Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Http/Http.Features/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#nullable enable
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.get -> bool
Microsoft.AspNetCore.Http.Features.IHttpMetricsTagsFeature.MetricsDisabled.set -> void
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.get -> System.TimeSpan?
Microsoft.AspNetCore.Http.WebSocketAcceptContext.KeepAliveTimeout.set -> void
26 changes: 25 additions & 1 deletion src/Http/Http.Features/src/WebSocketAcceptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,41 @@ namespace Microsoft.AspNetCore.Http;
public class WebSocketAcceptContext
{
private int _serverMaxWindowBits = 15;
private TimeSpan? _keepAliveTimeout;

/// <summary>
/// Gets or sets the subprotocol being negotiated.
/// </summary>
public virtual string? SubProtocol { get; set; }

/// <summary>
/// The interval to send pong frames. This is a heart-beat that keeps the connection alive.
/// The interval to send keep-alive frames. This is a heart-beat that keeps the connection alive.
/// </summary>
public virtual TimeSpan? KeepAliveInterval { get; set; }

/// <summary>
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
/// </summary>
/// <remarks>
/// <c>null</c> means use the value from <c>WebSocketOptions.KeepAliveTimeout</c>.
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> are valid values and will disable the timeout.
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
/// </exception>
public TimeSpan? KeepAliveTimeout
{
get => _keepAliveTimeout;
set
{
if (value is not null && value != Timeout.InfiniteTimeSpan)
{
ArgumentOutOfRangeException.ThrowIfLessThan(value.Value, TimeSpan.Zero);
}
_keepAliveTimeout = value;
}
}

/// <summary>
/// Enables support for the 'permessage-deflate' WebSocket extension.<para />
/// Be aware that enabling compression over encrypted connections makes the application subject to CRIME/BREACH type attacks.
Expand Down
129 changes: 129 additions & 0 deletions src/Middleware/WebSockets/src/AbortStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Net.WebSockets;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.WebSockets;

/// <summary>
/// Used in WebSocketMiddleware to wrap the HttpContext.Request.Body stream
/// so that we can call HttpContext.Abort when the stream is disposed and the WebSocket is in the Aborted state.
/// The Stream provided by Kestrel (and maybe other servers) noops in Dispose as it doesn't know whether it's a graceful close or not
/// and can result in truncated responses if in the graceful case.
///
/// This handles explicit WebSocket.Abort calls as well as the Keep-Alive timeout setting Aborted and disposing the stream.
/// </summary>
/// <remarks>
/// Workaround for https://github.com/dotnet/runtime/issues/44272
/// </remarks>
internal sealed class AbortStream : Stream
{
private readonly Stream _innerStream;
private readonly HttpContext _httpContext;

public WebSocket? WebSocket { get; set; }

public AbortStream(HttpContext httpContext, Stream innerStream)
{
_innerStream = innerStream;
_httpContext = httpContext;
}

public override bool CanRead => _innerStream.CanRead;

public override bool CanSeek => _innerStream.CanSeek;

public override bool CanWrite => _innerStream.CanWrite;

public override bool CanTimeout => _innerStream.CanTimeout;

public override long Length => _innerStream.Length;

public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }

public override void Flush()
{
_innerStream.Flush();
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.ReadAsync(buffer, cancellationToken);
}

public override int Read(byte[] buffer, int offset, int count)
{
return _innerStream.Read(buffer, offset, count);
}

public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
{
return _innerStream.BeginRead(buffer, offset, count, callback, state);
}

public override int EndRead(IAsyncResult asyncResult)
{
return _innerStream.EndRead(asyncResult);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
{
return _innerStream.BeginWrite(buffer, offset, count, callback, state);
}

public override void EndWrite(IAsyncResult asyncResult)
{
_innerStream.EndWrite(asyncResult);
}

public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_innerStream.SetLength(value);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.WriteAsync(buffer, cancellationToken);
}

public override void Write(byte[] buffer, int offset, int count)
{
_innerStream.Write(buffer, offset, count);
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
return _innerStream.FlushAsync(cancellationToken);
}

public override ValueTask DisposeAsync()
{
return _innerStream.DisposeAsync();
}

protected override void Dispose(bool disposing)
{
// Currently, if ManagedWebSocket sets the Aborted state it calls Stream.Dispose after
if (WebSocket?.State == WebSocketState.Aborted)
{
_httpContext.Abort();
}
_innerStream.Dispose();
}
}
2 changes: 2 additions & 0 deletions src/Middleware/WebSockets/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.get -> System.TimeSpan
Microsoft.AspNetCore.Builder.WebSocketOptions.KeepAliveTimeout.set -> void
80 changes: 0 additions & 80 deletions src/Middleware/WebSockets/src/ServerWebSocket.cs

This file was deleted.

9 changes: 7 additions & 2 deletions src/Middleware/WebSockets/src/WebSocketMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
bool serverContextTakeover = true;
int serverMaxWindowBits = 15;
TimeSpan keepAliveInterval = _options.KeepAliveInterval;
TimeSpan keepAliveTimeout = _options.KeepAliveTimeout;
if (acceptContext != null)
{
subProtocol = acceptContext.SubProtocol;
enableCompression = acceptContext.DangerousEnableCompression;
serverContextTakeover = !acceptContext.DisableServerContextTakeover;
serverMaxWindowBits = acceptContext.ServerMaxWindowBits;
keepAliveInterval = acceptContext.KeepAliveInterval ?? keepAliveInterval;
keepAliveTimeout = acceptContext.KeepAliveTimeout ?? keepAliveTimeout;
}

#pragma warning disable CS0618 // Type or member is obsolete
Expand Down Expand Up @@ -208,15 +210,18 @@ public async Task<WebSocket> AcceptAsync(WebSocketAcceptContext acceptContext)
// Disable request timeout, if there is one, after the websocket has been accepted
_context.Features.Get<IHttpRequestTimeoutFeature>()?.DisableTimeout();

var wrappedSocket = WebSocket.CreateFromStream(opaqueTransport, new WebSocketCreationOptions()
var abortStream = new AbortStream(_context, opaqueTransport);
var wrappedSocket = WebSocket.CreateFromStream(abortStream, new WebSocketCreationOptions()
{
IsServer = true,
KeepAliveInterval = keepAliveInterval,
KeepAliveTimeout = keepAliveTimeout,
SubProtocol = subProtocol,
DangerousDeflateOptions = deflateOptions
});

return new ServerWebSocket(wrappedSocket, _context);
abortStream.WebSocket = wrappedSocket;
return wrappedSocket;
}

public static bool CheckSupportedWebSocketRequest(string method, IHeaderDictionary requestHeaders)
Expand Down
28 changes: 28 additions & 0 deletions src/Middleware/WebSockets/src/WebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ namespace Microsoft.AspNetCore.Builder;
/// </summary>
public class WebSocketOptions
{
private TimeSpan _keepAliveTimeout = Timeout.InfiniteTimeSpan;

/// <summary>
/// Constructs the <see cref="WebSocketOptions"/> class with default values.
/// </summary>
Expand All @@ -23,6 +25,32 @@ public WebSocketOptions()
/// </summary>
public TimeSpan KeepAliveInterval { get; set; }

/// <summary>
/// The time to wait for a Pong frame response after sending a Ping frame. If the time is exceeded the websocket will be aborted.
/// </summary>
/// <remarks>
/// Default value is <see cref="Timeout.InfiniteTimeSpan"/>.
/// <see cref="Timeout.InfiniteTimeSpan"/> and <see cref="TimeSpan.Zero"/> will disable the timeout.
/// </remarks>
/// <exception cref="ArgumentOutOfRangeException">
/// <see cref="TimeSpan"/> is less than <see cref="TimeSpan.Zero"/>.
/// </exception>
public TimeSpan KeepAliveTimeout
{
get
{
return _keepAliveTimeout;
}
set
{
if (value != Timeout.InfiniteTimeSpan)
{
ArgumentOutOfRangeException.ThrowIfLessThan(value, TimeSpan.Zero);
}
_keepAliveTimeout = value;
}
}

/// <summary>
/// Gets or sets the size of the protocol buffer used to receive and parse frames.
/// The default is 4kb.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.AspNetCore.Builder;
Expand All @@ -17,14 +17,22 @@ public void AddWebSocketsConfiguresOptions()
serviceCollection.AddWebSockets(o =>
{
o.KeepAliveInterval = TimeSpan.FromSeconds(1000);
o.KeepAliveTimeout = TimeSpan.FromSeconds(1234);
o.AllowedOrigins.Add("someString");
});

var services = serviceCollection.BuildServiceProvider();
var socketOptions = services.GetRequiredService<IOptions<WebSocketOptions>>().Value;

Assert.Equal(TimeSpan.FromSeconds(1000), socketOptions.KeepAliveInterval);
Assert.Equal(TimeSpan.FromSeconds(1234), socketOptions.KeepAliveTimeout);
Assert.Single(socketOptions.AllowedOrigins);
Assert.Equal("someString", socketOptions.AllowedOrigins[0]);
}

[Fact]
public void ThrowsForBadOptions()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new WebSocketOptions() { KeepAliveTimeout = TimeSpan.FromMicroseconds(-1) });
}
}
Loading
Loading