diff --git a/src/ModelContextProtocol/Logging/Log.cs b/src/ModelContextProtocol/Logging/Log.cs index b49b4cb5..d22c5d66 100644 --- a/src/ModelContextProtocol/Logging/Log.cs +++ b/src/ModelContextProtocol/Logging/Log.cs @@ -68,6 +68,9 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Error, Message = "Request failed for {endpointName} with method {method}: {message} ({code})")] internal static partial void RequestFailed(this ILogger logger, string endpointName, string method, string message, int code); + [LoggerMessage(Level = LogLevel.Information, Message = "Request '{requestId}' canceled via client notification with reason '{Reason}'.")] + internal static partial void RequestCanceled(this ILogger logger, RequestId requestId, string? reason); + [LoggerMessage(Level = LogLevel.Information, Message = "Request response received payload for {endpointName}: {payload}")] internal static partial void RequestResponseReceivedPayload(this ILogger logger, string endpointName, string payload); diff --git a/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs b/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs new file mode 100644 index 00000000..c18a0d21 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs @@ -0,0 +1,21 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Messages; + +/// +/// This notification indicates that the result will be unused, so any associated processing SHOULD cease. +/// +public sealed class CancelledNotification +{ + /// + /// The ID of the request to cancel. + /// + [JsonPropertyName("requestId")] + public RequestId RequestId { get; set; } + + /// + /// An optional string describing the reason for the cancellation. + /// + [JsonPropertyName("reason")] + public string? Reason { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 831d40c4..97cbcb59 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable private readonly RequestHandlers _requestHandlers; private readonly NotificationHandlers _notificationHandlers; + /// Collection of requests sent on this session and waiting for responses. private readonly ConcurrentDictionary> _pendingRequests = []; + /// + /// Collection of requests received on this session and currently being handled. The value provides a + /// that can be used to request cancellation of the in-flight handler. + /// + private readonly ConcurrentDictionary _handlingRequests = new(); private readonly JsonSerializerOptions _jsonOptions; private readonly ILogger _logger; @@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken) { _logger.TransportMessageRead(EndpointName, message.GetType().Name); - // Fire and forget the message handling task to avoid blocking the transport - // If awaiting the task, the transport will not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back _ = ProcessMessageAsync(); async Task ProcessMessageAsync() { + IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId; + CancellationTokenSource? combinedCts = null; + try + { + // Register before we yield, so that the tracking is guaranteed to be there + // when subsequent messages arrive, even if the asynchronous processing happens + // out of order. + if (messageWithId is not null) + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _handlingRequests[messageWithId.Id] = combinedCts; + } + + // Fire and forget the message handling to avoid blocking the transport + // If awaiting the task, the transport will not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back + #if NET - await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); #else - await default(ForceYielding); + await default(ForceYielding); #endif - try - { - await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false); + + // Handle the message. + await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); } catch (Exception ex) { - var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); - _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex); + // Only send responses for request errors that aren't user-initiated cancellation. + bool isUserCancellation = + ex is OperationCanceledException && + !cancellationToken.IsCancellationRequested && + combinedCts?.IsCancellationRequested is true; + + if (!isUserCancellation && message is JsonRpcRequest request) + { + _logger.RequestHandlerError(EndpointName, request.Method, ex); + await _transport.SendMessageAsync(new JsonRpcError + { + Id = request.Id, + JsonRpc = "2.0", + Error = new JsonRpcErrorDetail + { + Code = ErrorCodes.InternalError, + Message = ex.Message + } + }, cancellationToken).ConfigureAwait(false); + } + else if (ex is not OperationCanceledException) + { + var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()); + _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex); + } + } + finally + { + if (messageWithId is not null) + { + _handlingRequests.TryRemove(messageWithId.Id, out _); + combinedCts!.Dispose(); + } } } } @@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken private async Task HandleNotification(JsonRpcNotification notification) { + // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) + if (notification.Method == NotificationMethods.CancelledNotification) + { + try + { + if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && + _handlingRequests.TryGetValue(cn.RequestId, out var cts)) + { + await cts.CancelAsync().ConfigureAwait(false); + _logger.RequestCanceled(cn.RequestId, cn.Reason); + } + } + catch + { + // "Invalid cancellation notifications SHOULD be ignored" + } + } + + // Handle user-defined notifications. if (_notificationHandlers.TryGetValue(notification.Method, out var handlers)) { foreach (var notificationHandler in handlers) @@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance { if (_requestHandlers.TryGetValue(request.Method, out var handler)) { - try - { - _logger.RequestHandlerCalled(EndpointName, request.Method); - var result = await handler(request, cancellationToken).ConfigureAwait(false); - _logger.RequestHandlerCompleted(EndpointName, request.Method); - await _transport.SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - JsonRpc = "2.0", - Result = result - }, cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) + _logger.RequestHandlerCalled(EndpointName, request.Method); + var result = await handler(request, cancellationToken).ConfigureAwait(false); + _logger.RequestHandlerCompleted(EndpointName, request.Method); + await _transport.SendMessageAsync(new JsonRpcResponse { - _logger.RequestHandlerError(EndpointName, request.Method, ex); - // Send error response - await _transport.SendMessageAsync(new JsonRpcError - { - Id = request.Id, - JsonRpc = "2.0", - Error = new JsonRpcErrorDetail - { - Code = -32000, // Implementation defined error - Message = ex.Message - } - }, cancellationToken).ConfigureAwait(false); - } + Id = request.Id, + JsonRpc = "2.0", + Result = result + }, cancellationToken).ConfigureAwait(false); } else { @@ -273,7 +325,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can } } - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) + public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); @@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella _logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo())); } - return _transport.SendMessageAsync(message, cancellationToken); + await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); + + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotification cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + + private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams) + { + try + { + switch (notificationParams) + { + case null: + return null; + + case CancelledNotification cn: + return cn; + + case JsonElement je: + return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo()); + + default: + return JsonSerializer.Deserialize( + JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo()), + McpJsonUtilities.DefaultOptions.GetTypeInfo()); + } + } + catch + { + return null; + } } public void Dispose() diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 68d13eb2..337cdbda 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) // MCP Request Params / Results [JsonSerializable(typeof(CallToolRequestParams))] [JsonSerializable(typeof(CallToolResponse))] + [JsonSerializable(typeof(CancelledNotification))] [JsonSerializable(typeof(CompleteRequestParams))] [JsonSerializable(typeof(CompleteResult))] [JsonSerializable(typeof(CreateMessageRequestParams))] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 2a1ddafa..73fdeade 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -91,7 +91,7 @@ public async Task Can_List_Registered_Tools() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); + Assert.Equal(13, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -138,7 +138,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T cancellationToken: TestContext.Current.CancellationToken)) { var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); + Assert.Equal(13, tools.Count); McpClientTool echoTool = tools.First(t => t.Name == "Echo"); Assert.Equal("Echo", echoTool.Name); @@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); + Assert.Equal(13, tools.Count); Channel listChanged = Channel.CreateUnbounded(); client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification => @@ -186,7 +186,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(13, tools.Count); + Assert.Equal(14, tools.Count); Assert.Contains(tools, t => t.Name == "NewTool"); notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); @@ -195,7 +195,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() await notificationRead; tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); + Assert.Equal(13, tools.Count); Assert.DoesNotContain(tools, t => t.Name == "NewTool"); } @@ -560,6 +560,35 @@ public async Task HandlesIProgressParameter() } } + [Fact] + public async Task CancellationNotificationsPropagateToToolTokens() + { + IMcpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken); + Assert.NotNull(tools); + Assert.NotEmpty(tools); + McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation)); + + var requestId = new RequestId(Guid.NewGuid().ToString()); + var invokeTask = client.SendRequestAsync(new JsonRpcRequest() + { + Method = RequestMethods.ToolsCall, + Id = requestId, + Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name }, + }, TestContext.Current.CancellationToken); + + await client.SendNotificationAsync( + NotificationMethods.CancelledNotification, + parameters: new CancelledNotification() + { + RequestId = requestId, + }, + cancellationToken: TestContext.Current.CancellationToken); + + await Assert.ThrowsAnyAsync(() => invokeTask); + } + [McpServerToolType] public sealed class EchoTool(ObjectWithId objectFromDI) { @@ -625,6 +654,21 @@ public static string EchoComplex(ComplexObject complex) return complex.Name!; } + [McpServerTool] + public static async Task InfiniteCancelableOperation(CancellationToken cancellationToken) + { + try + { + await Task.Delay(Timeout.Infinite, cancellationToken); + } + catch (Exception) + { + return "canceled"; + } + + return "unreachable"; + } + [McpServerTool] public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";