diff --git a/src/ModelContextProtocol/Logging/Log.cs b/src/ModelContextProtocol/Logging/Log.cs
index b49b4cb5..d22c5d66 100644
--- a/src/ModelContextProtocol/Logging/Log.cs
+++ b/src/ModelContextProtocol/Logging/Log.cs
@@ -68,6 +68,9 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Error, Message = "Request failed for {endpointName} with method {method}: {message} ({code})")]
internal static partial void RequestFailed(this ILogger logger, string endpointName, string method, string message, int code);
+ [LoggerMessage(Level = LogLevel.Information, Message = "Request '{requestId}' canceled via client notification with reason '{Reason}'.")]
+ internal static partial void RequestCanceled(this ILogger logger, RequestId requestId, string? reason);
+
[LoggerMessage(Level = LogLevel.Information, Message = "Request response received payload for {endpointName}: {payload}")]
internal static partial void RequestResponseReceivedPayload(this ILogger logger, string endpointName, string payload);
diff --git a/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs b/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs
new file mode 100644
index 00000000..c18a0d21
--- /dev/null
+++ b/src/ModelContextProtocol/Protocol/Messages/CancelledNotification.cs
@@ -0,0 +1,21 @@
+using System.Text.Json.Serialization;
+
+namespace ModelContextProtocol.Protocol.Messages;
+
+///
+/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
+///
+public sealed class CancelledNotification
+{
+ ///
+ /// The ID of the request to cancel.
+ ///
+ [JsonPropertyName("requestId")]
+ public RequestId RequestId { get; set; }
+
+ ///
+ /// An optional string describing the reason for the cancellation.
+ ///
+ [JsonPropertyName("reason")]
+ public string? Reason { get; set; }
+}
\ No newline at end of file
diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs
index 831d40c4..97cbcb59 100644
--- a/src/ModelContextProtocol/Shared/McpSession.cs
+++ b/src/ModelContextProtocol/Shared/McpSession.cs
@@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
private readonly RequestHandlers _requestHandlers;
private readonly NotificationHandlers _notificationHandlers;
+ /// Collection of requests sent on this session and waiting for responses.
private readonly ConcurrentDictionary> _pendingRequests = [];
+ ///
+ /// Collection of requests received on this session and currently being handled. The value provides a
+ /// that can be used to request cancellation of the in-flight handler.
+ ///
+ private readonly ConcurrentDictionary _handlingRequests = new();
private readonly JsonSerializerOptions _jsonOptions;
private readonly ILogger _logger;
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
{
_logger.TransportMessageRead(EndpointName, message.GetType().Name);
- // Fire and forget the message handling task to avoid blocking the transport
- // If awaiting the task, the transport will not be able to read more messages,
- // which could lead to a deadlock if the handler sends a message back
_ = ProcessMessageAsync();
async Task ProcessMessageAsync()
{
+ IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
+ CancellationTokenSource? combinedCts = null;
+ try
+ {
+ // Register before we yield, so that the tracking is guaranteed to be there
+ // when subsequent messages arrive, even if the asynchronous processing happens
+ // out of order.
+ if (messageWithId is not null)
+ {
+ combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
+ _handlingRequests[messageWithId.Id] = combinedCts;
+ }
+
+ // Fire and forget the message handling to avoid blocking the transport
+ // If awaiting the task, the transport will not be able to read more messages,
+ // which could lead to a deadlock if the handler sends a message back
+
#if NET
- await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
+ await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
#else
- await default(ForceYielding);
+ await default(ForceYielding);
#endif
- try
- {
- await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);
+
+ // Handle the message.
+ await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
- var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo());
- _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
+ // Only send responses for request errors that aren't user-initiated cancellation.
+ bool isUserCancellation =
+ ex is OperationCanceledException &&
+ !cancellationToken.IsCancellationRequested &&
+ combinedCts?.IsCancellationRequested is true;
+
+ if (!isUserCancellation && message is JsonRpcRequest request)
+ {
+ _logger.RequestHandlerError(EndpointName, request.Method, ex);
+ await _transport.SendMessageAsync(new JsonRpcError
+ {
+ Id = request.Id,
+ JsonRpc = "2.0",
+ Error = new JsonRpcErrorDetail
+ {
+ Code = ErrorCodes.InternalError,
+ Message = ex.Message
+ }
+ }, cancellationToken).ConfigureAwait(false);
+ }
+ else if (ex is not OperationCanceledException)
+ {
+ var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo());
+ _logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
+ }
+ }
+ finally
+ {
+ if (messageWithId is not null)
+ {
+ _handlingRequests.TryRemove(messageWithId.Id, out _);
+ combinedCts!.Dispose();
+ }
}
}
}
@@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
private async Task HandleNotification(JsonRpcNotification notification)
{
+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
+ if (notification.Method == NotificationMethods.CancelledNotification)
+ {
+ try
+ {
+ if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
+ _handlingRequests.TryGetValue(cn.RequestId, out var cts))
+ {
+ await cts.CancelAsync().ConfigureAwait(false);
+ _logger.RequestCanceled(cn.RequestId, cn.Reason);
+ }
+ }
+ catch
+ {
+ // "Invalid cancellation notifications SHOULD be ignored"
+ }
+ }
+
+ // Handle user-defined notifications.
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
{
foreach (var notificationHandler in handlers)
@@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
{
if (_requestHandlers.TryGetValue(request.Method, out var handler))
{
- try
- {
- _logger.RequestHandlerCalled(EndpointName, request.Method);
- var result = await handler(request, cancellationToken).ConfigureAwait(false);
- _logger.RequestHandlerCompleted(EndpointName, request.Method);
- await _transport.SendMessageAsync(new JsonRpcResponse
- {
- Id = request.Id,
- JsonRpc = "2.0",
- Result = result
- }, cancellationToken).ConfigureAwait(false);
- }
- catch (Exception ex)
+ _logger.RequestHandlerCalled(EndpointName, request.Method);
+ var result = await handler(request, cancellationToken).ConfigureAwait(false);
+ _logger.RequestHandlerCompleted(EndpointName, request.Method);
+ await _transport.SendMessageAsync(new JsonRpcResponse
{
- _logger.RequestHandlerError(EndpointName, request.Method, ex);
- // Send error response
- await _transport.SendMessageAsync(new JsonRpcError
- {
- Id = request.Id,
- JsonRpc = "2.0",
- Error = new JsonRpcErrorDetail
- {
- Code = -32000, // Implementation defined error
- Message = ex.Message
- }
- }, cancellationToken).ConfigureAwait(false);
- }
+ Id = request.Id,
+ JsonRpc = "2.0",
+ Result = result
+ }, cancellationToken).ConfigureAwait(false);
}
else
{
@@ -273,7 +325,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can
}
}
- public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
+ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);
@@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo()));
}
- return _transport.SendMessageAsync(message, cancellationToken);
+ await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
+
+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
+ if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
+ GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
+ _pendingRequests.TryRemove(cn.RequestId, out var tcs))
+ {
+ tcs.TrySetCanceled(default);
+ }
+ }
+
+ private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
+ {
+ try
+ {
+ switch (notificationParams)
+ {
+ case null:
+ return null;
+
+ case CancelledNotification cn:
+ return cn;
+
+ case JsonElement je:
+ return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo());
+
+ default:
+ return JsonSerializer.Deserialize(
+ JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo