diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 02bd50d1..1a451a2d 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,9 +1,12 @@ +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; +using Moq; using System.IO.Pipelines; using System.Text.Json; using System.Text.Json.Serialization.Metadata; @@ -40,6 +43,180 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) _serverTask = server.RunAsync(cancellationToken: _cts.Token); } + [Theory] + [InlineData(null, null)] + [InlineData(0.7f, 50)] + [InlineData(1.0f, 100)] + public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content { Type = "text", Text = "Hello" } + } + ], + Temperature = temperature, + MaxTokens = maxTokens, + Meta = new RequestParamsMetadata + { + ProgressToken = new ProgressToken(), + } + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("Hello, World!", result.Content.Text); + Assert.Equal("test-model", result.Model); + Assert.Equal("assistant", result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleImageMessages() + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content + { + Type = "image", + MimeType = "image/png", + Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) + } + } + ], + MaxTokens = 100 + }; + + const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedData, result.Content.Data); + Assert.Equal("test-model", result.Model); + Assert.Equal("assistant", result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + { + // Arrange + const string data = "SGVsbG8sIFdvcmxkIQ=="; + string content = $"data:application/octet-stream;base64,{data}"; + var mockChatClient = new Mock(); + var resource = new BlobResourceContents + { + Blob = data, + MimeType = "application/octet-stream", + Uri = "data:application/octet-stream" + }; + + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new Content + { + Type = "resource", + Resource = resource + }, + } + ], + MaxTokens = 100 + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + AuthorName = "bot", + Role = ChatRole.Assistant, + Contents = + [ + resource.ToAIContent() + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("test-model", result.Model); + Assert.Equal(ChatRole.Assistant.ToString(), result.Role); + Assert.Equal("endTurn", result.StopReason); + } + public async ValueTask DisposeAsync() { await _cts.CancelAsync(); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index d4545ddc..ae58023f 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -1,7 +1,9 @@ +using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; +using Moq; using System.Text.Json; using System.Threading.Channels; @@ -187,7 +189,68 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken)); } - private sealed class NopTransport : ITransport, IClientTransport + [Theory] + [InlineData(typeof(NopTransport))] + [InlineData(typeof(FailureTransport))] + public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) + { + // Arrange + var serverConfig = new McpServerConfig + { + Id = "TestServer", + Name = "TestServer", + TransportType = "stdio", + Location = "test-location" + }; + + var clientOptions = new McpClientOptions + { + ClientInfo = new Implementation + { + Name = "TestClient", + Version = "1.0.0.0" + }, + Capabilities = new ClientCapabilities + { + Sampling = new SamplingCapability + { + SamplingHandler = (c, p, t) => Task.FromResult( + new CreateMessageResult { + Content = new Content { Text = "result" }, + Model = "test-model", + Role = "test-role", + StopReason = "endTurn" + }), + }, + Roots = new RootsCapability + { + ListChanged = true, + RootsHandler = (t, r) => Task.FromResult(new ListRootsResult { Roots = [] }), + } + } + }; + + var clientTransport = (IClientTransport?)Activator.CreateInstance(transportType); + IMcpClient? client = null; + + var actionTask = McpClientFactory.CreateAsync(serverConfig, clientOptions, (config, logger) => clientTransport ?? new NopTransport(), new Mock().Object, CancellationToken.None); + + // Act + if (clientTransport is FailureTransport) + { + var exception = await Assert.ThrowsAsync(async() => await actionTask); + Assert.Equal(FailureTransport.ExpectedMessage, exception.Message); + } + else + { + client = await actionTask; + + // Assert + Assert.NotNull(client); + } + } + + private class NopTransport : ITransport, IClientTransport { private readonly Channel _channel = Channel.CreateUnbounded(); @@ -199,7 +262,7 @@ private sealed class NopTransport : ITransport, IClientTransport public ValueTask DisposeAsync() => default; - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { switch (message) { @@ -224,4 +287,14 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella return Task.CompletedTask; } } + + private sealed class FailureTransport : NopTransport + { + public const string ExpectedMessage = "Something failed"; + + public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + { + throw new InvalidOperationException(ExpectedMessage); + } + } }