diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index fc526132..3c8981b6 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -508,20 +508,33 @@ public void WithTools_Parameters_Satisfiable_From_DI(bool parameterInServices) } [Theory] - [InlineData(false)] - [InlineData(true)] - public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(bool parameterInServices) + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + [InlineData(null)] + public void WithToolsFromAssembly_Parameters_Satisfiable_From_DI(ServiceLifetime? lifetime) { ServiceCollection sc = new(); - if (parameterInServices) + switch (lifetime) { - sc.AddSingleton(new ComplexObject()); + case ServiceLifetime.Singleton: + sc.AddSingleton(new ComplexObject()); + break; + + case ServiceLifetime.Scoped: + sc.AddScoped(_ => new ComplexObject()); + break; + + case ServiceLifetime.Transient: + sc.AddTransient(_ => new ComplexObject()); + break; } + sc.AddMcpServer().WithToolsFromAssembly(); IServiceProvider services = sc.BuildServiceProvider(); McpServerTool tool = services.GetServices().First(t => t.ProtocolTool.Name == "EchoComplex"); - if (parameterInServices) + if (lifetime is not null) { Assert.DoesNotContain("\"complex\"", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 2d747d9d..30d13fd9 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -45,20 +45,47 @@ public async Task SupportsIMcpServer() Assert.Equal("42", result.Content[0].Text); } - [Fact] - public async Task SupportsServiceFromDI() + [Theory] + [InlineData(ServiceLifetime.Singleton)] + [InlineData(ServiceLifetime.Scoped)] + [InlineData(ServiceLifetime.Transient)] + public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime) { - MyService expectedMyService = new(); + MyService singletonService = new(); ServiceCollection sc = new(); - sc.AddSingleton(expectedMyService); - IServiceProvider services = sc.BuildServiceProvider(); + switch (injectedArgumentLifetime) + { + case ServiceLifetime.Singleton: + sc.AddSingleton(singletonService); + break; + + case ServiceLifetime.Scoped: + sc.AddScoped(_ => new MyService()); + break; + + case ServiceLifetime.Transient: + sc.AddTransient(_ => new MyService()); + break; + } - McpServerTool tool = McpServerTool.Create((MyService actualMyService) => + sc.AddSingleton(services => { - Assert.Same(expectedMyService, actualMyService); - return "42"; - }, new() { Services = services }); + return McpServerTool.Create((MyService actualMyService) => + { + Assert.NotNull(actualMyService); + if (injectedArgumentLifetime == ServiceLifetime.Singleton) + { + Assert.Same(singletonService, actualMyService); + } + + return "42"; + }, new() { Services = services }); + }); + + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerTool tool = services.GetRequiredService(); Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema));