Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using Moq;
using System.IO.Pipelines;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
Expand Down Expand Up @@ -40,6 +43,180 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper)
_serverTask = server.RunAsync(cancellationToken: _cts.Token);
}

[Theory]
[InlineData(null, null)]
[InlineData(0.7f, 50)]
[InlineData(1.0f, 100)]
public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens)
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = new Content { Type = "text", Text = "Hello" }
}
],
Temperature = temperature,
MaxTokens = maxTokens,
Meta = new RequestParamsMetadata
{
ProgressToken = new ProgressToken(),
}
};

var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
Role = ChatRole.Assistant,
Contents =
[
new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" }
]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Returns(expectedResponse);

var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object);

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.Equal("Hello, World!", result.Content.Text);
Assert.Equal("test-model", result.Model);
Assert.Equal("assistant", result.Role);
Assert.Equal("endTurn", result.StopReason);
}

[Fact]
public async Task CreateSamplingHandler_ShouldHandleImageMessages()
{
// Arrange
var mockChatClient = new Mock<IChatClient>();
var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = new Content
{
Type = "image",
MimeType = "image/png",
Data = Convert.ToBase64String(new byte[] { 1, 2, 3 })
}
}
],
MaxTokens = 100
};

const string expectedData = "SGVsbG8sIFdvcmxkIQ==";
var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
Role = ChatRole.Assistant,
Contents =
[
new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" }
]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Returns(expectedResponse);

var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object);

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.Equal(expectedData, result.Content.Data);
Assert.Equal("test-model", result.Model);
Assert.Equal("assistant", result.Role);
Assert.Equal("endTurn", result.StopReason);
}

[Fact]
public async Task CreateSamplingHandler_ShouldHandleResourceMessages()
{
// Arrange
const string data = "SGVsbG8sIFdvcmxkIQ==";
string content = $"data:application/octet-stream;base64,{data}";
var mockChatClient = new Mock<IChatClient>();
var resource = new BlobResourceContents
{
Blob = data,
MimeType = "application/octet-stream",
Uri = "data:application/octet-stream"
};

var requestParams = new CreateMessageRequestParams
{
Messages =
[
new SamplingMessage
{
Role = Role.User,
Content = new Content
{
Type = "resource",
Resource = resource
},
}
],
MaxTokens = 100
};

var cancellationToken = CancellationToken.None;
var expectedResponse = new[] {
new ChatResponseUpdate
{
ModelId = "test-model",
FinishReason = ChatFinishReason.Stop,
AuthorName = "bot",
Role = ChatRole.Assistant,
Contents =
[
resource.ToAIContent()
]
}
}.ToAsyncEnumerable();

mockChatClient
.Setup(client => client.GetStreamingResponseAsync(It.IsAny<IEnumerable<ChatMessage>>(), It.IsAny<ChatOptions>(), cancellationToken))
.Returns(expectedResponse);

var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object);

// Act
var result = await handler(requestParams, Mock.Of<IProgress<ProgressNotificationValue>>(), cancellationToken);

// Assert
Assert.NotNull(result);
Assert.Equal("test-model", result.Model);
Assert.Equal(ChatRole.Assistant.ToString(), result.Role);
Assert.Equal("endTurn", result.StopReason);
}

public async ValueTask DisposeAsync()
{
await _cts.CancelAsync();
Expand Down
77 changes: 75 additions & 2 deletions tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Protocol.Types;
using Moq;
using System.Text.Json;
using System.Threading.Channels;

Expand Down Expand Up @@ -187,7 +189,68 @@ public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(s
await Assert.ThrowsAsync<ArgumentException>(() => McpClientFactory.CreateAsync(config, _defaultOptions, cancellationToken: TestContext.Current.CancellationToken));
}

private sealed class NopTransport : ITransport, IClientTransport
[Theory]
[InlineData(typeof(NopTransport))]
[InlineData(typeof(FailureTransport))]
public async Task CreateAsync_WithCapabilitiesOptions(Type transportType)
{
// Arrange
var serverConfig = new McpServerConfig
{
Id = "TestServer",
Name = "TestServer",
TransportType = "stdio",
Location = "test-location"
};

var clientOptions = new McpClientOptions
{
ClientInfo = new Implementation
{
Name = "TestClient",
Version = "1.0.0.0"
},
Capabilities = new ClientCapabilities
{
Sampling = new SamplingCapability
{
SamplingHandler = (c, p, t) => Task.FromResult(
new CreateMessageResult {
Content = new Content { Text = "result" },
Model = "test-model",
Role = "test-role",
StopReason = "endTurn"
}),
},
Roots = new RootsCapability
{
ListChanged = true,
RootsHandler = (t, r) => Task.FromResult(new ListRootsResult { Roots = [] }),
}
}
};

var clientTransport = (IClientTransport?)Activator.CreateInstance(transportType);
IMcpClient? client = null;

var actionTask = McpClientFactory.CreateAsync(serverConfig, clientOptions, (config, logger) => clientTransport ?? new NopTransport(), new Mock<ILoggerFactory>().Object, CancellationToken.None);

// Act
if (clientTransport is FailureTransport)
{
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async() => await actionTask);
Assert.Equal(FailureTransport.ExpectedMessage, exception.Message);
}
else
{
client = await actionTask;

// Assert
Assert.NotNull(client);
}
}

private class NopTransport : ITransport, IClientTransport
{
private readonly Channel<IJsonRpcMessage> _channel = Channel.CreateUnbounded<IJsonRpcMessage>();

Expand All @@ -199,7 +262,7 @@ private sealed class NopTransport : ITransport, IClientTransport

public ValueTask DisposeAsync() => default;

public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
public virtual Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
switch (message)
{
Expand All @@ -224,4 +287,14 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
return Task.CompletedTask;
}
}

private sealed class FailureTransport : NopTransport
{
public const string ExpectedMessage = "Something failed";

public override Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
throw new InvalidOperationException(ExpectedMessage);
}
}
}