Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ModelContextProtocol/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Error, Message = "Request failed for {endpointName} with method {method}: {message} ({code})")]
internal static partial void RequestFailed(this ILogger logger, string endpointName, string method, string message, int code);

[LoggerMessage(Level = LogLevel.Information, Message = "Request '{requestId}' canceled via client notification.")]
internal static partial void RequestCanceled(this ILogger logger, RequestId requestId);

[LoggerMessage(Level = LogLevel.Information, Message = "Request response received payload for {endpointName}: {payload}")]
internal static partial void RequestResponseReceivedPayload(this ILogger logger, string endpointName, string payload);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Messages;

/// <summary>
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
/// </summary>
public sealed class CancelledNotification
{
/// <summary>
/// The ID of the request to cancel.
/// </summary>
[JsonPropertyName("requestId")]
public RequestId RequestId { get; set; }

/// <summary>
/// An optional string describing the reason for the cancellation.
/// </summary>
[JsonPropertyName("reason")]
public string? Reason { get; set; }
}
165 changes: 127 additions & 38 deletions src/ModelContextProtocol/Shared/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
private readonly RequestHandlers _requestHandlers;
private readonly NotificationHandlers _notificationHandlers;

/// <summary>Collection of requests sent on this session and waiting for responses.</summary>
private readonly ConcurrentDictionary<RequestId, TaskCompletionSource<IJsonRpcMessage>> _pendingRequests = [];
/// <summary>
/// Collection of requests received on this session and currently being handled. The value provides a <see cref="CancellationTokenSource"/>
/// that can be used to request cancellation of the in-flight handler.
/// </summary>
private readonly ConcurrentDictionary<RequestId, CancellationTokenSource> s_handlingRequests = new();
private readonly JsonSerializerOptions _jsonOptions;
private readonly ILogger _logger;

Expand Down Expand Up @@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
{
_logger.TransportMessageRead(EndpointName, message.GetType().Name);

// Fire and forget the message handling task to avoid blocking the transport
// If awaiting the task, the transport will not be able to read more messages,
// which could lead to a deadlock if the handler sends a message back
_ = ProcessMessageAsync();
async Task ProcessMessageAsync()
{
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
CancellationTokenSource? combinedCts = null;
try
{
// Register before we yield, so that the tracking is guaranteed to be there
// when subsequent messages arrive, even if the asynchronous processing happens
// out of order.
if (messageWithId is not null)
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
s_handlingRequests[messageWithId.Id] = combinedCts;
}

// Fire and forget the message handling to avoid blocking the transport
// If awaiting the task, the transport will not be able to read more messages,
// which could lead to a deadlock if the handler sends a message back

#if NET
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
#else
await default(ForceYielding);
await default(ForceYielding);
#endif
try
{
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Handle the message.
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
// Only send responses for request errors that aren't user-initiated cancellation.
bool isUserCancellation =
ex is OperationCanceledException &&
!cancellationToken.IsCancellationRequested &&
combinedCts?.IsCancellationRequested is true;

if (!isUserCancellation && message is JsonRpcRequest request)
{
_logger.RequestHandlerError(EndpointName, request.Method, ex);
await _transport.SendMessageAsync(new JsonRpcError
{
Id = request.Id,
JsonRpc = "2.0",
Error = new JsonRpcErrorDetail
{
Code = ErrorCodes.InternalError,
Message = ex.Message
}
}, cancellationToken).ConfigureAwait(false);
}
else if (ex is not OperationCanceledException)
{
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
}
}
finally
{
if (messageWithId is not null)
{
s_handlingRequests.TryRemove(messageWithId.Id, out _);
combinedCts!.Dispose();
}
}
}
}
Expand Down Expand Up @@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken

private async Task HandleNotification(JsonRpcNotification notification)
{
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
if (notification.Method == NotificationMethods.CancelledNotification)
{
try
{
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
s_handlingRequests.TryGetValue(cn.RequestId, out var cts))
{
await cts.CancelAsync().ConfigureAwait(false);
_logger.RequestCanceled(cn.RequestId);
}
}
catch
{
// "Invalid cancellation notifications SHOULD be ignored"
}
}

// Handle user-defined notifications.
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
{
foreach (var notificationHandler in handlers)
Expand Down Expand Up @@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
{
if (_requestHandlers.TryGetValue(request.Method, out var handler))
{
try
{
_logger.RequestHandlerCalled(EndpointName, request.Method);
var result = await handler(request, cancellationToken).ConfigureAwait(false);
_logger.RequestHandlerCompleted(EndpointName, request.Method);
await _transport.SendMessageAsync(new JsonRpcResponse
{
Id = request.Id,
JsonRpc = "2.0",
Result = result
}, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
_logger.RequestHandlerCalled(EndpointName, request.Method);
var result = await handler(request, cancellationToken).ConfigureAwait(false);
_logger.RequestHandlerCompleted(EndpointName, request.Method);
await _transport.SendMessageAsync(new JsonRpcResponse
{
_logger.RequestHandlerError(EndpointName, request.Method, ex);
// Send error response
await _transport.SendMessageAsync(new JsonRpcError
{
Id = request.Id,
JsonRpc = "2.0",
Error = new JsonRpcErrorDetail
{
Code = -32000, // Implementation defined error
Message = ex.Message
}
}, cancellationToken).ConfigureAwait(false);
}
Id = request.Id,
JsonRpc = "2.0",
Result = result
}, cancellationToken).ConfigureAwait(false);
}
else
{
Expand Down Expand Up @@ -273,7 +325,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
}
}

public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);

Expand All @@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
}

return _transport.SendMessageAsync(message, cancellationToken);
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
{
tcs.TrySetCanceled(default);
}
}

private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
{
try
{
switch (notificationParams)
{
case null:
return null;

case CancelledNotification cn:
return cn;

case JsonElement je:
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());

default:
return JsonSerializer.Deserialize(
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
}
}
catch
{
return null;
}
}

public void Dispose()
Expand Down
1 change: 1 addition & 0 deletions src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
// MCP Request Params / Results
[JsonSerializable(typeof(CallToolRequestParams))]
[JsonSerializable(typeof(CallToolResponse))]
[JsonSerializable(typeof(CancelledNotification))]
[JsonSerializable(typeof(CompleteRequestParams))]
[JsonSerializable(typeof(CompleteResult))]
[JsonSerializable(typeof(CreateMessageRequestParams))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public async Task Can_List_Registered_Tools()
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "Echo");
Assert.Equal("Echo", echoTool.Name);
Expand Down Expand Up @@ -138,7 +138,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
cancellationToken: TestContext.Current.CancellationToken))
{
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "Echo");
Assert.Equal("Echo", echoTool.Name);
Expand All @@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
Expand All @@ -186,7 +186,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await notificationRead;

tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(13, tools.Count);
Assert.Equal(14, tools.Count);
Assert.Contains(tools, t => t.Name == "NewTool");

notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
Expand All @@ -195,7 +195,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await notificationRead;

tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
}

Expand Down Expand Up @@ -560,6 +560,35 @@ public async Task HandlesIProgressParameter()
}
}

[Fact]
public async Task CancellationNotificationsPropagateToToolTokens()
{
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.NotNull(tools);
Assert.NotEmpty(tools);
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));

var requestId = new RequestId(Guid.NewGuid().ToString());
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
{
Method = RequestMethods.ToolsCall,
Id = requestId,
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
}, TestContext.Current.CancellationToken);

await client.SendNotificationAsync(
NotificationMethods.CancelledNotification,
parameters: new CancelledNotification()
{
RequestId = requestId,
},
cancellationToken: TestContext.Current.CancellationToken);

await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
}

[McpServerToolType]
public sealed class EchoTool(ObjectWithId objectFromDI)
{
Expand Down Expand Up @@ -625,6 +654,21 @@ public static string EchoComplex(ComplexObject complex)
return complex.Name!;
}

[McpServerTool]
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
{
try
{
await Task.Delay(Timeout.Infinite, cancellationToken);
}
catch (Exception)
{
return "canceled";
}

return "unreachable";
}

[McpServerTool]
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";

Expand Down