diff --git a/samples/ServerApplication/Program.cs b/samples/ServerApplication/Program.cs index 1f99663a..79e19d67 100644 --- a/samples/ServerApplication/Program.cs +++ b/samples/ServerApplication/Program.cs @@ -1,85 +1,55 @@ -using System; using System.Net; using System.Security.Cryptography.X509Certificates; -using System.Threading.Tasks; using Bedrock.Framework; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using ServerApplication; -namespace ServerApplication -{ - public partial class Program - { - public static async Task Main(string[] args) - { - // Manual wire up of the server - var services = new ServiceCollection(); - services.AddLogging(builder => - { - builder.SetMinimumLevel(LogLevel.Debug); - builder.AddConsole(); - }); - - services.AddSignalR(); - - var serviceProvider = services.BuildServiceProvider(); - - var server = new ServerBuilder(serviceProvider) - .UseSockets(sockets => - { - // Echo server - sockets.ListenLocalhost(5000, - builder => builder.UseConnectionLogging().UseConnectionHandler()); +var builder = Host.CreateApplicationBuilder(); - // HTTP/1.1 server - sockets.Listen(IPAddress.Loopback, 5001, - builder => builder.UseConnectionLogging().UseConnectionHandler()); +builder.Logging.SetMinimumLevel(LogLevel.Debug); - // SignalR Hub - sockets.Listen(IPAddress.Loopback, 5002, - builder => builder.UseConnectionLogging().UseHub()); +builder.Services.AddSignalR(); - // MQTT application - sockets.Listen(IPAddress.Loopback, 5003, - builder => builder.UseConnectionLogging().UseConnectionHandler()); - - // Echo Server with TLS - sockets.Listen(IPAddress.Loopback, 5004, - builder => builder.UseServerTls(options => - { - options.LocalCertificate = X509CertificateLoader.LoadPkcs12FromFile("testcert.pfx", "testcert"); - - // NOTE: Do not do this in a production environment - options.AllowAnyRemoteCertificate(); - }) - .UseConnectionLogging().UseConnectionHandler()); +builder.ConfigureServer(server => +{ + server.UseSockets(sockets => + { + // Echo server + sockets.ListenLocalhost(5000, + builder => builder.UseConnectionLogging().UseConnectionHandler()); - sockets.Listen(IPAddress.Loopback, 5005, - builder => builder.UseConnectionLogging().UseConnectionHandler()); - }) - .Build(); + // HTTP/1.1 server + sockets.Listen(IPAddress.Loopback, 5001, + builder => builder.UseConnectionLogging().UseConnectionHandler()); - var logger = serviceProvider.GetRequiredService().CreateLogger(); + // SignalR Hub + sockets.Listen(IPAddress.Loopback, 5002, + builder => builder.UseConnectionLogging().UseHub()); - await server.StartAsync(); + // MQTT application + sockets.Listen(IPAddress.Loopback, 5003, + builder => builder.UseConnectionLogging().UseConnectionHandler()); - foreach (var ep in server.EndPoints) + // Echo Server with TLS + sockets.Listen(IPAddress.Loopback, 5004, + builder => builder.UseServerTls(options => { - logger.LogInformation("Listening on {EndPoint}", ep); - } + options.LocalCertificate = X509CertificateLoader.LoadPkcs12FromFile("testcert.pfx", "testcert"); - var tcs = new TaskCompletionSource(); - Console.CancelKeyPress += (sender, e) => - { - tcs.TrySetResult(null); - e.Cancel = true; - }; + // NOTE: Do not do this in a production environment + options.AllowAnyRemoteCertificate(); + }) + .UseConnectionLogging().UseConnectionHandler()); + + sockets.Listen(IPAddress.Loopback, 5005, + builder => builder.UseConnectionLogging().UseConnectionHandler()); + }); +}); - await tcs.Task; +var host = builder.Build(); - await server.StopAsync(); - } - } -} +host.Run(); \ No newline at end of file diff --git a/src/Bedrock.Framework/Client/Client.cs b/src/Bedrock.Framework/Client/Client.cs index 98b38d29..572124a2 100644 --- a/src/Bedrock.Framework/Client/Client.cs +++ b/src/Bedrock.Framework/Client/Client.cs @@ -3,33 +3,23 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class Client(IConnectionFactory connectionFactory, ConnectionDelegate application) : IConnectionFactory { - public class Client : IConnectionFactory + public async ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) { - private readonly IConnectionFactory _connectionFactory; - private readonly ConnectionDelegate _application; - - public Client(IConnectionFactory connectionFactory, ConnectionDelegate application) - { - _connectionFactory = connectionFactory; - _application = application; - } - - public async ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) - { - var connection = await _connectionFactory.ConnectAsync(endpoint, cancellationToken).ConfigureAwait(false); + var connection = await connectionFactory.ConnectAsync(endpoint, cancellationToken).ConfigureAwait(false); - // Since nothing is being returned from this middleware, we need to wait for the last middleware to run - // until we yield this call. Stash a tcs in the items bag that allows this code to get notified - // when the middleware ran - var connectionContextWithDelegate = new ConnectionContextWithDelegate(connection, _application); + // Since nothing is being returned from this middleware, we need to wait for the last middleware to run + // until we yield this call. Stash a tcs in the items bag that allows this code to get notified + // when the middleware ran + var connectionContextWithDelegate = new ConnectionContextWithDelegate(connection, application); - // Execute the middleware pipeline - connectionContextWithDelegate.Start(); + // Execute the middleware pipeline + connectionContextWithDelegate.Start(); - // Wait for it the most inner middleware to run - return await connectionContextWithDelegate.Initialized.Task.ConfigureAwait(false); - } + // Wait for it the most inner middleware to run + return await connectionContextWithDelegate.Initialized.Task.ConfigureAwait(false); } } diff --git a/src/Bedrock.Framework/Client/ClientBuilder.cs b/src/Bedrock.Framework/Client/ClientBuilder.cs index b27ba017..7254a333 100644 --- a/src/Bedrock.Framework/Client/ClientBuilder.cs +++ b/src/Bedrock.Framework/Client/ClientBuilder.cs @@ -5,92 +5,91 @@ using Bedrock.Framework.Infrastructure; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public partial class ClientBuilder : IConnectionBuilder { - public partial class ClientBuilder : IConnectionBuilder - { - private readonly ConnectionBuilder _connectionBuilder; + private readonly ConnectionBuilder _connectionBuilder; - public ClientBuilder() : this(EmptyServiceProvider.Instance) - { + public ClientBuilder() : this(EmptyServiceProvider.Instance) + { - } + } - public ClientBuilder(IServiceProvider serviceProvider) - { - _connectionBuilder = new ConnectionBuilder(serviceProvider); - } + public ClientBuilder(IServiceProvider serviceProvider) + { + _connectionBuilder = new ConnectionBuilder(serviceProvider); + } - internal static object Key { get; } = new object(); + internal static object Key { get; } = new object(); - private IConnectionFactory ConnectionFactory { get; set; } = new ThrowConnectionFactory(); + private IConnectionFactory ConnectionFactory { get; set; } = new ThrowConnectionFactory(); - public IServiceProvider ApplicationServices => _connectionBuilder.ApplicationServices; + public IServiceProvider ApplicationServices => _connectionBuilder.ApplicationServices; - public Client Build() + public Client Build() + { + // Middleware currently a single linear execution flow without a return value. + // We need to return the connection when it reaches the innermost middleware (D in this case) + // Then we need to wait until dispose is called to unwind that pipeline. + + // A -> + // B -> + // C -> + // D + // C <- + // B <- + // A <- + + _connectionBuilder.Run(connection => { - // Middleware currently a single linear execution flow without a return value. - // We need to return the connection when it reaches the innermost middleware (D in this case) - // Then we need to wait until dispose is called to unwind that pipeline. - - // A -> - // B -> - // C -> - // D - // C <- - // B <- - // A <- - - _connectionBuilder.Run(connection => + if (connection is ConnectionContextWithDelegate connectionContextWithDelegate) { - if (connection is ConnectionContextWithDelegate connectionContextWithDelegate) - { - connectionContextWithDelegate.Initialized.TrySetResult(connectionContextWithDelegate); + connectionContextWithDelegate.Initialized.TrySetResult(connectionContextWithDelegate); - // This task needs to stay around until the connection is disposed - // only then can we unwind the middleware chain - return connectionContextWithDelegate.ExecutionTask; - } + // This task needs to stay around until the connection is disposed + // only then can we unwind the middleware chain + return connectionContextWithDelegate.ExecutionTask; + } - // REVIEW: Do we throw in this case? It's edgy but possible to call next with a differnt - // connection delegate that originally given - return Task.CompletedTask; - }); + // REVIEW: Do we throw in this case? It's edgy but possible to call next with a differnt + // connection delegate that originally given + return Task.CompletedTask; + }); - var application = _connectionBuilder.Build(); + var application = _connectionBuilder.Build(); - return new Client(ConnectionFactory, application); - } + return new Client(ConnectionFactory, application); + } - public ClientBuilder UseConnectionFactory(IConnectionFactory connectionFactory) - { - ConnectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory)); - return this; - } + public ClientBuilder UseConnectionFactory(IConnectionFactory connectionFactory) + { + ConnectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory)); + return this; + } - public ClientBuilder Use(Func middleware) - { - ConnectionFactory = middleware(ConnectionFactory); - return this; - } + public ClientBuilder Use(Func middleware) + { + ConnectionFactory = middleware(ConnectionFactory); + return this; + } - public IConnectionBuilder Use(Func middleware) - { - return _connectionBuilder.Use(middleware); - } + public IConnectionBuilder Use(Func middleware) + { + return _connectionBuilder.Use(middleware); + } - ConnectionDelegate IConnectionBuilder.Build() - { - return _connectionBuilder.Build(); - } + ConnectionDelegate IConnectionBuilder.Build() + { + return _connectionBuilder.Build(); + } - private class ThrowConnectionFactory : IConnectionFactory + private class ThrowConnectionFactory : IConnectionFactory + { + public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) { - public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) - { - throw new InvalidOperationException("No transport configured. Set the ConnectionFactory property."); - } + throw new InvalidOperationException("No transport configured. Set the ConnectionFactory property."); } } } diff --git a/src/Bedrock.Framework/Hosting/ServerHostedService.cs b/src/Bedrock.Framework/Hosting/ServerHostedService.cs index 99e23485..7d7e7c40 100644 --- a/src/Bedrock.Framework/Hosting/ServerHostedService.cs +++ b/src/Bedrock.Framework/Hosting/ServerHostedService.cs @@ -3,25 +3,19 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Options; -namespace Bedrock.Framework -{ - public class ServerHostedService : IHostedService - { - private readonly Server _server; +namespace Bedrock.Framework; - public ServerHostedService(IOptions options) - { - _server = options.Value.ServerBuilder.Build(); - } +public class ServerHostedService(IOptions options) : IHostedService +{ + private readonly Server _server = options.Value.ServerBuilder.Build(); - public Task StartAsync(CancellationToken cancellationToken) - { - return _server.StartAsync(cancellationToken); - } + public Task StartAsync(CancellationToken cancellationToken) + { + return _server.StartAsync(cancellationToken); + } - public Task StopAsync(CancellationToken cancellationToken) - { - return _server.StopAsync(cancellationToken); - } + public Task StopAsync(CancellationToken cancellationToken) + { + return _server.StopAsync(cancellationToken); } } diff --git a/src/Bedrock.Framework/Hosting/ServerHostedServiceOptions.cs b/src/Bedrock.Framework/Hosting/ServerHostedServiceOptions.cs index 1d487a00..33a566b4 100644 --- a/src/Bedrock.Framework/Hosting/ServerHostedServiceOptions.cs +++ b/src/Bedrock.Framework/Hosting/ServerHostedServiceOptions.cs @@ -1,11 +1,6 @@ -using System; -using System.Collections.Generic; -using System.Text; +namespace Bedrock.Framework; -namespace Bedrock.Framework +public class ServerHostedServiceOptions { - public class ServerHostedServiceOptions - { - public ServerBuilder ServerBuilder { get; set; } - } + public ServerBuilder ServerBuilder { get; set; } } diff --git a/src/Bedrock.Framework/Hosting/ServiceCollectionExtensions.cs b/src/Bedrock.Framework/Hosting/ServiceCollectionExtensions.cs index 9fceb6d5..a0e9538b 100644 --- a/src/Bedrock.Framework/Hosting/ServiceCollectionExtensions.cs +++ b/src/Bedrock.Framework/Hosting/ServiceCollectionExtensions.cs @@ -2,23 +2,28 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public static class ServiceCollectionExtensions { - public static class ServiceCollectionExtensions + public static IHostBuilder ConfigureServer(this IHostBuilder builder, Action configure) => + builder.ConfigureServices(services => ConfigureServices(configure, services)); + + public static IHostApplicationBuilder ConfigureServer(this IHostApplicationBuilder builder, Action configure) + { + ConfigureServices(configure, builder.Services); + return builder; + } + + private static void ConfigureServices(Action configure, IServiceCollection services) { - public static IHostBuilder ConfigureServer(this IHostBuilder builder, Action configure) - { - return builder.ConfigureServices(services => - { - services.AddHostedService(); + services.AddHostedService(); - services.AddOptions() - .Configure((options, sp) => - { - options.ServerBuilder = new ServerBuilder(sp); - configure(options.ServerBuilder); - }); - }); - } + services.AddOptions() + .Configure((options, sp) => + { + options.ServerBuilder = new ServerBuilder(sp); + configure(options.ServerBuilder); + }); } } diff --git a/src/Bedrock.Framework/Infrastructure/DuplexPipe.cs b/src/Bedrock.Framework/Infrastructure/DuplexPipe.cs index 066141c2..ae426502 100644 --- a/src/Bedrock.Framework/Infrastructure/DuplexPipe.cs +++ b/src/Bedrock.Framework/Infrastructure/DuplexPipe.cs @@ -1,42 +1,29 @@ // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. -namespace System.IO.Pipelines -{ - internal class DuplexPipe : IDuplexPipe - { - public DuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } +namespace System.IO.Pipelines; - public PipeReader Input { get; } - - public PipeWriter Output { get; } +internal class DuplexPipe(PipeReader reader, PipeWriter writer) : IDuplexPipe +{ + public PipeReader Input { get; } = reader; - public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); + public PipeWriter Output { get; } = writer; - var transportToApplication = new DuplexPipe(output.Reader, input.Writer); - var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); + public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) + { + var input = new Pipe(inputOptions); + var output = new Pipe(outputOptions); - return new DuplexPipePair(applicationToTransport, transportToApplication); - } + var transportToApplication = new DuplexPipe(output.Reader, input.Writer); + var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); - // This class exists to work around issues with value tuple on .NET Framework - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } + return new DuplexPipePair(applicationToTransport, transportToApplication); + } - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } + // This class exists to work around issues with value tuple on .NET Framework + public readonly struct DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) + { + public IDuplexPipe Transport { get; } = transport; + public IDuplexPipe Application { get; } = application; } } \ No newline at end of file diff --git a/src/Bedrock.Framework/Infrastructure/DuplexPipeStream.cs b/src/Bedrock.Framework/Infrastructure/DuplexPipeStream.cs index df2ec1d9..097d47da 100644 --- a/src/Bedrock.Framework/Infrastructure/DuplexPipeStream.cs +++ b/src/Bedrock.Framework/Infrastructure/DuplexPipeStream.cs @@ -5,166 +5,155 @@ using System.Threading; using System.Threading.Tasks; -namespace Bedrock.Framework.Infrastructure -{ - internal class DuplexPipeStream : Stream - { - private readonly PipeReader _input; - private readonly PipeWriter _output; - private readonly bool _throwOnCancelled; - private volatile bool _cancelCalled; - - public DuplexPipeStream(PipeReader input, PipeWriter output, bool throwOnCancelled = false) - { - _input = input; - _output = output; - _throwOnCancelled = throwOnCancelled; - } +namespace Bedrock.Framework.Infrastructure; - public void CancelPendingRead() - { - _cancelCalled = true; - _input.CancelPendingRead(); - } +internal class DuplexPipeStream(PipeReader input, PipeWriter output, bool throwOnCancelled = false) : Stream +{ + private volatile bool _cancelCalled; - public override bool CanRead => true; + public void CancelPendingRead() + { + _cancelCalled = true; + input.CancelPendingRead(); + } - public override bool CanSeek => false; + public override bool CanRead => true; - public override bool CanWrite => true; + public override bool CanSeek => false; - public override long Length - { - get - { - throw new NotSupportedException(); - } - } + public override bool CanWrite => true; - public override long Position + public override long Length + { + get { - get - { - throw new NotSupportedException(); - } - set - { - throw new NotSupportedException(); - } + throw new NotSupportedException(); } + } - public override long Seek(long offset, SeekOrigin origin) + public override long Position + { + get { throw new NotSupportedException(); } - - public override void SetLength(long value) + set { throw new NotSupportedException(); } + } - public override int Read(byte[] buffer, int offset, int count) - { - // ValueTask uses .GetAwaiter().GetResult() if necessary - // https://github.com/dotnet/corefx/blob/f9da3b4af08214764a51b2331f3595ffaf162abe/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs#L156 - return ReadAsyncInternal(new Memory(buffer, offset, count), default).Result; - } + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - { - return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); - } + public override void SetLength(long value) + { + throw new NotSupportedException(); + } - public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) - { - return ReadAsyncInternal(destination, cancellationToken); - } + public override int Read(byte[] buffer, int offset, int count) + { + // ValueTask uses .GetAwaiter().GetResult() if necessary + // https://github.com/dotnet/corefx/blob/f9da3b4af08214764a51b2331f3595ffaf162abe/src/System.Threading.Tasks.Extensions/src/System/Threading/Tasks/ValueTask.cs#L156 + return ReadAsyncInternal(new Memory(buffer, offset, count), default).Result; + } - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + return ReadAsyncInternal(new Memory(buffer, offset, count), cancellationToken).AsTask(); + } - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - if (buffer != null) - { - _output.Write(new ReadOnlySpan(buffer, offset, count)); - } + public override ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + return ReadAsyncInternal(destination, cancellationToken); + } - await _output.FlushAsync(cancellationToken).ConfigureAwait(false); - } + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); + } - public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (buffer != null) { - _output.Write(source.Span); - await _output.FlushAsync(cancellationToken).ConfigureAwait(false); + output.Write(new ReadOnlySpan(buffer, offset, count)); } - public override void Flush() - { - FlushAsync(CancellationToken.None).GetAwaiter().GetResult(); - } + await output.FlushAsync(cancellationToken).ConfigureAwait(false); + } - public override Task FlushAsync(CancellationToken cancellationToken) - { - return WriteAsync(null, 0, 0, cancellationToken); - } + public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + output.Write(source.Span); + await output.FlushAsync(cancellationToken).ConfigureAwait(false); + } + + public override void Flush() + { + FlushAsync(CancellationToken.None).GetAwaiter().GetResult(); + } - private async ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) + public override Task FlushAsync(CancellationToken cancellationToken) + { + return WriteAsync(null, 0, 0, cancellationToken); + } + + private async ValueTask ReadAsyncInternal(Memory destination, CancellationToken cancellationToken) + { + while (true) { - while (true) + var result = await input.ReadAsync(cancellationToken).ConfigureAwait(false); + var readableBuffer = result.Buffer; + try { - var result = await _input.ReadAsync(cancellationToken).ConfigureAwait(false); - var readableBuffer = result.Buffer; - try + if (throwOnCancelled && result.IsCanceled && _cancelCalled) + { + // Reset the bool + _cancelCalled = false; + throw new OperationCanceledException(); + } + + if (!readableBuffer.IsEmpty) { - if (_throwOnCancelled && result.IsCanceled && _cancelCalled) - { - // Reset the bool - _cancelCalled = false; - throw new OperationCanceledException(); - } - - if (!readableBuffer.IsEmpty) - { - // buffer.Count is int - var count = (int)Math.Min(readableBuffer.Length, destination.Length); - readableBuffer = readableBuffer.Slice(0, count); - readableBuffer.CopyTo(destination.Span); - return count; - } - - if (result.IsCompleted) - { - return 0; - } + // buffer.Count is int + var count = (int)Math.Min(readableBuffer.Length, destination.Length); + readableBuffer = readableBuffer.Slice(0, count); + readableBuffer.CopyTo(destination.Span); + return count; } - finally + + if (result.IsCompleted) { - _input.AdvanceTo(readableBuffer.End, readableBuffer.End); + return 0; } } + finally + { + input.AdvanceTo(readableBuffer.End, readableBuffer.End); + } } + } - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { - return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state); - } + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state); + } - public override int EndRead(IAsyncResult asyncResult) - { - return TaskToApm.End(asyncResult); - } + public override int EndRead(IAsyncResult asyncResult) + { + return TaskToApm.End(asyncResult); + } - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { - return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state); - } + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state); + } - public override void EndWrite(IAsyncResult asyncResult) - { - TaskToApm.End(asyncResult); - } + public override void EndWrite(IAsyncResult asyncResult) + { + TaskToApm.End(asyncResult); } } diff --git a/src/Bedrock.Framework/Infrastructure/DuplexPipeStreamAdapter.cs b/src/Bedrock.Framework/Infrastructure/DuplexPipeStreamAdapter.cs index 2cfb46a5..186ec50b 100644 --- a/src/Bedrock.Framework/Infrastructure/DuplexPipeStreamAdapter.cs +++ b/src/Bedrock.Framework/Infrastructure/DuplexPipeStreamAdapter.cs @@ -6,56 +6,55 @@ using System.IO.Pipelines; using System.Threading.Tasks; -namespace Bedrock.Framework.Infrastructure +namespace Bedrock.Framework.Infrastructure; + +/// +/// A helper for wrapping a Stream decorator from an . +/// +/// +internal class DuplexPipeStreamAdapter : DuplexPipeStream, IDuplexPipe where TStream : Stream { - /// - /// A helper for wrapping a Stream decorator from an . - /// - /// - internal class DuplexPipeStreamAdapter : DuplexPipeStream, IDuplexPipe where TStream : Stream - { - private bool _disposed; - private readonly object _disposeLock = new object(); + private bool _disposed; + private readonly object _disposeLock = new object(); - public DuplexPipeStreamAdapter(IDuplexPipe duplexPipe, Func createStream) : - this(duplexPipe, new StreamPipeReaderOptions(leaveOpen: true), new StreamPipeWriterOptions(leaveOpen: true), createStream) - { - } + public DuplexPipeStreamAdapter(IDuplexPipe duplexPipe, Func createStream) : + this(duplexPipe, new StreamPipeReaderOptions(leaveOpen: true), new StreamPipeWriterOptions(leaveOpen: true), createStream) + { + } - public DuplexPipeStreamAdapter(IDuplexPipe duplexPipe, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions, Func createStream) : - base(duplexPipe.Input, duplexPipe.Output) - { - var stream = createStream(this); - Stream = stream; - Input = PipeReader.Create(stream, readerOptions); - Output = PipeWriter.Create(stream, writerOptions); - } + public DuplexPipeStreamAdapter(IDuplexPipe duplexPipe, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions, Func createStream) : + base(duplexPipe.Input, duplexPipe.Output) + { + var stream = createStream(this); + Stream = stream; + Input = PipeReader.Create(stream, readerOptions); + Output = PipeWriter.Create(stream, writerOptions); + } - public TStream Stream { get; } + public TStream Stream { get; } - public PipeReader Input { get; } + public PipeReader Input { get; } - public PipeWriter Output { get; } + public PipeWriter Output { get; } - public override async ValueTask DisposeAsync() + public override async ValueTask DisposeAsync() + { + lock (_disposeLock) { - lock (_disposeLock) + if (_disposed) { - if (_disposed) - { - return; - } - _disposed = true; + return; } - - await Input.CompleteAsync().ConfigureAwait(false); - await Output.CompleteAsync().ConfigureAwait(false); + _disposed = true; } - protected override void Dispose(bool disposing) - { - throw new NotSupportedException(); - } + await Input.CompleteAsync().ConfigureAwait(false); + await Output.CompleteAsync().ConfigureAwait(false); + } + + protected override void Dispose(bool disposing) + { + throw new NotSupportedException(); } } diff --git a/src/Bedrock.Framework/Infrastructure/EmptyServiceProvder.cs b/src/Bedrock.Framework/Infrastructure/EmptyServiceProvder.cs index 033fd0e6..05b32d88 100644 --- a/src/Bedrock.Framework/Infrastructure/EmptyServiceProvder.cs +++ b/src/Bedrock.Framework/Infrastructure/EmptyServiceProvder.cs @@ -1,11 +1,10 @@ using System; -namespace Bedrock.Framework.Infrastructure +namespace Bedrock.Framework.Infrastructure; + +internal class EmptyServiceProvider : IServiceProvider { - internal class EmptyServiceProvider : IServiceProvider - { - public static IServiceProvider Instance { get; } = new EmptyServiceProvider(); + public static IServiceProvider Instance { get; } = new EmptyServiceProvider(); - public object GetService(Type serviceType) => null; - } + public object GetService(Type serviceType) => null; } diff --git a/src/Bedrock.Framework/Infrastructure/MemoryPoolExtensions.cs b/src/Bedrock.Framework/Infrastructure/MemoryPoolExtensions.cs index c1e3c40f..543c88f3 100644 --- a/src/Bedrock.Framework/Infrastructure/MemoryPoolExtensions.cs +++ b/src/Bedrock.Framework/Infrastructure/MemoryPoolExtensions.cs @@ -1,31 +1,28 @@ using System; using System.Buffers; -using System.Collections.Generic; -using System.Text; -namespace Bedrock.Framework.Infrastructure +namespace Bedrock.Framework.Infrastructure; + +internal static class MemoryPoolExtensions { - internal static class MemoryPoolExtensions + /// + /// Computes a minimum segment size + /// + /// + /// + public static int GetMinimumSegmentSize(this MemoryPool pool) { - /// - /// Computes a minimum segment size - /// - /// - /// - public static int GetMinimumSegmentSize(this MemoryPool pool) + if (pool == null) { - if (pool == null) - { - return 4096; - } - - return Math.Min(4096, pool.MaxBufferSize); + return 4096; } - public static int GetMinimumAllocSize(this MemoryPool pool) - { - // 1/2 of a segment - return pool.GetMinimumSegmentSize() / 2; - } + return Math.Min(4096, pool.MaxBufferSize); + } + + public static int GetMinimumAllocSize(this MemoryPool pool) + { + // 1/2 of a segment + return pool.GetMinimumSegmentSize() / 2; } } diff --git a/src/Bedrock.Framework/Infrastructure/TaskExtensions.cs b/src/Bedrock.Framework/Infrastructure/TaskExtensions.cs index c241ff1e..69119212 100644 --- a/src/Bedrock.Framework/Infrastructure/TaskExtensions.cs +++ b/src/Bedrock.Framework/Infrastructure/TaskExtensions.cs @@ -1,55 +1,35 @@ using System; -using System.Collections.Generic; -using System.Text; using System.Threading; using System.Threading.Tasks; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal static class TaskExtensions { - internal static class TaskExtensions + public static async Task WithCancellation(this Task task, CancellationToken cancellationToken) { - public static async Task WithCancellation(this Task task, CancellationToken cancellationToken) + try { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - // This disposes the registration as soon as one of the tasks trigger - using (cancellationToken.Register(state => - { - ((TaskCompletionSource)state).TrySetResult(null); - }, - tcs)) - { - var resultTask = await Task.WhenAny(task, tcs.Task).ConfigureAwait(false); - if (resultTask == tcs.Task) - { - // Operation cancelled - return false; - } - - await task.ConfigureAwait(false); - return true; - } + await task.WaitAsync(cancellationToken); + return true; } - - public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + catch (OperationCanceledException) { - using var cts = new CancellationTokenSource(); - var delayTask = Task.Delay(timeout, cts.Token); - - var resultTask = await Task.WhenAny(task, delayTask).ConfigureAwait(false); - if (resultTask == delayTask) - { - // Operation cancelled - return false; - } - else - { - // Cancel the timer task so that it does not fire - cts.Cancel(); - } + return false; + } + } - await task.ConfigureAwait(false); + public static async Task TimeoutAfter(this Task task, TimeSpan timeout) + { + try + { + using var cts = new CancellationTokenSource(timeout); + await task.WaitAsync(cts.Token); return true; } + catch (OperationCanceledException) + { + return false; + } } } diff --git a/src/Bedrock.Framework/Infrastructure/TaskToApm.cs b/src/Bedrock.Framework/Infrastructure/TaskToApm.cs index 64989529..51ffe019 100644 --- a/src/Bedrock.Framework/Infrastructure/TaskToApm.cs +++ b/src/Bedrock.Framework/Infrastructure/TaskToApm.cs @@ -15,107 +15,106 @@ #nullable enable using System.Diagnostics; -namespace System.Threading.Tasks +namespace System.Threading.Tasks; + +/// +/// Provides support for efficiently using Tasks to implement the APM (Begin/End) pattern. +/// +internal static class TaskToApm { /// - /// Provides support for efficiently using Tasks to implement the APM (Begin/End) pattern. + /// Marshals the Task as an IAsyncResult, using the supplied callback and state + /// to implement the APM pattern. /// - internal static class TaskToApm - { - /// - /// Marshals the Task as an IAsyncResult, using the supplied callback and state - /// to implement the APM pattern. - /// - /// The Task to be marshaled. - /// The callback to be invoked upon completion. - /// The state to be stored in the IAsyncResult. - /// An IAsyncResult to represent the task's asynchronous operation. - public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? state) => - new TaskAsyncResult(task, state, callback); + /// The Task to be marshaled. + /// The callback to be invoked upon completion. + /// The state to be stored in the IAsyncResult. + /// An IAsyncResult to represent the task's asynchronous operation. + public static IAsyncResult Begin(Task task, AsyncCallback? callback, object? state) => + new TaskAsyncResult(task, state, callback); - /// Processes an IAsyncResult returned by Begin. - /// The IAsyncResult to unwrap. - public static void End(IAsyncResult asyncResult) + /// Processes an IAsyncResult returned by Begin. + /// The IAsyncResult to unwrap. + public static void End(IAsyncResult asyncResult) + { + if (asyncResult is TaskAsyncResult twar) { - if (asyncResult is TaskAsyncResult twar) - { - twar._task.GetAwaiter().GetResult(); - return; - } - - throw new ArgumentNullException(); + twar._task.GetAwaiter().GetResult(); + return; } - /// Processes an IAsyncResult returned by Begin. - /// The IAsyncResult to unwrap. - public static TResult End(IAsyncResult asyncResult) - { - if (asyncResult is TaskAsyncResult twar && twar._task is Task task) - { - return task.GetAwaiter().GetResult(); - } + throw new ArgumentNullException(); + } - throw new ArgumentNullException(); + /// Processes an IAsyncResult returned by Begin. + /// The IAsyncResult to unwrap. + public static TResult End(IAsyncResult asyncResult) + { + if (asyncResult is TaskAsyncResult twar && twar._task is Task task) + { + return task.GetAwaiter().GetResult(); } - /// Provides a simple IAsyncResult that wraps a Task. - /// - /// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state, - /// but that's very rare, in particular in a situation where someone cares about allocation, and always - /// using TaskAsyncResult simplifies things and enables additional optimizations. - /// - internal sealed class TaskAsyncResult : IAsyncResult + throw new ArgumentNullException(); + } + + /// Provides a simple IAsyncResult that wraps a Task. + /// + /// We could use the Task as the IAsyncResult if the Task's AsyncState is the same as the object state, + /// but that's very rare, in particular in a situation where someone cares about allocation, and always + /// using TaskAsyncResult simplifies things and enables additional optimizations. + /// + internal sealed class TaskAsyncResult : IAsyncResult + { + /// The wrapped Task. + internal readonly Task _task; + /// Callback to invoke when the wrapped task completes. + private readonly AsyncCallback? _callback; + + /// Initializes the IAsyncResult with the Task to wrap and the associated object state. + /// The Task to wrap. + /// The new AsyncState value. + /// Callback to invoke when the wrapped task completes. + internal TaskAsyncResult(Task task, object? state, AsyncCallback? callback) { - /// The wrapped Task. - internal readonly Task _task; - /// Callback to invoke when the wrapped task completes. - private readonly AsyncCallback? _callback; + Debug.Assert(task != null); + _task = task; + AsyncState = state; - /// Initializes the IAsyncResult with the Task to wrap and the associated object state. - /// The Task to wrap. - /// The new AsyncState value. - /// Callback to invoke when the wrapped task completes. - internal TaskAsyncResult(Task task, object? state, AsyncCallback? callback) + if (task.IsCompleted) { - Debug.Assert(task != null); - _task = task; - AsyncState = state; - - if (task.IsCompleted) - { - // Synchronous completion. Invoke the callback. No need to store it. - CompletedSynchronously = true; - callback?.Invoke(this); - } - else if (callback != null) - { - // Asynchronous completion, and we have a callback; schedule it. We use OnCompleted rather than ContinueWith in - // order to avoid running synchronously if the task has already completed by the time we get here but still run - // synchronously as part of the task's completion if the task completes after (the more common case). - _callback = callback; - _task.ConfigureAwait(continueOnCapturedContext: false) - .GetAwaiter() - .OnCompleted(InvokeCallback); // allocates a delegate, but avoids a closure - } + // Synchronous completion. Invoke the callback. No need to store it. + CompletedSynchronously = true; + callback?.Invoke(this); } - - /// Invokes the callback. - private void InvokeCallback() + else if (callback != null) { - Debug.Assert(!CompletedSynchronously); - Debug.Assert(_callback != null); - _callback.Invoke(this); + // Asynchronous completion, and we have a callback; schedule it. We use OnCompleted rather than ContinueWith in + // order to avoid running synchronously if the task has already completed by the time we get here but still run + // synchronously as part of the task's completion if the task completes after (the more common case). + _callback = callback; + _task.ConfigureAwait(continueOnCapturedContext: false) + .GetAwaiter() + .OnCompleted(InvokeCallback); // allocates a delegate, but avoids a closure } + } - /// Gets a user-defined object that qualifies or contains information about an asynchronous operation. - public object? AsyncState { get; } - /// Gets a value that indicates whether the asynchronous operation completed synchronously. - /// This is set lazily based on whether the has completed by the time this object is created. - public bool CompletedSynchronously { get; } - /// Gets a value that indicates whether the asynchronous operation has completed. - public bool IsCompleted => _task.IsCompleted; - /// Gets a that is used to wait for an asynchronous operation to complete. - public WaitHandle AsyncWaitHandle => ((IAsyncResult)_task).AsyncWaitHandle; + /// Invokes the callback. + private void InvokeCallback() + { + Debug.Assert(!CompletedSynchronously); + Debug.Assert(_callback != null); + _callback.Invoke(this); } + + /// Gets a user-defined object that qualifies or contains information about an asynchronous operation. + public object? AsyncState { get; } + /// Gets a value that indicates whether the asynchronous operation completed synchronously. + /// This is set lazily based on whether the has completed by the time this object is created. + public bool CompletedSynchronously { get; } + /// Gets a value that indicates whether the asynchronous operation has completed. + public bool IsCompleted => _task.IsCompleted; + /// Gets a that is used to wait for an asynchronous operation to complete. + public WaitHandle AsyncWaitHandle => ((IAsyncResult)_task).AsyncWaitHandle; } } \ No newline at end of file diff --git a/src/Bedrock.Framework/Infrastructure/TimerAwaitable.cs b/src/Bedrock.Framework/Infrastructure/TimerAwaitable.cs deleted file mode 100644 index b274442a..00000000 --- a/src/Bedrock.Framework/Infrastructure/TimerAwaitable.cs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace Bedrock.Framework -{ - internal class TimerAwaitable : IDisposable, ICriticalNotifyCompletion - { - private Timer _timer; - private Action _callback; - private static readonly Action _callbackCompleted = () => { }; - - private readonly TimeSpan _period; - - private readonly TimeSpan _dueTime; - private bool _disposed; - private bool _running = true; - private readonly object _lockObj = new object(); - - public TimerAwaitable(TimeSpan dueTime, TimeSpan period) - { - _dueTime = dueTime; - _period = period; - } - - public void Start() - { - if (_timer == null) - { - lock (_lockObj) - { - if (_disposed) - { - return; - } - - if (_timer == null) - { - _timer = new Timer(state => ((TimerAwaitable)state).Tick(), this, _dueTime, _period); - } - } - } - } - - public TimerAwaitable GetAwaiter() => this; - public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); - - public bool GetResult() - { - _callback = null; - - return _running; - } - - private void Tick() - { - var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); - continuation?.Invoke(); - } - - public void OnCompleted(Action continuation) - { - if (ReferenceEquals(_callback, _callbackCompleted) || - ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) - { - Task.Run(continuation); - } - } - - public void UnsafeOnCompleted(Action continuation) - { - OnCompleted(continuation); - } - - public void Stop() - { - lock (_lockObj) - { - // Stop should be used to trigger the call to end the loop which disposes - if (_disposed) - { - throw new ObjectDisposedException(GetType().FullName); - } - - _running = false; - } - - // Call tick here to make sure that we yield the callback, - // if it's currently waiting, we don't need to wait for the next period - Tick(); - } - - void IDisposable.Dispose() - { - lock (_lockObj) - { - _disposed = true; - - _timer?.Dispose(); - - _timer = null; - } - } - } -} \ No newline at end of file diff --git a/src/Bedrock.Framework/Middleware/ConnectionBuilderExtensions.cs b/src/Bedrock.Framework/Middleware/ConnectionBuilderExtensions.cs index e9b95a9e..5f03d24e 100644 --- a/src/Bedrock.Framework/Middleware/ConnectionBuilderExtensions.cs +++ b/src/Bedrock.Framework/Middleware/ConnectionBuilderExtensions.cs @@ -1,91 +1,80 @@ using System; -using System.Collections.Generic; -using System.Text; -using Bedrock.Framework.Infrastructure; using Bedrock.Framework.Middleware.Tls; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public delegate void LoggingFormatter(ILogger logger, string method, ReadOnlySpan buffer); + +public static class ConnectionBuilderExtensions { - public delegate void LoggingFormatter(ILogger logger, string method, ReadOnlySpan buffer); + /// + /// Emits verbose logs for bytes read from and written to the connection. + /// + public static TBuilder UseConnectionLogging(this TBuilder builder, string loggerName = null, ILoggerFactory loggerFactory = null, LoggingFormatter loggingFormatter = null) where TBuilder : IConnectionBuilder + { + loggerFactory ??= builder.ApplicationServices.GetRequiredService(); + var logger = loggerName == null ? loggerFactory.CreateLogger() : loggerFactory.CreateLogger(loggerName); + builder.Use(next => new LoggingConnectionMiddleware(next, logger, loggingFormatter).OnConnectionAsync); + return builder; + } - public static class ConnectionBuilderExtensions + public static TBuilder UseConnectionLimits(this TBuilder builder, int connectionLimit) where TBuilder : IConnectionBuilder { - /// - /// Emits verbose logs for bytes read from and written to the connection. - /// - public static TBuilder UseConnectionLogging(this TBuilder builder, string loggerName = null, ILoggerFactory loggerFactory = null, LoggingFormatter loggingFormatter = null) where TBuilder : IConnectionBuilder - { - loggerFactory ??= builder.ApplicationServices.GetRequiredService(); - var logger = loggerName == null ? loggerFactory.CreateLogger() : loggerFactory.CreateLogger(loggerName); - builder.Use(next => new LoggingConnectionMiddleware(next, logger, loggingFormatter).OnConnectionAsync); - return builder; - } + var loggerFactory = builder.ApplicationServices.GetService() ?? NullLoggerFactory.Instance; + var logger = loggerFactory.CreateLogger(); + builder.Use(next => new ConnectionLimitMiddleware(next, logger, connectionLimit).OnConnectionAsync); + return builder; + } - public static TBuilder UseConnectionLimits(this TBuilder builder, int connectionLimit) where TBuilder : IConnectionBuilder - { - var loggerFactory = builder.ApplicationServices.GetService() ?? NullLoggerFactory.Instance; - var logger = loggerFactory.CreateLogger(); - builder.Use(next => new ConnectionLimitMiddleware(next, logger, connectionLimit).OnConnectionAsync); - return builder; - } + public static TBuilder UseServerTls( + this TBuilder builder, + Action configure) where TBuilder : IConnectionBuilder + { + var options = new TlsOptions(); + configure(options); + return builder.UseServerTls(options); + } - public static TBuilder UseServerTls( - this TBuilder builder, - Action configure) where TBuilder : IConnectionBuilder - { - var options = new TlsOptions(); - configure(options); - return builder.UseServerTls(options); - } + public static TBuilder UseServerTls( + this TBuilder builder, + TlsOptions options) where TBuilder : IConnectionBuilder + { + ArgumentNullException.ThrowIfNull(options); - public static TBuilder UseServerTls( - this TBuilder builder, - TlsOptions options) where TBuilder : IConnectionBuilder + var loggerFactory = builder.ApplicationServices.GetService(typeof(ILoggerFactory)) as ILoggerFactory ?? NullLoggerFactory.Instance; + builder.Use(next => { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } + var middleware = new TlsServerConnectionMiddleware(next, options, loggerFactory); + return middleware.OnConnectionAsync; + }); + return builder; + } - var loggerFactory = builder.ApplicationServices.GetService(typeof(ILoggerFactory)) as ILoggerFactory ?? NullLoggerFactory.Instance; - builder.Use(next => - { - var middleware = new TlsServerConnectionMiddleware(next, options, loggerFactory); - return middleware.OnConnectionAsync; - }); - return builder; - } + public static TBuilder UseClientTls( + this TBuilder builder, + Action configure) where TBuilder : IConnectionBuilder + { + var options = new TlsOptions(); + configure(options); + return builder.UseClientTls(options); + } - public static TBuilder UseClientTls( - this TBuilder builder, - Action configure) where TBuilder : IConnectionBuilder - { - var options = new TlsOptions(); - configure(options); - return builder.UseClientTls(options); - } + public static TBuilder UseClientTls( + this TBuilder builder, + TlsOptions options) where TBuilder : IConnectionBuilder + { + ArgumentNullException.ThrowIfNull(options); - public static TBuilder UseClientTls( - this TBuilder builder, - TlsOptions options) where TBuilder : IConnectionBuilder + var loggerFactory = builder.ApplicationServices.GetService(typeof(ILoggerFactory)) as ILoggerFactory ?? NullLoggerFactory.Instance; + builder.Use(next => { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - - var loggerFactory = builder.ApplicationServices.GetService(typeof(ILoggerFactory)) as ILoggerFactory ?? NullLoggerFactory.Instance; - builder.Use(next => - { - var middleware = new TlsClientConnectionMiddleware(next, options, loggerFactory); - return middleware.OnConnectionAsync; - }); - return builder; - } + var middleware = new TlsClientConnectionMiddleware(next, options, loggerFactory); + return middleware.OnConnectionAsync; + }); + return builder; } - } diff --git a/src/Bedrock.Framework/Middleware/ConnectionLimitMiddleware.cs b/src/Bedrock.Framework/Middleware/ConnectionLimitMiddleware.cs index 4a7928fb..f78692cc 100644 --- a/src/Bedrock.Framework/Middleware/ConnectionLimitMiddleware.cs +++ b/src/Bedrock.Framework/Middleware/ConnectionLimitMiddleware.cs @@ -4,45 +4,35 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class ConnectionLimitMiddleware(ConnectionDelegate next, ILogger logger, int limit) { - public class ConnectionLimitMiddleware - { - private readonly ConnectionDelegate _next; - private readonly SemaphoreSlim _limiter; - private readonly ILogger _logger; + private readonly SemaphoreSlim _limiter = new(limit); - public ConnectionLimitMiddleware(ConnectionDelegate next, ILogger logger, int limit) - { - _next = next; - _logger = logger; - _limiter = new SemaphoreSlim(limit); - } + public async Task OnConnectionAsync(ConnectionContext connectionContext) + { + // Wait 10 seconds for a connection + var task = _limiter.WaitAsync(TimeSpan.FromSeconds(10)); - public async Task OnConnectionAsync(ConnectionContext connectionContext) + if (!task.IsCompletedSuccessfully) { - // Wait 10 seconds for a connection - var task = _limiter.WaitAsync(TimeSpan.FromSeconds(10)); + logger.LogInformation("{ConnectionId} queued", connectionContext.ConnectionId); - if (!task.IsCompletedSuccessfully) + if (!await task.ConfigureAwait(false)) { - _logger.LogInformation("{ConnectionId} queued", connectionContext.ConnectionId); - - if (!await task.ConfigureAwait(false)) - { - _logger.LogInformation("{ConnectionId} timed out in the connection queue", connectionContext.ConnectionId); - return; - } + logger.LogInformation("{ConnectionId} timed out in the connection queue", connectionContext.ConnectionId); + return; } + } - try - { - await _next(connectionContext).ConfigureAwait(false); - } - finally - { - _limiter.Release(); - } + try + { + await next(connectionContext).ConfigureAwait(false); + } + finally + { + _limiter.Release(); } } } diff --git a/src/Bedrock.Framework/Middleware/Internal/LoggingStream.cs b/src/Bedrock.Framework/Middleware/Internal/LoggingStream.cs index c750eca9..c48caa21 100644 --- a/src/Bedrock.Framework/Middleware/Internal/LoggingStream.cs +++ b/src/Bedrock.Framework/Middleware/Internal/LoggingStream.cs @@ -8,217 +8,205 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework.Infrastructure +namespace Bedrock.Framework.Infrastructure; + +internal sealed class LoggingStream(Stream inner, ILogger logger, LoggingFormatter logFormatter = null) : Stream { - internal sealed class LoggingStream : Stream + public override bool CanRead { - private readonly Stream _inner; - private readonly ILogger _logger; - private readonly LoggingFormatter _logFormatter; - - public LoggingStream(Stream inner, ILogger logger, LoggingFormatter logFormatter = null) + get { - _inner = inner; - _logger = logger; - _logFormatter = logFormatter; + return inner.CanRead; } + } - public override bool CanRead + public override bool CanSeek + { + get { - get - { - return _inner.CanRead; - } + return inner.CanSeek; } + } - public override bool CanSeek + public override bool CanWrite + { + get { - get - { - return _inner.CanSeek; - } + return inner.CanWrite; } + } - public override bool CanWrite + public override long Length + { + get { - get - { - return _inner.CanWrite; - } + return inner.Length; } + } - public override long Length + public override long Position + { + get { - get - { - return _inner.Length; - } + return inner.Position; } - public override long Position + set { - get - { - return _inner.Position; - } - - set - { - _inner.Position = value; - } + inner.Position = value; } + } - public override void Flush() - { - _inner.Flush(); - } + public override void Flush() + { + inner.Flush(); + } - public override Task FlushAsync(CancellationToken cancellationToken) - { - return _inner.FlushAsync(cancellationToken); - } + public override Task FlushAsync(CancellationToken cancellationToken) + { + return inner.FlushAsync(cancellationToken); + } - public override int Read(byte[] buffer, int offset, int count) - { - int read = _inner.Read(buffer, offset, count); - Log("Read", new ReadOnlySpan(buffer, offset, read)); - return read; - } + public override int Read(byte[] buffer, int offset, int count) + { + int read = inner.Read(buffer, offset, count); + Log("Read", new ReadOnlySpan(buffer, offset, read)); + return read; + } - public override int Read(Span destination) - { - int read = _inner.Read(destination); - Log("Read", destination.Slice(0, read)); - return read; - } + public override int Read(Span destination) + { + int read = inner.Read(destination); + Log("Read", destination.Slice(0, read)); + return read; + } - public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); - Log("ReadAsync", new ReadOnlySpan(buffer, offset, read)); - return read; - } + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await inner.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + Log("ReadAsync", new ReadOnlySpan(buffer, offset, read)); + return read; + } - public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) - { - int read = await _inner.ReadAsync(destination, cancellationToken).ConfigureAwait(false); - Log("ReadAsync", destination.Span.Slice(0, read)); - return read; - } + public override async ValueTask ReadAsync(Memory destination, CancellationToken cancellationToken = default) + { + int read = await inner.ReadAsync(destination, cancellationToken).ConfigureAwait(false); + Log("ReadAsync", destination.Span.Slice(0, read)); + return read; + } - public override long Seek(long offset, SeekOrigin origin) - { - return _inner.Seek(offset, origin); - } + public override long Seek(long offset, SeekOrigin origin) + { + return inner.Seek(offset, origin); + } - public override void SetLength(long value) - { - _inner.SetLength(value); - } + public override void SetLength(long value) + { + inner.SetLength(value); + } - public override void Write(byte[] buffer, int offset, int count) - { - Log("Write", new ReadOnlySpan(buffer, offset, count)); - _inner.Write(buffer, offset, count); - } + public override void Write(byte[] buffer, int offset, int count) + { + Log("Write", new ReadOnlySpan(buffer, offset, count)); + inner.Write(buffer, offset, count); + } - public override void Write(ReadOnlySpan source) - { - Log("Write", source); - _inner.Write(source); - } + public override void Write(ReadOnlySpan source) + { + Log("Write", source); + inner.Write(source); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Log("WriteAsync", new ReadOnlySpan(buffer, offset, count)); + return inner.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + Log("WriteAsync", source.Span); + return inner.WriteAsync(source, cancellationToken); + } - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + private void Log(string method, ReadOnlySpan buffer) + { + if (logFormatter != null) { - Log("WriteAsync", new ReadOnlySpan(buffer, offset, count)); - return _inner.WriteAsync(buffer, offset, count, cancellationToken); + logFormatter(logger, method, buffer); + return; } - public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + if (!logger.IsEnabled(LogLevel.Debug)) { - Log("WriteAsync", source.Span); - return _inner.WriteAsync(source, cancellationToken); + return; } - private void Log(string method, ReadOnlySpan buffer) + var builder = new StringBuilder(); + builder.AppendLine($"{method}[{buffer.Length}]"); + var charBuilder = new StringBuilder(); + + // Write the hex + for (int i = 0; i < buffer.Length; i++) { - if (_logFormatter != null) - { - _logFormatter(_logger, method, buffer); - return; - } + builder.Append(buffer[i].ToString("X2")); + builder.Append(" "); - if (!_logger.IsEnabled(LogLevel.Debug)) + var bufferChar = (char)buffer[i]; + if (char.IsControl(bufferChar)) { - return; + charBuilder.Append("."); } - - var builder = new StringBuilder(); - builder.AppendLine($"{method}[{buffer.Length}]"); - var charBuilder = new StringBuilder(); - - // Write the hex - for (int i = 0; i < buffer.Length; i++) + else { - builder.Append(buffer[i].ToString("X2")); - builder.Append(" "); - - var bufferChar = (char)buffer[i]; - if (char.IsControl(bufferChar)) - { - charBuilder.Append("."); - } - else - { - charBuilder.Append(bufferChar); - } - - if ((i + 1) % 16 == 0) - { - builder.Append(" "); - builder.Append(charBuilder.ToString()); - builder.AppendLine(); - charBuilder.Clear(); - } - else if ((i + 1) % 8 == 0) - { - builder.Append(" "); - charBuilder.Append(" "); - } + charBuilder.Append(bufferChar); } - if (charBuilder.Length > 0) + if ((i + 1) % 16 == 0) { - // 2 (between hex and char blocks) + num bytes left (3 per byte) - builder.Append(string.Empty.PadRight(2 + (3 * (16 - charBuilder.Length)))); - // extra for space after 8th byte - if (charBuilder.Length < 8) - builder.Append(" "); + builder.Append(" "); builder.Append(charBuilder.ToString()); + builder.AppendLine(); + charBuilder.Clear(); + } + else if ((i + 1) % 8 == 0) + { + builder.Append(" "); + charBuilder.Append(" "); } - - _logger.LogDebug(builder.ToString()); } - // The below APM methods call the underlying Read/WriteAsync methods which will still be logged. - public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + if (charBuilder.Length > 0) { - return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state); + // 2 (between hex and char blocks) + num bytes left (3 per byte) + builder.Append(string.Empty.PadRight(2 + (3 * (16 - charBuilder.Length)))); + // extra for space after 8th byte + if (charBuilder.Length < 8) + builder.Append(" "); + builder.Append(charBuilder.ToString()); } - public override int EndRead(IAsyncResult asyncResult) - { - return TaskToApm.End(asyncResult); - } + logger.LogDebug(builder.ToString()); + } - public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) - { - return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state); - } + // The below APM methods call the underlying Read/WriteAsync methods which will still be logged. + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return TaskToApm.Begin(ReadAsync(buffer, offset, count), callback, state); + } - public override void EndWrite(IAsyncResult asyncResult) - { - TaskToApm.End(asyncResult); - } + public override int EndRead(IAsyncResult asyncResult) + { + return TaskToApm.End(asyncResult); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return TaskToApm.Begin(WriteAsync(buffer, offset, count), callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + TaskToApm.End(asyncResult); } } diff --git a/src/Bedrock.Framework/Middleware/LoggingConnectionMiddleware.cs b/src/Bedrock.Framework/Middleware/LoggingConnectionMiddleware.cs index 00c4108a..45e620b2 100644 --- a/src/Bedrock.Framework/Middleware/LoggingConnectionMiddleware.cs +++ b/src/Bedrock.Framework/Middleware/LoggingConnectionMiddleware.cs @@ -8,46 +8,33 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal class LoggingConnectionMiddleware(ConnectionDelegate next, ILogger logger, LoggingFormatter loggingFormatter = null) { - internal class LoggingConnectionMiddleware + private readonly ConnectionDelegate _next = next ?? throw new ArgumentNullException(nameof(next)); + private readonly ILogger _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + + public async Task OnConnectionAsync(ConnectionContext context) { - private readonly ConnectionDelegate _next; - private readonly ILogger _logger; - private readonly LoggingFormatter _loggingFormatter; + var oldTransport = context.Transport; - public LoggingConnectionMiddleware(ConnectionDelegate next, ILogger logger, LoggingFormatter loggingFormatter = null) + try { - _next = next ?? throw new ArgumentNullException(nameof(next)); - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - _loggingFormatter = loggingFormatter; - } + await using var loggingDuplexPipe = new LoggingDuplexPipe(context.Transport, _logger, loggingFormatter); - public async Task OnConnectionAsync(ConnectionContext context) - { - var oldTransport = context.Transport; - - try - { - await using (var loggingDuplexPipe = new LoggingDuplexPipe(context.Transport, _logger, _loggingFormatter)) - { - context.Transport = loggingDuplexPipe; - - await _next(context).ConfigureAwait(false); - } - } - finally - { - context.Transport = oldTransport; - } - } + context.Transport = loggingDuplexPipe; - private class LoggingDuplexPipe : DuplexPipeStreamAdapter + await _next(context).ConfigureAwait(false); + } + finally { - public LoggingDuplexPipe(IDuplexPipe transport, ILogger logger, LoggingFormatter loggingFormatter) : - base(transport, stream => new LoggingStream(stream, logger, loggingFormatter)) - { - } + context.Transport = oldTransport; } } + + private class LoggingDuplexPipe(IDuplexPipe transport, ILogger logger, LoggingFormatter loggingFormatter) : + DuplexPipeStreamAdapter(transport, stream => new LoggingStream(stream, logger, loggingFormatter)) + { + } } diff --git a/src/Bedrock.Framework/Middleware/Tls/CertificateLoader.cs b/src/Bedrock.Framework/Middleware/Tls/CertificateLoader.cs index b54727cb..523859e0 100644 --- a/src/Bedrock.Framework/Middleware/Tls/CertificateLoader.cs +++ b/src/Bedrock.Framework/Middleware/Tls/CertificateLoader.cs @@ -4,102 +4,99 @@ using System.Security.Cryptography.X509Certificates; using System.Text; -namespace Bedrock.Framework.Middleware.Tls -{ - public static class CertificateLoader - { - // See http://oid-info.com/get/1.3.6.1.5.5.7.3.1 - // Indicates that a certificate can be used as a TLS server certificate - private const string ServerAuthenticationOid = "1.3.6.1.5.5.7.3.1"; +namespace Bedrock.Framework.Middleware.Tls; - // See http://oid-info.com/get/1.3.6.1.5.5.7.3.2 - // Indicates that a certificate can be used as a TLS client certificate - private const string ClientAuthenticationOid = "1.3.6.1.5.5.7.3.2"; +public static class CertificateLoader +{ + // See http://oid-info.com/get/1.3.6.1.5.5.7.3.1 + // Indicates that a certificate can be used as a TLS server certificate + private const string ServerAuthenticationOid = "1.3.6.1.5.5.7.3.1"; - public static X509Certificate2 LoadFromStoreCert(string subject, string storeName, StoreLocation storeLocation, bool allowInvalid, bool server) - { - using (var store = new X509Store(storeName, storeLocation)) - { - X509Certificate2Collection storeCertificates = null; - X509Certificate2 foundCertificate = null; + // See http://oid-info.com/get/1.3.6.1.5.5.7.3.2 + // Indicates that a certificate can be used as a TLS client certificate + private const string ClientAuthenticationOid = "1.3.6.1.5.5.7.3.2"; - try - { - store.Open(OpenFlags.ReadOnly); - storeCertificates = store.Certificates; - var foundCertificates = storeCertificates.Find(X509FindType.FindBySubjectName, subject, !allowInvalid); - foundCertificate = foundCertificates - .OfType() - .Where(c => server ? IsCertificateAllowedForServerAuth(c) : IsCertificateAllowedForClientAuth(c)) - .Where(DoesCertificateHaveAnAccessiblePrivateKey) - .OrderByDescending(certificate => certificate.NotAfter) - .FirstOrDefault(); + public static X509Certificate2 LoadFromStoreCert(string subject, string storeName, StoreLocation storeLocation, bool allowInvalid, bool server) + { + using var store = new X509Store(storeName, storeLocation); + X509Certificate2Collection storeCertificates = null; + X509Certificate2 foundCertificate = null; - if (foundCertificate == null) - { - throw new InvalidOperationException($"Certificate {subject} not found in store {storeLocation} / {storeName}. AllowInvalid: {allowInvalid}"); - } + try + { + store.Open(OpenFlags.ReadOnly); + storeCertificates = store.Certificates; + var foundCertificates = storeCertificates.Find(X509FindType.FindBySubjectName, subject, !allowInvalid); + foundCertificate = foundCertificates + .OfType() + .Where(c => server ? IsCertificateAllowedForServerAuth(c) : IsCertificateAllowedForClientAuth(c)) + .Where(DoesCertificateHaveAnAccessiblePrivateKey) + .OrderByDescending(certificate => certificate.NotAfter) + .FirstOrDefault(); - return foundCertificate; - } - finally - { - DisposeCertificates(storeCertificates, except: foundCertificate); - } + if (foundCertificate == null) + { + throw new InvalidOperationException($"Certificate {subject} not found in store {storeLocation} / {storeName}. AllowInvalid: {allowInvalid}"); } + + return foundCertificate; + } + finally + { + DisposeCertificates(storeCertificates, except: foundCertificate); } + } - internal static bool IsCertificateAllowedForServerAuth(X509Certificate2 certificate) => IsCertificateAllowedForKeyUsage(certificate, ServerAuthenticationOid); + internal static bool IsCertificateAllowedForServerAuth(X509Certificate2 certificate) => IsCertificateAllowedForKeyUsage(certificate, ServerAuthenticationOid); - internal static bool IsCertificateAllowedForClientAuth(X509Certificate2 certificate) => IsCertificateAllowedForKeyUsage(certificate, ClientAuthenticationOid); + internal static bool IsCertificateAllowedForClientAuth(X509Certificate2 certificate) => IsCertificateAllowedForKeyUsage(certificate, ClientAuthenticationOid); - private static bool IsCertificateAllowedForKeyUsage(X509Certificate2 certificate, string purposeOid) - { - /* If the Extended Key Usage extension is included, then we check that the serverAuth usage is included. (http://oid-info.com/get/1.3.6.1.5.5.7.3.1) - * If the Extended Key Usage extension is not included, then we assume the certificate is allowed for all usages. - * - * See also https://blogs.msdn.microsoft.com/kaushal/2012/02/17/client-certificates-vs-server-certificates/ - * - * From https://tools.ietf.org/html/rfc3280#section-4.2.1.13 "Certificate Extensions: Extended Key Usage" - * - * If the (Extended Key Usage) extension is present, then the certificate MUST only be used - * for one of the purposes indicated. If multiple purposes are - * indicated the application need not recognize all purposes indicated, - * as long as the intended purpose is present. Certificate using - * applications MAY require that a particular purpose be indicated in - * order for the certificate to be acceptable to that application. - */ + private static bool IsCertificateAllowedForKeyUsage(X509Certificate2 certificate, string purposeOid) + { + /* If the Extended Key Usage extension is included, then we check that the serverAuth usage is included. (http://oid-info.com/get/1.3.6.1.5.5.7.3.1) + * If the Extended Key Usage extension is not included, then we assume the certificate is allowed for all usages. + * + * See also https://blogs.msdn.microsoft.com/kaushal/2012/02/17/client-certificates-vs-server-certificates/ + * + * From https://tools.ietf.org/html/rfc3280#section-4.2.1.13 "Certificate Extensions: Extended Key Usage" + * + * If the (Extended Key Usage) extension is present, then the certificate MUST only be used + * for one of the purposes indicated. If multiple purposes are + * indicated the application need not recognize all purposes indicated, + * as long as the intended purpose is present. Certificate using + * applications MAY require that a particular purpose be indicated in + * order for the certificate to be acceptable to that application. + */ - var hasEkuExtension = false; + var hasEkuExtension = false; - foreach (var extension in certificate.Extensions.OfType()) + foreach (var extension in certificate.Extensions.OfType()) + { + hasEkuExtension = true; + foreach (var oid in extension.EnhancedKeyUsages) { - hasEkuExtension = true; - foreach (var oid in extension.EnhancedKeyUsages) + if (oid.Value.Equals(purposeOid, StringComparison.Ordinal)) { - if (oid.Value.Equals(purposeOid, StringComparison.Ordinal)) - { - return true; - } + return true; } } - - return !hasEkuExtension; } - internal static bool DoesCertificateHaveAnAccessiblePrivateKey(X509Certificate2 certificate) - => certificate.HasPrivateKey; + return !hasEkuExtension; + } - private static void DisposeCertificates(X509Certificate2Collection certificates, X509Certificate2 except) + internal static bool DoesCertificateHaveAnAccessiblePrivateKey(X509Certificate2 certificate) + => certificate.HasPrivateKey; + + private static void DisposeCertificates(X509Certificate2Collection certificates, X509Certificate2 except) + { + if (certificates != null) { - if (certificates != null) + foreach (var certificate in certificates) { - foreach (var certificate in certificates) + if (!certificate.Equals(except)) { - if (!certificate.Equals(except)) - { - certificate.Dispose(); - } + certificate.Dispose(); } } } diff --git a/src/Bedrock.Framework/Middleware/Tls/ITlsApplicationProtocolFeature.cs b/src/Bedrock.Framework/Middleware/Tls/ITlsApplicationProtocolFeature.cs index c39f4e37..38b67d75 100644 --- a/src/Bedrock.Framework/Middleware/Tls/ITlsApplicationProtocolFeature.cs +++ b/src/Bedrock.Framework/Middleware/Tls/ITlsApplicationProtocolFeature.cs @@ -2,10 +2,9 @@ using System.Collections.Generic; using System.Text; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +public interface ITlsApplicationProtocolFeature { - public interface ITlsApplicationProtocolFeature - { - ReadOnlyMemory ApplicationProtocol { get; } - } + ReadOnlyMemory ApplicationProtocol { get; } } diff --git a/src/Bedrock.Framework/Middleware/Tls/ITlsConnectionFeature.cs b/src/Bedrock.Framework/Middleware/Tls/ITlsConnectionFeature.cs index e67215d5..90a79b08 100644 --- a/src/Bedrock.Framework/Middleware/Tls/ITlsConnectionFeature.cs +++ b/src/Bedrock.Framework/Middleware/Tls/ITlsConnectionFeature.cs @@ -5,19 +5,18 @@ using System.Threading; using System.Threading.Tasks; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +public interface ITlsConnectionFeature { - public interface ITlsConnectionFeature - { - /// - /// Synchronously retrieves the remote endpoint's certificate, if any. - /// - X509Certificate2 RemoteCertificate { get; set; } + /// + /// Synchronously retrieves the remote endpoint's certificate, if any. + /// + X509Certificate2 RemoteCertificate { get; set; } - /// - /// Asynchronously retrieves the remote endpoint's certificate, if any. - /// - /// - Task GetRemoteCertificateAsync(CancellationToken cancellationToken); - } + /// + /// Asynchronously retrieves the remote endpoint's certificate, if any. + /// + /// + Task GetRemoteCertificateAsync(CancellationToken cancellationToken); } diff --git a/src/Bedrock.Framework/Middleware/Tls/ITlsHandshakeFeature.cs b/src/Bedrock.Framework/Middleware/Tls/ITlsHandshakeFeature.cs index 8b1c378e..a169f558 100644 --- a/src/Bedrock.Framework/Middleware/Tls/ITlsHandshakeFeature.cs +++ b/src/Bedrock.Framework/Middleware/Tls/ITlsHandshakeFeature.cs @@ -3,22 +3,21 @@ using System.Security.Authentication; using System.Text; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +public interface ITlsHandshakeFeature { - public interface ITlsHandshakeFeature - { - SslProtocols Protocol { get; } + SslProtocols Protocol { get; } - CipherAlgorithmType CipherAlgorithm { get; } + CipherAlgorithmType CipherAlgorithm { get; } - int CipherStrength { get; } + int CipherStrength { get; } - HashAlgorithmType HashAlgorithm { get; } + HashAlgorithmType HashAlgorithm { get; } - int HashStrength { get; } + int HashStrength { get; } - ExchangeAlgorithmType KeyExchangeAlgorithm { get; } + ExchangeAlgorithmType KeyExchangeAlgorithm { get; } - int KeyExchangeStrength { get; } - } + int KeyExchangeStrength { get; } } diff --git a/src/Bedrock.Framework/Middleware/Tls/RemoteCertificateMode.cs b/src/Bedrock.Framework/Middleware/Tls/RemoteCertificateMode.cs index c22fd698..57badc52 100644 --- a/src/Bedrock.Framework/Middleware/Tls/RemoteCertificateMode.cs +++ b/src/Bedrock.Framework/Middleware/Tls/RemoteCertificateMode.cs @@ -2,26 +2,25 @@ using System.Collections.Generic; using System.Text; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +/// +/// Describes the remote certificate requirements for a TLS connection. +/// +public enum RemoteCertificateMode { /// - /// Describes the remote certificate requirements for a TLS connection. + /// A remote certificate is not required and will not be requested from remote endpoints. /// - public enum RemoteCertificateMode - { - /// - /// A remote certificate is not required and will not be requested from remote endpoints. - /// - NoCertificate, + NoCertificate, - /// - /// A remote certificate will be requested; however, authentication will not fail if a certificate is not provided by the remote endpoint. - /// - AllowCertificate, + /// + /// A remote certificate will be requested; however, authentication will not fail if a certificate is not provided by the remote endpoint. + /// + AllowCertificate, - /// - /// A remote certificate will be requested, and the remote endpoint must provide a valid certificate for authentication. - /// - RequireCertificate - } + /// + /// A remote certificate will be requested, and the remote endpoint must provide a valid certificate for authentication. + /// + RequireCertificate } diff --git a/src/Bedrock.Framework/Middleware/Tls/SslDuplexPipe.cs b/src/Bedrock.Framework/Middleware/Tls/SslDuplexPipe.cs index 7cdb00a1..1633ba33 100644 --- a/src/Bedrock.Framework/Middleware/Tls/SslDuplexPipe.cs +++ b/src/Bedrock.Framework/Middleware/Tls/SslDuplexPipe.cs @@ -1,24 +1,21 @@ using System; -using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Net.Security; -using System.Text; using Bedrock.Framework.Infrastructure; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +internal class SslDuplexPipe : DuplexPipeStreamAdapter { - internal class SslDuplexPipe : DuplexPipeStreamAdapter + public SslDuplexPipe(IDuplexPipe transport, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions) + : this(transport, readerOptions, writerOptions, s => new SslStream(s)) { - public SslDuplexPipe(IDuplexPipe transport, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions) - : this(transport, readerOptions, writerOptions, s => new SslStream(s)) - { - } + } - public SslDuplexPipe(IDuplexPipe transport, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions, Func factory) : - base(transport, readerOptions, writerOptions, factory) - { - } + public SslDuplexPipe(IDuplexPipe transport, StreamPipeReaderOptions readerOptions, StreamPipeWriterOptions writerOptions, Func factory) : + base(transport, readerOptions, writerOptions, factory) + { } } diff --git a/src/Bedrock.Framework/Middleware/Tls/TlsClientConnectionMiddleware.cs b/src/Bedrock.Framework/Middleware/Tls/TlsClientConnectionMiddleware.cs index 8136f29e..cc894175 100644 --- a/src/Bedrock.Framework/Middleware/Tls/TlsClientConnectionMiddleware.cs +++ b/src/Bedrock.Framework/Middleware/Tls/TlsClientConnectionMiddleware.cs @@ -13,186 +13,185 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +internal class TlsClientConnectionMiddleware { - internal class TlsClientConnectionMiddleware - { - private readonly ConnectionDelegate _next; - private readonly TlsOptions _options; - private readonly ILogger _logger; - private readonly X509Certificate2 _certificate; + private readonly ConnectionDelegate _next; + private readonly TlsOptions _options; + private readonly ILogger _logger; + private readonly X509Certificate2 _certificate; - public TlsClientConnectionMiddleware(ConnectionDelegate next, TlsOptions options, ILoggerFactory loggerFactory) + public TlsClientConnectionMiddleware(ConnectionDelegate next, TlsOptions options, ILoggerFactory loggerFactory) + { + if (options == null) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - - _next = next; + throw new ArgumentNullException(nameof(options)); + } - // capture the certificate now so it can't be switched after validation - _certificate = options.LocalCertificate; + _next = next; - if (_certificate != null) - { - EnsureCertificateIsAllowedForClientAuth(_certificate); - } + // capture the certificate now so it can't be switched after validation + _certificate = options.LocalCertificate; - _options = options; - _logger = loggerFactory?.CreateLogger(); + if (_certificate != null) + { + EnsureCertificateIsAllowedForClientAuth(_certificate); } - public async Task OnConnectionAsync(ConnectionContext context) - { - await Task.Yield(); + _options = options; + _logger = loggerFactory?.CreateLogger(); + } - var feature = new TlsConnectionFeature(); - context.Features.Set(feature); - context.Features.Set(feature); + public async Task OnConnectionAsync(ConnectionContext context) + { + await Task.Yield(); - var memoryPool = context.Features.Get()?.MemoryPool; + var feature = new TlsConnectionFeature(); + context.Features.Set(feature); + context.Features.Set(feature); - var inputPipeOptions = new StreamPipeReaderOptions - ( - pool: memoryPool, - bufferSize: memoryPool.GetMinimumSegmentSize(), - minimumReadSize: memoryPool.GetMinimumAllocSize(), - leaveOpen: true - ); + var memoryPool = context.Features.Get()?.MemoryPool; - var outputPipeOptions = new StreamPipeWriterOptions - ( - pool: memoryPool, - leaveOpen: true - ); + var inputPipeOptions = new StreamPipeReaderOptions + ( + pool: memoryPool, + bufferSize: memoryPool.GetMinimumSegmentSize(), + minimumReadSize: memoryPool.GetMinimumAllocSize(), + leaveOpen: true + ); - SslDuplexPipe sslDuplexPipe = null; + var outputPipeOptions = new StreamPipeWriterOptions + ( + pool: memoryPool, + leaveOpen: true + ); - if (_options.RemoteCertificateMode == RemoteCertificateMode.NoCertificate) - { - sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions); - } - else - { - sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions, s => new SslStream( - s, - leaveInnerStreamOpen: false, - userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => - { - if (certificate == null) - { - return _options.RemoteCertificateMode != RemoteCertificateMode.RequireCertificate; - } + SslDuplexPipe sslDuplexPipe = null; - if (_options.RemoteCertificateValidation == null) - { - if (sslPolicyErrors != SslPolicyErrors.None) - { - return false; - } - } + if (_options.RemoteCertificateMode == RemoteCertificateMode.NoCertificate) + { + sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions); + } + else + { + sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions, s => new SslStream( + s, + leaveInnerStreamOpen: false, + userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => + { + if (certificate == null) + { + return _options.RemoteCertificateMode != RemoteCertificateMode.RequireCertificate; + } - var certificate2 = ConvertToX509Certificate2(certificate); - if (certificate2 == null) + if (_options.RemoteCertificateValidation == null) + { + if (sslPolicyErrors != SslPolicyErrors.None) { return false; } + } + + var certificate2 = ConvertToX509Certificate2(certificate); + if (certificate2 == null) + { + return false; + } - if (_options.RemoteCertificateValidation != null) + if (_options.RemoteCertificateValidation != null) + { + if (!_options.RemoteCertificateValidation(certificate2, chain, sslPolicyErrors)) { - if (!_options.RemoteCertificateValidation(certificate2, chain, sslPolicyErrors)) - { - return false; - } + return false; } + } - return true; - })); - } + return true; + })); + } - var sslStream = sslDuplexPipe.Stream; + var sslStream = sslDuplexPipe.Stream; - using (var cancellationTokeSource = new CancellationTokenSource(_options.HandshakeTimeout)) + using (var cancellationTokeSource = new CancellationTokenSource(_options.HandshakeTimeout)) + { + try { - try + var sslOptions = new SslClientAuthenticationOptions { - var sslOptions = new SslClientAuthenticationOptions - { - ClientCertificates = new X509CertificateCollection(new[] { _certificate }), - EnabledSslProtocols = _options.SslProtocols, - CertificateRevocationCheckMode = _options.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - ApplicationProtocols = new List(), - }; + ClientCertificates = new X509CertificateCollection(new[] { _certificate }), + EnabledSslProtocols = _options.SslProtocols, + CertificateRevocationCheckMode = _options.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + ApplicationProtocols = new List(), + }; - _options.OnAuthenticateAsClient?.Invoke(context, sslOptions); + _options.OnAuthenticateAsClient?.Invoke(context, sslOptions); - await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationTokeSource.Token).ConfigureAwait(false); - } - catch (OperationCanceledException) - { - _logger?.LogDebug(2, "Authentication timed out"); - await sslStream.DisposeAsync().ConfigureAwait(false); - return; - } - catch (Exception ex) when (ex is IOException || ex is AuthenticationException) - { - _logger?.LogDebug(1, ex, "Authentication failed"); - await sslStream.DisposeAsync().ConfigureAwait(false); - return; - } + await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationTokeSource.Token).ConfigureAwait(false); } - - feature.ApplicationProtocol = sslStream.NegotiatedApplicationProtocol.Protocol; - context.Features.Set(feature); - feature.LocalCertificate = ConvertToX509Certificate2(sslStream.LocalCertificate); - feature.RemoteCertificate = ConvertToX509Certificate2(sslStream.RemoteCertificate); - feature.CipherAlgorithm = sslStream.CipherAlgorithm; - feature.CipherStrength = sslStream.CipherStrength; - feature.HashAlgorithm = sslStream.HashAlgorithm; - feature.HashStrength = sslStream.HashStrength; - feature.KeyExchangeAlgorithm = sslStream.KeyExchangeAlgorithm; - feature.KeyExchangeStrength = sslStream.KeyExchangeStrength; - feature.Protocol = sslStream.SslProtocol; - - var originalTransport = context.Transport; - - try + catch (OperationCanceledException) { - context.Transport = sslDuplexPipe; - - // Disposing the stream will dispose the sslDuplexPipe - await using (sslStream) - await using (sslDuplexPipe) - { - await _next(context).ConfigureAwait(false); - // Dispose the inner stream (SslDuplexPipe) before disposing the SslStream - // as the duplex pipe can hit an ODE as it still may be writing. - } + _logger?.LogDebug(2, "Authentication timed out"); + await sslStream.DisposeAsync().ConfigureAwait(false); + return; } - finally + catch (Exception ex) when (ex is IOException || ex is AuthenticationException) { - // Restore the original so that it gets closed appropriately - context.Transport = originalTransport; + _logger?.LogDebug(1, ex, "Authentication failed"); + await sslStream.DisposeAsync().ConfigureAwait(false); + return; } } - protected static void EnsureCertificateIsAllowedForClientAuth(X509Certificate2 certificate) + feature.ApplicationProtocol = sslStream.NegotiatedApplicationProtocol.Protocol; + context.Features.Set(feature); + feature.LocalCertificate = ConvertToX509Certificate2(sslStream.LocalCertificate); + feature.RemoteCertificate = ConvertToX509Certificate2(sslStream.RemoteCertificate); + feature.CipherAlgorithm = sslStream.CipherAlgorithm; + feature.CipherStrength = sslStream.CipherStrength; + feature.HashAlgorithm = sslStream.HashAlgorithm; + feature.HashStrength = sslStream.HashStrength; + feature.KeyExchangeAlgorithm = sslStream.KeyExchangeAlgorithm; + feature.KeyExchangeStrength = sslStream.KeyExchangeStrength; + feature.Protocol = sslStream.SslProtocol; + + var originalTransport = context.Transport; + + try { - if (!CertificateLoader.IsCertificateAllowedForClientAuth(certificate)) + context.Transport = sslDuplexPipe; + + // Disposing the stream will dispose the sslDuplexPipe + await using (sslStream) + await using (sslDuplexPipe) { - throw new InvalidOperationException($"Invalid client certificate for client authentication: {certificate.Thumbprint}"); + await _next(context).ConfigureAwait(false); + // Dispose the inner stream (SslDuplexPipe) before disposing the SslStream + // as the duplex pipe can hit an ODE as it still may be writing. } } + finally + { + // Restore the original so that it gets closed appropriately + context.Transport = originalTransport; + } + } - private static X509Certificate2 ConvertToX509Certificate2(X509Certificate certificate) + protected static void EnsureCertificateIsAllowedForClientAuth(X509Certificate2 certificate) + { + if (!CertificateLoader.IsCertificateAllowedForClientAuth(certificate)) { - if (certificate is null) - { - return null; - } + throw new InvalidOperationException($"Invalid client certificate for client authentication: {certificate.Thumbprint}"); + } + } - return certificate as X509Certificate2 ?? new X509Certificate2(certificate); + private static X509Certificate2 ConvertToX509Certificate2(X509Certificate certificate) + { + if (certificate is null) + { + return null; } + + return certificate as X509Certificate2 ?? new X509Certificate2(certificate); } } diff --git a/src/Bedrock.Framework/Middleware/Tls/TlsConnectionFeature.cs b/src/Bedrock.Framework/Middleware/Tls/TlsConnectionFeature.cs index e5b12789..cb8cd433 100644 --- a/src/Bedrock.Framework/Middleware/Tls/TlsConnectionFeature.cs +++ b/src/Bedrock.Framework/Middleware/Tls/TlsConnectionFeature.cs @@ -6,33 +6,32 @@ using System.Threading; using System.Threading.Tasks; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +internal class TlsConnectionFeature : ITlsConnectionFeature, ITlsApplicationProtocolFeature, ITlsHandshakeFeature { - internal class TlsConnectionFeature : ITlsConnectionFeature, ITlsApplicationProtocolFeature, ITlsHandshakeFeature - { - public X509Certificate2 LocalCertificate { get; set; } + public X509Certificate2 LocalCertificate { get; set; } - public X509Certificate2 RemoteCertificate { get; set; } + public X509Certificate2 RemoteCertificate { get; set; } - public ReadOnlyMemory ApplicationProtocol { get; set; } + public ReadOnlyMemory ApplicationProtocol { get; set; } - public SslProtocols Protocol { get; set; } + public SslProtocols Protocol { get; set; } - public CipherAlgorithmType CipherAlgorithm { get; set; } + public CipherAlgorithmType CipherAlgorithm { get; set; } - public int CipherStrength { get; set; } + public int CipherStrength { get; set; } - public HashAlgorithmType HashAlgorithm { get; set; } + public HashAlgorithmType HashAlgorithm { get; set; } - public int HashStrength { get; set; } + public int HashStrength { get; set; } - public ExchangeAlgorithmType KeyExchangeAlgorithm { get; set; } + public ExchangeAlgorithmType KeyExchangeAlgorithm { get; set; } - public int KeyExchangeStrength { get; set; } + public int KeyExchangeStrength { get; set; } - public Task GetRemoteCertificateAsync(CancellationToken cancellationToken) - { - return Task.FromResult(RemoteCertificate); - } + public Task GetRemoteCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(RemoteCertificate); } } diff --git a/src/Bedrock.Framework/Middleware/Tls/TlsOptions.cs b/src/Bedrock.Framework/Middleware/Tls/TlsOptions.cs index e9c833ff..b92dd318 100644 --- a/src/Bedrock.Framework/Middleware/Tls/TlsOptions.cs +++ b/src/Bedrock.Framework/Middleware/Tls/TlsOptions.cs @@ -6,107 +6,106 @@ using System.Threading; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +public delegate bool RemoteCertificateValidator(X509Certificate2 certificate, X509Chain chain, SslPolicyErrors policyErrors); + +/// +/// Settings for how TLS connections are handled. +/// +public class TlsOptions { - public delegate bool RemoteCertificateValidator(X509Certificate2 certificate, X509Chain chain, SslPolicyErrors policyErrors); + private TimeSpan _handshakeTimeout; /// - /// Settings for how TLS connections are handled. + /// Initializes a new instance of . /// - public class TlsOptions + public TlsOptions() { - private TimeSpan _handshakeTimeout; - - /// - /// Initializes a new instance of . - /// - public TlsOptions() - { - RemoteCertificateMode = RemoteCertificateMode.RequireCertificate; - HandshakeTimeout = TimeSpan.FromSeconds(10); - } + RemoteCertificateMode = RemoteCertificateMode.RequireCertificate; + HandshakeTimeout = TimeSpan.FromSeconds(10); + } - /// - /// - /// Specifies the local certificate used to authenticate TLS connections. This is ignored if LocalCertificateSelector is set. - /// - /// - /// If the certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). - /// - /// - public X509Certificate2? LocalCertificate { get; set; } + /// + /// + /// Specifies the local certificate used to authenticate TLS connections. This is ignored if LocalCertificateSelector is set. + /// + /// + /// If the certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + /// + public X509Certificate2? LocalCertificate { get; set; } - /// - /// - /// A callback that will be invoked to dynamically select a local server certificate. This is higher priority than LocalCertificate. - /// If SNI is not available then the name parameter will be null. - /// - /// - /// If the certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). - /// - /// - public Func? LocalServerCertificateSelector { get; set; } + /// + /// + /// A callback that will be invoked to dynamically select a local server certificate. This is higher priority than LocalCertificate. + /// If SNI is not available then the name parameter will be null. + /// + /// + /// If the certificate has an Extended Key Usage extension, the usages must include Server Authentication (OID 1.3.6.1.5.5.7.3.1). + /// + /// + public Func? LocalServerCertificateSelector { get; set; } - /// - /// Specifies the remote endpoint certificate requirements for a TLS connection. Defaults to . - /// - public RemoteCertificateMode RemoteCertificateMode { get; set; } + /// + /// Specifies the remote endpoint certificate requirements for a TLS connection. Defaults to . + /// + public RemoteCertificateMode RemoteCertificateMode { get; set; } - /// - /// Specifies a callback for additional remote certificate validation that will be invoked during authentication. This will be ignored - /// if is called after this callback is set. - /// - public RemoteCertificateValidator? RemoteCertificateValidation { get; set; } + /// + /// Specifies a callback for additional remote certificate validation that will be invoked during authentication. This will be ignored + /// if is called after this callback is set. + /// + public RemoteCertificateValidator? RemoteCertificateValidation { get; set; } - /// - /// Specifies allowable SSL protocols. Defaults to and . - /// - public SslProtocols SslProtocols { get; set; } + /// + /// Specifies allowable SSL protocols. Defaults to and . + /// + public SslProtocols SslProtocols { get; set; } - /// - /// Specifies whether the certificate revocation list is checked during authentication. - /// - public bool CheckCertificateRevocation { get; set; } + /// + /// Specifies whether the certificate revocation list is checked during authentication. + /// + public bool CheckCertificateRevocation { get; set; } - /// - /// Specifies the cipher suites allowed for TLS. When set to null, the operating system default is used. - /// - public CipherSuitesPolicy? CipherSuitesPolicy { get; set; } + /// + /// Specifies the cipher suites allowed for TLS. When set to null, the operating system default is used. + /// + public CipherSuitesPolicy? CipherSuitesPolicy { get; set; } - /// - /// Overrides the current callback and allows any client certificate. - /// - public void AllowAnyRemoteCertificate() - { - RemoteCertificateValidation = (_, __, ___) => true; - } + /// + /// Overrides the current callback and allows any client certificate. + /// + public void AllowAnyRemoteCertificate() + { + RemoteCertificateValidation = (_, __, ___) => true; + } - /// - /// Provides direct configuration of the on a per-connection basis. - /// This is called after all of the other settings have already been applied. - /// - public Action? OnAuthenticateAsServer { get; set; } + /// + /// Provides direct configuration of the on a per-connection basis. + /// This is called after all of the other settings have already been applied. + /// + public Action? OnAuthenticateAsServer { get; set; } - /// - /// Provides direct configuration of the on a per-connection basis. - /// This is called after all of the other settings have already been applied. - /// - public Action? OnAuthenticateAsClient { get; set; } + /// + /// Provides direct configuration of the on a per-connection basis. + /// This is called after all of the other settings have already been applied. + /// + public Action? OnAuthenticateAsClient { get; set; } - /// - /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive and finite. - /// - public TimeSpan HandshakeTimeout + /// + /// Specifies the maximum amount of time allowed for the TLS/SSL handshake. This must be positive and finite. + /// + public TimeSpan HandshakeTimeout + { + get => _handshakeTimeout; + set { - get => _handshakeTimeout; - set + if (value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) { - if (value <= TimeSpan.Zero && value != Timeout.InfiniteTimeSpan) - { - throw new ArgumentOutOfRangeException(nameof(value), nameof(HandshakeTimeout) + " must be positive"); - } - _handshakeTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; + throw new ArgumentOutOfRangeException(nameof(value), nameof(HandshakeTimeout) + " must be positive"); } + _handshakeTimeout = value != Timeout.InfiniteTimeSpan ? value : TimeSpan.MaxValue; } } } diff --git a/src/Bedrock.Framework/Middleware/Tls/TlsServerConnectionMiddleware.cs b/src/Bedrock.Framework/Middleware/Tls/TlsServerConnectionMiddleware.cs index 4524f0b3..359b63d2 100644 --- a/src/Bedrock.Framework/Middleware/Tls/TlsServerConnectionMiddleware.cs +++ b/src/Bedrock.Framework/Middleware/Tls/TlsServerConnectionMiddleware.cs @@ -13,224 +13,223 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework.Middleware.Tls +namespace Bedrock.Framework.Middleware.Tls; + +internal class TlsServerConnectionMiddleware { - internal class TlsServerConnectionMiddleware - { - private readonly ConnectionDelegate _next; - private readonly TlsOptions _options; - private readonly ILogger _logger; - private readonly X509Certificate2 _certificate; - private readonly Func _certificateSelector; + private readonly ConnectionDelegate _next; + private readonly TlsOptions _options; + private readonly ILogger _logger; + private readonly X509Certificate2 _certificate; + private readonly Func _certificateSelector; - public TlsServerConnectionMiddleware(ConnectionDelegate next, TlsOptions options, ILoggerFactory loggerFactory) + public TlsServerConnectionMiddleware(ConnectionDelegate next, TlsOptions options, ILoggerFactory loggerFactory) + { + if (options == null) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } + throw new ArgumentNullException(nameof(options)); + } - _next = next; + _next = next; - // capture the certificate now so it can't be switched after validation - _certificate = options.LocalCertificate; - _certificateSelector = options.LocalServerCertificateSelector; - if (_certificate == null && _certificateSelector == null) - { - throw new ArgumentException("Server certificate is required", nameof(options)); - } - - // If a selector is provided then ignore the cert, it may be a default cert. - if (_certificateSelector != null) - { - // SslStream doesn't allow both. - _certificate = null; - } - else - { - EnsureCertificateIsAllowedForServerAuth(_certificate); - } - - _options = options; - _logger = loggerFactory?.CreateLogger(); + // capture the certificate now so it can't be switched after validation + _certificate = options.LocalCertificate; + _certificateSelector = options.LocalServerCertificateSelector; + if (_certificate == null && _certificateSelector == null) + { + throw new ArgumentException("Server certificate is required", nameof(options)); } - public Task OnConnectionAsync(ConnectionContext context) + // If a selector is provided then ignore the cert, it may be a default cert. + if (_certificateSelector != null) { - return Task.Run(() => InnerOnConnectionAsync(context)); + // SslStream doesn't allow both. + _certificate = null; } + else + { + EnsureCertificateIsAllowedForServerAuth(_certificate); + } + + _options = options; + _logger = loggerFactory?.CreateLogger(); + } - private async Task InnerOnConnectionAsync(ConnectionContext context) + public Task OnConnectionAsync(ConnectionContext context) + { + return Task.Run(() => InnerOnConnectionAsync(context)); + } + + private async Task InnerOnConnectionAsync(ConnectionContext context) + { + bool certificateRequired; + var feature = new TlsConnectionFeature(); + context.Features.Set(feature); + context.Features.Set(feature); + + var memoryPool = context.Features.Get()?.MemoryPool; + + var inputPipeOptions = new StreamPipeReaderOptions + ( + pool: memoryPool, + bufferSize: memoryPool.GetMinimumSegmentSize(), + minimumReadSize: memoryPool.GetMinimumAllocSize(), + leaveOpen: true + ); + + var outputPipeOptions = new StreamPipeWriterOptions + ( + pool: memoryPool, + leaveOpen: true + ); + + SslDuplexPipe sslDuplexPipe = null; + + if (_options.RemoteCertificateMode == RemoteCertificateMode.NoCertificate) { - bool certificateRequired; - var feature = new TlsConnectionFeature(); - context.Features.Set(feature); - context.Features.Set(feature); - - var memoryPool = context.Features.Get()?.MemoryPool; - - var inputPipeOptions = new StreamPipeReaderOptions - ( - pool: memoryPool, - bufferSize: memoryPool.GetMinimumSegmentSize(), - minimumReadSize: memoryPool.GetMinimumAllocSize(), - leaveOpen: true - ); - - var outputPipeOptions = new StreamPipeWriterOptions - ( - pool: memoryPool, - leaveOpen: true - ); - - SslDuplexPipe sslDuplexPipe = null; - - if (_options.RemoteCertificateMode == RemoteCertificateMode.NoCertificate) - { - sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions); - certificateRequired = false; - } - else - { - sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions, s => new SslStream( - s, - leaveInnerStreamOpen: false, - userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => + sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions); + certificateRequired = false; + } + else + { + sslDuplexPipe = new SslDuplexPipe(context.Transport, inputPipeOptions, outputPipeOptions, s => new SslStream( + s, + leaveInnerStreamOpen: false, + userCertificateValidationCallback: (sender, certificate, chain, sslPolicyErrors) => + { + if (certificate == null) { - if (certificate == null) - { - return _options.RemoteCertificateMode != RemoteCertificateMode.RequireCertificate; - } - - if (_options.RemoteCertificateValidation == null) - { - if (sslPolicyErrors != SslPolicyErrors.None) - { - return false; - } - } + return _options.RemoteCertificateMode != RemoteCertificateMode.RequireCertificate; + } - var certificate2 = ConvertToX509Certificate2(certificate); - if (certificate2 == null) + if (_options.RemoteCertificateValidation == null) + { + if (sslPolicyErrors != SslPolicyErrors.None) { return false; } + } + + var certificate2 = ConvertToX509Certificate2(certificate); + if (certificate2 == null) + { + return false; + } - if (_options.RemoteCertificateValidation != null) + if (_options.RemoteCertificateValidation != null) + { + if (!_options.RemoteCertificateValidation(certificate2, chain, sslPolicyErrors)) { - if (!_options.RemoteCertificateValidation(certificate2, chain, sslPolicyErrors)) - { - return false; - } + return false; } + } - return true; - })); + return true; + })); - certificateRequired = true; - } + certificateRequired = true; + } - var sslStream = sslDuplexPipe.Stream; + var sslStream = sslDuplexPipe.Stream; - using (var cancellationTokeSource = new CancellationTokenSource(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : _options.HandshakeTimeout)) + using (var cancellationTokeSource = new CancellationTokenSource(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : _options.HandshakeTimeout)) + { + try { - try + // Adapt to the SslStream signature + ServerCertificateSelectionCallback selector = null; + if (_certificateSelector != null) { - // Adapt to the SslStream signature - ServerCertificateSelectionCallback selector = null; - if (_certificateSelector != null) + selector = (sender, name) => { - selector = (sender, name) => + context.Features.Set(sslStream); + var cert = _certificateSelector(context, name); + if (cert != null) { - context.Features.Set(sslStream); - var cert = _certificateSelector(context, name); - if (cert != null) - { - EnsureCertificateIsAllowedForServerAuth(cert); - } - return cert; - }; - } - - var sslOptions = new SslServerAuthenticationOptions - { - ServerCertificate = _certificate, - ServerCertificateSelectionCallback = selector, - ClientCertificateRequired = certificateRequired, - EnabledSslProtocols = _options.SslProtocols, - CertificateRevocationCheckMode = _options.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, - ApplicationProtocols = new List(), - CipherSuitesPolicy = _options.CipherSuitesPolicy + EnsureCertificateIsAllowedForServerAuth(cert); + } + return cert; }; - - _options.OnAuthenticateAsServer?.Invoke(context, sslOptions); - - await sslStream.AuthenticateAsServerAsync(sslOptions, cancellationTokeSource.Token).ConfigureAwait(false); } - catch (OperationCanceledException) - { - _logger?.LogDebug(2, "Authentication timed out"); - await sslStream.DisposeAsync().ConfigureAwait(false); - return; - } - catch (Exception ex) when (ex is IOException || ex is AuthenticationException) + + var sslOptions = new SslServerAuthenticationOptions { - _logger?.LogDebug(1, ex, "Authentication failed"); - await sslStream.DisposeAsync().ConfigureAwait(false); - return; - } + ServerCertificate = _certificate, + ServerCertificateSelectionCallback = selector, + ClientCertificateRequired = certificateRequired, + EnabledSslProtocols = _options.SslProtocols, + CertificateRevocationCheckMode = _options.CheckCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + ApplicationProtocols = new List(), + CipherSuitesPolicy = _options.CipherSuitesPolicy + }; + + _options.OnAuthenticateAsServer?.Invoke(context, sslOptions); + + await sslStream.AuthenticateAsServerAsync(sslOptions, cancellationTokeSource.Token).ConfigureAwait(false); } - - feature.ApplicationProtocol = sslStream.NegotiatedApplicationProtocol.Protocol; - context.Features.Set(feature); - feature.LocalCertificate = ConvertToX509Certificate2(sslStream.LocalCertificate); - feature.RemoteCertificate = ConvertToX509Certificate2(sslStream.RemoteCertificate); - feature.CipherAlgorithm = sslStream.CipherAlgorithm; - feature.CipherStrength = sslStream.CipherStrength; - feature.HashAlgorithm = sslStream.HashAlgorithm; - feature.HashStrength = sslStream.HashStrength; - feature.KeyExchangeAlgorithm = sslStream.KeyExchangeAlgorithm; - feature.KeyExchangeStrength = sslStream.KeyExchangeStrength; - feature.Protocol = sslStream.SslProtocol; - - var originalTransport = context.Transport; - - try + catch (OperationCanceledException) { - context.Transport = sslDuplexPipe; - - // Disposing the stream will dispose the sslDuplexPipe - await using (sslStream) - await using (sslDuplexPipe) - { - await _next(context).ConfigureAwait(false); - // Dispose the inner stream (SslDuplexPipe) before disposing the SslStream - // as the duplex pipe can hit an ODE as it still may be writing. - } + _logger?.LogDebug(2, "Authentication timed out"); + await sslStream.DisposeAsync().ConfigureAwait(false); + return; } - finally + catch (Exception ex) when (ex is IOException || ex is AuthenticationException) { - // Restore the original so that it gets closed appropriately - context.Transport = originalTransport; + _logger?.LogDebug(1, ex, "Authentication failed"); + await sslStream.DisposeAsync().ConfigureAwait(false); + return; } } - protected static void EnsureCertificateIsAllowedForServerAuth(X509Certificate2 certificate) + feature.ApplicationProtocol = sslStream.NegotiatedApplicationProtocol.Protocol; + context.Features.Set(feature); + feature.LocalCertificate = ConvertToX509Certificate2(sslStream.LocalCertificate); + feature.RemoteCertificate = ConvertToX509Certificate2(sslStream.RemoteCertificate); + feature.CipherAlgorithm = sslStream.CipherAlgorithm; + feature.CipherStrength = sslStream.CipherStrength; + feature.HashAlgorithm = sslStream.HashAlgorithm; + feature.HashStrength = sslStream.HashStrength; + feature.KeyExchangeAlgorithm = sslStream.KeyExchangeAlgorithm; + feature.KeyExchangeStrength = sslStream.KeyExchangeStrength; + feature.Protocol = sslStream.SslProtocol; + + var originalTransport = context.Transport; + + try { - if (!CertificateLoader.IsCertificateAllowedForServerAuth(certificate)) + context.Transport = sslDuplexPipe; + + // Disposing the stream will dispose the sslDuplexPipe + await using (sslStream) + await using (sslDuplexPipe) { - throw new InvalidOperationException($"Invalid server certificate for server authentication: {certificate.Thumbprint}"); + await _next(context).ConfigureAwait(false); + // Dispose the inner stream (SslDuplexPipe) before disposing the SslStream + // as the duplex pipe can hit an ODE as it still may be writing. } } + finally + { + // Restore the original so that it gets closed appropriately + context.Transport = originalTransport; + } + } - private static X509Certificate2 ConvertToX509Certificate2(X509Certificate certificate) + protected static void EnsureCertificateIsAllowedForServerAuth(X509Certificate2 certificate) + { + if (!CertificateLoader.IsCertificateAllowedForServerAuth(certificate)) { - if (certificate is null) - { - return null; - } + throw new InvalidOperationException($"Invalid server certificate for server authentication: {certificate.Thumbprint}"); + } + } - return certificate as X509Certificate2 ?? new X509Certificate2(certificate); + private static X509Certificate2 ConvertToX509Certificate2(X509Certificate certificate) + { + if (certificate is null) + { + return null; } + + return certificate as X509Certificate2 ?? new X509Certificate2(certificate); } } diff --git a/src/Bedrock.Framework/Server/EndPointBinding.cs b/src/Bedrock.Framework/Server/EndPointBinding.cs index 81844bef..12c9b49d 100644 --- a/src/Bedrock.Framework/Server/EndPointBinding.cs +++ b/src/Bedrock.Framework/Server/EndPointBinding.cs @@ -4,31 +4,19 @@ using System.Threading; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework -{ - public class EndPointBinding : ServerBinding - { - private readonly ConnectionDelegate _application; - public EndPointBinding(EndPoint endPoint, ConnectionDelegate application, IConnectionListenerFactory connectionListenerFactory) - { - EndPoint = endPoint; - _application = application; - ConnectionListenerFactory = connectionListenerFactory; - } +namespace Bedrock.Framework; - private EndPoint EndPoint { get; } - private IConnectionListenerFactory ConnectionListenerFactory { get; } - - public override ConnectionDelegate Application => _application; +public class EndPointBinding(EndPoint endPoint, ConnectionDelegate application, IConnectionListenerFactory connectionListenerFactory) : ServerBinding +{ + public override ConnectionDelegate Application => application; - public override async IAsyncEnumerable BindAsync([EnumeratorCancellation]CancellationToken cancellationToken) - { - yield return await ConnectionListenerFactory.BindAsync(EndPoint, cancellationToken); - } + public override async IAsyncEnumerable BindAsync([EnumeratorCancellation]CancellationToken cancellationToken) + { + yield return await connectionListenerFactory.BindAsync(endPoint, cancellationToken); + } - public override string ToString() - { - return EndPoint?.ToString(); - } + public override string ToString() + { + return endPoint?.ToString(); } } diff --git a/src/Bedrock.Framework/Server/LocalHostBinding.cs b/src/Bedrock.Framework/Server/LocalHostBinding.cs index 60adc6b2..1a749168 100644 --- a/src/Bedrock.Framework/Server/LocalHostBinding.cs +++ b/src/Bedrock.Framework/Server/LocalHostBinding.cs @@ -6,68 +6,55 @@ using System.Threading; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class LocalHostBinding(int port, ConnectionDelegate application, IConnectionListenerFactory connectionListenerFactory) : ServerBinding { - public class LocalHostBinding : ServerBinding + public override ConnectionDelegate Application => application; + + public override async IAsyncEnumerable BindAsync([EnumeratorCancellation]CancellationToken cancellationToken = default) { - private readonly ConnectionDelegate _application; + var exceptions = new List(); + + IConnectionListener ipv6Listener = null; + IConnectionListener ipv4Listener = null; - public LocalHostBinding(int port, ConnectionDelegate application, IConnectionListenerFactory connectionListenerFactory) + try { - Port = port; - _application = application; - ConnectionListenerFactory = connectionListenerFactory; + ipv6Listener = await connectionListenerFactory.BindAsync(new IPEndPoint(IPAddress.IPv6Loopback, port), cancellationToken); } - - private int Port { get; } - private IConnectionListenerFactory ConnectionListenerFactory { get; } - - public override ConnectionDelegate Application => _application; - - public override async IAsyncEnumerable BindAsync([EnumeratorCancellation]CancellationToken cancellationToken = default) + catch (Exception ex) when (ex is not IOException) { - var exceptions = new List(); - - IConnectionListener ipv6Listener = null; - IConnectionListener ipv4Listener = null; - - try - { - ipv6Listener = await ConnectionListenerFactory.BindAsync(new IPEndPoint(IPAddress.IPv6Loopback, Port), cancellationToken); - } - catch (Exception ex) when (!(ex is IOException)) - { - exceptions.Add(ex); - } - - if (ipv6Listener != null) - { - yield return ipv6Listener; - } + exceptions.Add(ex); + } - try - { - ipv4Listener = await ConnectionListenerFactory.BindAsync(new IPEndPoint(IPAddress.Loopback, Port), cancellationToken); - } - catch (Exception ex) when (!(ex is IOException)) - { - exceptions.Add(ex); - } + if (ipv6Listener != null) + { + yield return ipv6Listener; + } - if (exceptions.Count == 2) - { - throw new IOException($"Failed to bind to {this}", new AggregateException(exceptions)); - } + try + { + ipv4Listener = await connectionListenerFactory.BindAsync(new IPEndPoint(IPAddress.Loopback, port), cancellationToken); + } + catch (Exception ex) when (ex is not IOException) + { + exceptions.Add(ex); + } - if (ipv4Listener != null) - { - yield return ipv4Listener; - } + if (exceptions.Count == 2) + { + throw new IOException($"Failed to bind to {this}", new AggregateException(exceptions)); } - public override string ToString() + if (ipv4Listener != null) { - return $"localhost:{Port}"; + yield return ipv4Listener; } } + + public override string ToString() + { + return $"localhost:{port}"; + } } diff --git a/src/Bedrock.Framework/Server/Server.cs b/src/Bedrock.Framework/Server/Server.cs index 3253314f..f736c7a0 100644 --- a/src/Bedrock.Framework/Server/Server.cs +++ b/src/Bedrock.Framework/Server/Server.cs @@ -7,250 +7,239 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class Server { - public class Server + private readonly ServerBuilder _builder; + private readonly ILogger _logger; + private readonly List _listeners = []; + private readonly TaskCompletionSource _shutdownTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly PeriodicTimer _timer; + private Task _timerTask = Task.CompletedTask; + + internal Server(ServerBuilder builder) { - private readonly ServerBuilder _builder; - private readonly ILogger _logger; - private readonly List _listeners = new List(); - private readonly TaskCompletionSource _shutdownTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly TimerAwaitable _timerAwaitable; - private Task _timerTask = Task.CompletedTask; - - internal Server(ServerBuilder builder) - { - _logger = builder.ApplicationServices.GetLoggerFactory().CreateLogger(); - _builder = builder; - _timerAwaitable = new TimerAwaitable(_builder.HeartBeatInterval, _builder.HeartBeatInterval); - } + _logger = builder.ApplicationServices.GetLoggerFactory().CreateLogger(); + _builder = builder; + _timer = new PeriodicTimer(_builder.HeartBeatInterval); + } - public IEnumerable EndPoints + public IEnumerable EndPoints + { + get { - get + foreach (var listener in _listeners) { - foreach (var listener in _listeners) - { - yield return listener.Listener.EndPoint; - } + yield return listener.Listener.EndPoint; } } + } - public async Task StartAsync(CancellationToken cancellationToken = default) + public async Task StartAsync(CancellationToken cancellationToken = default) + { + try { - try + foreach (var binding in _builder.Bindings) { - foreach (var binding in _builder.Bindings) + await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false)) { - await foreach (var listener in binding.BindAsync(cancellationToken).ConfigureAwait(false)) - { - var runningListener = new RunningListener(this, binding, listener); - _listeners.Add(runningListener); - runningListener.Start(); - } + var runningListener = new RunningListener(this, binding, listener); + _listeners.Add(runningListener); + runningListener.Start(); } } - catch - { - await StopAsync().ConfigureAwait(false); - - throw; - } + } + catch + { + await StopAsync().ConfigureAwait(false); - _timerAwaitable.Start(); - _timerTask = StartTimerAsync(); + throw; } - private async Task StartTimerAsync() + _timerTask = StartTimerAsync(); + } + + private async Task StartTimerAsync() + { + using (_timer) { - using (_timerAwaitable) + while (await _timer.WaitForNextTickAsync()) { - while (await _timerAwaitable) + foreach (var listener in _listeners) { - foreach (var listener in _listeners) - { - listener.TickHeartbeat(); - } + listener.TickHeartbeat(); } } } + } - public async Task StopAsync(CancellationToken cancellationToken = default) - { - var tasks = new Task[_listeners.Count]; + public async Task StopAsync(CancellationToken cancellationToken = default) + { + var tasks = new Task[_listeners.Count]; - for (int i = 0; i < _listeners.Count; i++) - { - tasks[i] = _listeners[i].Listener.UnbindAsync(cancellationToken).AsTask(); - } + for (int i = 0; i < _listeners.Count; i++) + { + tasks[i] = _listeners[i].Listener.UnbindAsync(cancellationToken).AsTask(); + } - await Task.WhenAll(tasks).ConfigureAwait(false); + await Task.WhenAll(tasks).ConfigureAwait(false); - // Signal to all of the listeners that it's time to start the shutdown process - // We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener) - _shutdownTcs.TrySetResult(null); + // Signal to all of the listeners that it's time to start the shutdown process + // We call this after unbind so that we're not touching the listener anymore (each loop will dispose the listener) + _shutdownTcs.TrySetResult(null); - for (int i = 0; i < _listeners.Count; i++) - { - tasks[i] = _listeners[i].ExecutionTask; - } + for (int i = 0; i < _listeners.Count; i++) + { + tasks[i] = _listeners[i].ExecutionTask; + } - var shutdownTask = Task.WhenAll(tasks); + var shutdownTask = Task.WhenAll(tasks); - if (cancellationToken.CanBeCanceled) - { - await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false); - } - else - { - await shutdownTask.ConfigureAwait(false); - } + if (cancellationToken.CanBeCanceled) + { + await shutdownTask.WithCancellation(cancellationToken).ConfigureAwait(false); + } + else + { + await shutdownTask.ConfigureAwait(false); + } - if (_timerAwaitable != null) - { - _timerAwaitable.Stop(); + if (_timer != null) + { + _timer.Dispose(); - await _timerTask.ConfigureAwait(false); - } + await _timerTask.ConfigureAwait(false); } + } - private class RunningListener + private class RunningListener(Server server, ServerBinding binding, IConnectionListener listener) + { + private readonly ConcurrentDictionary _connections = []; + + public void Start() { - private readonly Server _server; - private readonly ServerBinding _binding; - private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + ExecutionTask = RunListenerAsync(); + } - public RunningListener(Server server, ServerBinding binding, IConnectionListener listener) - { - _server = server; - _binding = binding; - Listener = listener; - } + public IConnectionListener Listener { get; } = listener; + public Task ExecutionTask { get; private set; } - public void Start() + public void TickHeartbeat() + { + foreach (var pair in _connections) { - ExecutionTask = RunListenerAsync(); + pair.Value.Connection.TickHeartbeat(); } + } - public IConnectionListener Listener { get; } - public Task ExecutionTask { get; private set; } + private async Task RunListenerAsync() + { + var connectionDelegate = binding.Application; + var listener = Listener; - public void TickHeartbeat() + async Task ExecuteConnectionAsync(ServerConnection serverConnection) { - foreach (var pair in _connections) - { - pair.Value.Connection.TickHeartbeat(); - } - } + await Task.Yield(); - private async Task RunListenerAsync() - { - var connectionDelegate = _binding.Application; - var listener = Listener; + var connection = serverConnection.TransportConnection; - async Task ExecuteConnectionAsync(ServerConnection serverConnection) + try { - await Task.Yield(); - - var connection = serverConnection.TransportConnection; - - try - { - using var scope = BeginConnectionScope(serverConnection); + using var scope = BeginConnectionScope(serverConnection); - await connectionDelegate(connection).ConfigureAwait(false); - } - catch (ConnectionResetException) - { - // Don't let connection aborted exceptions out - } - catch (ConnectionAbortedException) - { - // Don't let connection aborted exceptions out - } - catch (Exception ex) - { - _server._logger.LogError(ex, "Unexpected exception from connection {ConnectionId}", connection.ConnectionId); - } - finally - { - // Fire the OnCompleted callbacks - await serverConnection.FireOnCompletedAsync().ConfigureAwait(false); + await connectionDelegate(connection).ConfigureAwait(false); + } + catch (ConnectionResetException) + { + // Don't let connection aborted exceptions out + } + catch (ConnectionAbortedException) + { + // Don't let connection aborted exceptions out + } + catch (Exception ex) + { + server._logger.LogError(ex, "Unexpected exception from connection {ConnectionId}", connection.ConnectionId); + } + finally + { + // Fire the OnCompleted callbacks + await serverConnection.FireOnCompletedAsync().ConfigureAwait(false); - await connection.DisposeAsync().ConfigureAwait(false); + await connection.DisposeAsync().ConfigureAwait(false); - // Remove the connection from tracking - _connections.TryRemove(serverConnection.Id, out _); - } + // Remove the connection from tracking + _connections.TryRemove(serverConnection.Id, out _); } + } - long id = 0; + long id = 0; - while (true) + while (true) + { + try { - try - { - var connection = await listener.AcceptAsync().ConfigureAwait(false); - - if (connection == null) - { - // Null means we don't have anymore connections - break; - } + var connection = await listener.AcceptAsync().ConfigureAwait(false); - var serverConnection = new ServerConnection(id, connection, _server._logger); - - _connections[id] = (serverConnection, ExecuteConnectionAsync(serverConnection)); - } - catch (OperationCanceledException) + if (connection == null) { + // Null means we don't have anymore connections break; } - catch (Exception ex) - { - _server._logger.LogCritical(ex, "Stopped accepting connections on {endpoint}", listener.EndPoint); - break; - } - - id++; - } - // Don't shut down connections until entire server is shutting down - await _server._shutdownTcs.Task.ConfigureAwait(false); + var serverConnection = new ServerConnection(id, connection, server._logger); - // Give connections a chance to close gracefully - var tasks = new List(_connections.Count); - - foreach (var pair in _connections) + _connections[id] = (serverConnection, ExecuteConnectionAsync(serverConnection)); + } + catch (OperationCanceledException) { - pair.Value.Connection.RequestClose(); - tasks.Add(pair.Value.ExecutionTask); + break; } - - if (!await Task.WhenAll(tasks).TimeoutAfter(_server._builder.ShutdownTimeout).ConfigureAwait(false)) + catch (Exception ex) { - // Abort all connections still in flight - foreach (var pair in _connections) - { - pair.Value.Connection.TransportConnection.Abort(); - } - - await Task.WhenAll(tasks).ConfigureAwait(false); + server._logger.LogCritical(ex, "Stopped accepting connections on {endpoint}", listener.EndPoint); + break; } - await listener.DisposeAsync().ConfigureAwait(false); + id++; } + // Don't shut down connections until entire server is shutting down + await server._shutdownTcs.Task.ConfigureAwait(false); - private IDisposable BeginConnectionScope(ServerConnection connection) + // Give connections a chance to close gracefully + var tasks = new List(_connections.Count); + + foreach (var pair in _connections) + { + pair.Value.Connection.RequestClose(); + tasks.Add(pair.Value.ExecutionTask); + } + + if (!await Task.WhenAll(tasks).TimeoutAfter(server._builder.ShutdownTimeout).ConfigureAwait(false)) { - if (_server._logger.IsEnabled(LogLevel.Critical)) + // Abort all connections still in flight + foreach (var pair in _connections) { - return _server._logger.BeginScope(connection); + pair.Value.Connection.TransportConnection.Abort(); } - return null; + await Task.WhenAll(tasks).ConfigureAwait(false); } + + await listener.DisposeAsync().ConfigureAwait(false); + } + + + private IDisposable BeginConnectionScope(ServerConnection connection) + { + if (server._logger.IsEnabled(LogLevel.Critical)) + { + return server._logger.BeginScope(connection); + } + + return null; } } } diff --git a/src/Bedrock.Framework/Server/ServerBuilder.cs b/src/Bedrock.Framework/Server/ServerBuilder.cs index c2534b3b..f01ddba2 100644 --- a/src/Bedrock.Framework/Server/ServerBuilder.cs +++ b/src/Bedrock.Framework/Server/ServerBuilder.cs @@ -2,31 +2,30 @@ using System.Collections.Generic; using Bedrock.Framework.Infrastructure; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class ServerBuilder { - public class ServerBuilder + public ServerBuilder() : this(EmptyServiceProvider.Instance) { - public ServerBuilder() : this(EmptyServiceProvider.Instance) - { - } + } - public ServerBuilder(IServiceProvider serviceProvider) - { - ApplicationServices = serviceProvider; - } + public ServerBuilder(IServiceProvider serviceProvider) + { + ApplicationServices = serviceProvider; + } - public IList Bindings { get; } = new List(); + public IList Bindings { get; } = []; - public TimeSpan ShutdownTimeout { get; set; } = TimeSpan.FromSeconds(5); + public TimeSpan ShutdownTimeout { get; set; } = TimeSpan.FromSeconds(5); - public TimeSpan HeartBeatInterval { get; set; } = TimeSpan.FromSeconds(1); + public TimeSpan HeartBeatInterval { get; set; } = TimeSpan.FromSeconds(1); - public IServiceProvider ApplicationServices { get; } + public IServiceProvider ApplicationServices { get; } - public Server Build() - { - return new Server(this); - } + public Server Build() + { + return new Server(this); } } diff --git a/src/Bedrock.Framework/Transports/Memory/MemoryEndPoint.cs b/src/Bedrock.Framework/Transports/Memory/MemoryEndPoint.cs index 2e46c7b8..9e4b8aea 100644 --- a/src/Bedrock.Framework/Transports/Memory/MemoryEndPoint.cs +++ b/src/Bedrock.Framework/Transports/Memory/MemoryEndPoint.cs @@ -1,20 +1,14 @@ using System.Net; -namespace Bedrock.Framework.Transports.Memory -{ - public class MemoryEndPoint : EndPoint - { - public static readonly MemoryEndPoint Default = new MemoryEndPoint("default"); +namespace Bedrock.Framework.Transports.Memory; - public MemoryEndPoint(string name) - { - Name = name; - } +public class MemoryEndPoint(string name) : EndPoint +{ + public static readonly MemoryEndPoint Default = new MemoryEndPoint("default"); - public string Name { get; } + public string Name { get; } = name; - public override string ToString() => Name; + public override string ToString() => Name; - public override int GetHashCode() => Name.GetHashCode(); - } + public override int GetHashCode() => Name.GetHashCode(); } diff --git a/src/Bedrock.Framework/Transports/Memory/MemoryTransport.cs b/src/Bedrock.Framework/Transports/Memory/MemoryTransport.cs index e48b51c4..274833ca 100644 --- a/src/Bedrock.Framework/Transports/Memory/MemoryTransport.cs +++ b/src/Bedrock.Framework/Transports/Memory/MemoryTransport.cs @@ -7,101 +7,96 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework.Transports.Memory +namespace Bedrock.Framework.Transports.Memory; + +public partial class MemoryTransport : IConnectionListenerFactory, IConnectionFactory { - public partial class MemoryTransport : IConnectionListenerFactory, IConnectionFactory + private readonly ConcurrentDictionary _listeners = new ConcurrentDictionary(); + + public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) { - private readonly ConcurrentDictionary _listeners = new ConcurrentDictionary(); + endpoint ??= MemoryEndPoint.Default; + MemoryConnectionListener listener; - public ValueTask BindAsync(EndPoint endpoint, CancellationToken cancellationToken = default) + if (_listeners.TryGetValue(endpoint, out _) || + !_listeners.TryAdd(endpoint, listener = new MemoryConnectionListener() { EndPoint = endpoint })) { - endpoint ??= MemoryEndPoint.Default; - MemoryConnectionListener listener; + throw new AddressInUseException($"{endpoint} listener already bound"); + } - if (_listeners.TryGetValue(endpoint, out _) || - !_listeners.TryAdd(endpoint, listener = new MemoryConnectionListener() { EndPoint = endpoint })) - { - throw new AddressInUseException($"{endpoint} listener already bound"); - } + return new ValueTask(listener); + } - return new ValueTask(listener); - } + public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) + { + endpoint ??= MemoryEndPoint.Default; - public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) + if (!_listeners.TryGetValue(endpoint, out var listener)) { - endpoint ??= MemoryEndPoint.Default; - - if (!_listeners.TryGetValue(endpoint, out var listener)) - { - throw new InvalidOperationException($"{endpoint} not bound!"); - } + throw new InvalidOperationException($"{endpoint} not bound!"); + } - var pair = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); + var pair = DuplexPipe.CreateConnectionPair(new PipeOptions(), new PipeOptions()); - var serverConnection = new MemoryConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application) - { - LocalEndPoint = endpoint, - RemoteEndPoint = endpoint - }; + var serverConnection = new MemoryConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application) + { + LocalEndPoint = endpoint, + RemoteEndPoint = endpoint + }; - var clientConnection = new DefaultConnectionContext(serverConnection.ConnectionId, pair.Application, pair.Transport) - { - LocalEndPoint = endpoint, - RemoteEndPoint = endpoint - }; + var clientConnection = new DefaultConnectionContext(serverConnection.ConnectionId, pair.Application, pair.Transport) + { + LocalEndPoint = endpoint, + RemoteEndPoint = endpoint + }; - listener.AcceptQueue.Writer.TryWrite(serverConnection); - return new ValueTask(clientConnection); - } + listener.AcceptQueue.Writer.TryWrite(serverConnection); + return new ValueTask(clientConnection); + } - private class MemoryConnectionContext : DefaultConnectionContext + private class MemoryConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) : DefaultConnectionContext(id, transport, application) + { + public override async ValueTask DisposeAsync() { - public MemoryConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) - : base(id, transport, application) - { } - - public override async ValueTask DisposeAsync() + if (Transport != null) { - if (Transport != null) - { - await Transport.Output.CompleteAsync().ConfigureAwait(false); - await Transport.Input.CompleteAsync().ConfigureAwait(false); - } - - await base.DisposeAsync().ConfigureAwait(false); + await Transport.Output.CompleteAsync().ConfigureAwait(false); + await Transport.Input.CompleteAsync().ConfigureAwait(false); } + + await base.DisposeAsync().ConfigureAwait(false); } + } - private class MemoryConnectionListener : IConnectionListener - { - public EndPoint EndPoint { get; set; } + private class MemoryConnectionListener : IConnectionListener + { + public EndPoint EndPoint { get; set; } - internal Channel AcceptQueue { get; } = Channel.CreateUnbounded(); + internal Channel AcceptQueue { get; } = Channel.CreateUnbounded(); - public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) + public async ValueTask AcceptAsync(CancellationToken cancellationToken = default) + { + if (await AcceptQueue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { - if (await AcceptQueue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + while (AcceptQueue.Reader.TryRead(out var item)) { - while (AcceptQueue.Reader.TryRead(out var item)) - { - return item; - } + return item; } - - return null; } - public ValueTask DisposeAsync() - { - return UnbindAsync(); - } + return null; + } - public ValueTask UnbindAsync(CancellationToken cancellationToken = default) - { - AcceptQueue.Writer.TryComplete(); + public ValueTask DisposeAsync() + { + return UnbindAsync(); + } - return default; - } + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) + { + AcceptQueue.Writer.TryComplete(); + + return default; } } } diff --git a/src/Bedrock.Framework/Transports/ServerBuilderExtensions.cs b/src/Bedrock.Framework/Transports/ServerBuilderExtensions.cs index 6ca908fd..d113768b 100644 --- a/src/Bedrock.Framework/Transports/ServerBuilderExtensions.cs +++ b/src/Bedrock.Framework/Transports/ServerBuilderExtensions.cs @@ -3,21 +3,20 @@ using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.DependencyInjection; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public static partial class ServerBuilderExtensions { - public static partial class ServerBuilderExtensions + public static ServerBuilder Listen(this ServerBuilder builder, EndPoint endPoint, Action configure) where TTransport : IConnectionListenerFactory { - public static ServerBuilder Listen(this ServerBuilder builder, EndPoint endPoint, Action configure) where TTransport : IConnectionListenerFactory - { - return builder.Listen(endPoint, ActivatorUtilities.CreateInstance(builder.ApplicationServices), configure); - } + return builder.Listen(endPoint, ActivatorUtilities.CreateInstance(builder.ApplicationServices), configure); + } - public static ServerBuilder Listen(this ServerBuilder builder, EndPoint endPoint, IConnectionListenerFactory connectionListenerFactory, Action configure) - { - var connectionBuilder = new ConnectionBuilder(builder.ApplicationServices); - configure(connectionBuilder); - builder.Bindings.Add(new EndPointBinding(endPoint, connectionBuilder.Build(), connectionListenerFactory)); - return builder; - } + public static ServerBuilder Listen(this ServerBuilder builder, EndPoint endPoint, IConnectionListenerFactory connectionListenerFactory, Action configure) + { + var connectionBuilder = new ConnectionBuilder(builder.ApplicationServices); + configure(connectionBuilder); + builder.Bindings.Add(new EndPointBinding(endPoint, connectionBuilder.Build(), connectionListenerFactory)); + return builder; } } diff --git a/src/Bedrock.Framework/Transports/Sockets/BufferExtensions.cs b/src/Bedrock.Framework/Transports/Sockets/BufferExtensions.cs index f5e1e9da..737487b2 100644 --- a/src/Bedrock.Framework/Transports/Sockets/BufferExtensions.cs +++ b/src/Bedrock.Framework/Transports/Sockets/BufferExtensions.cs @@ -1,22 +1,21 @@ using System; using System.Runtime.InteropServices; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal static class BufferExtensions { - internal static class BufferExtensions + public static ArraySegment GetArray(this Memory memory) { - public static ArraySegment GetArray(this Memory memory) - { - return ((ReadOnlyMemory)memory).GetArray(); - } + return ((ReadOnlyMemory)memory).GetArray(); + } - public static ArraySegment GetArray(this ReadOnlyMemory memory) + public static ArraySegment GetArray(this ReadOnlyMemory memory) + { + if (!MemoryMarshal.TryGetArray(memory, out var result)) { - if (!MemoryMarshal.TryGetArray(memory, out var result)) - { - throw new InvalidOperationException("Buffer backed by array was expected"); - } - return result; + throw new InvalidOperationException("Buffer backed by array was expected"); } + return result; } } diff --git a/src/Bedrock.Framework/Transports/Sockets/ServerBuilderExtensions.cs b/src/Bedrock.Framework/Transports/Sockets/ServerBuilderExtensions.cs index 814aec9a..8b2a91f2 100644 --- a/src/Bedrock.Framework/Transports/Sockets/ServerBuilderExtensions.cs +++ b/src/Bedrock.Framework/Transports/Sockets/ServerBuilderExtensions.cs @@ -1,24 +1,19 @@ using System; -using System.Net; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public static partial class ServerBuilderExtensions { - public static partial class ServerBuilderExtensions + public static ServerBuilder UseSockets(this ServerBuilder serverBuilder, Action configure) { - public static ServerBuilder UseSockets(this ServerBuilder serverBuilder, Action configure) - { - var socketsBuilder = new SocketsServerBuilder(); - configure(socketsBuilder); - socketsBuilder.Apply(serverBuilder); - return serverBuilder; - } + var socketsBuilder = new SocketsServerBuilder(); + configure(socketsBuilder); + socketsBuilder.Apply(serverBuilder); + return serverBuilder; + } - public static ClientBuilder UseSockets(this ClientBuilder clientBuilder) - { - return clientBuilder.UseConnectionFactory(new SocketConnectionFactory()); - } + public static ClientBuilder UseSockets(this ClientBuilder clientBuilder) + { + return clientBuilder.UseConnectionFactory(new SocketConnectionFactory()); } } \ No newline at end of file diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketAwaitable.cs b/src/Bedrock.Framework/Transports/Sockets/SocketAwaitable.cs index c8b417da..45e4612e 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketAwaitable.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketAwaitable.cs @@ -1,71 +1,68 @@ using System; -using System.Collections.Generic; using System.Diagnostics; using System.IO.Pipelines; using System.Net.Sockets; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; using System.Threading.Tasks; -namespace Bedrock.Framework -{ - internal class SocketAwaitable : ICriticalNotifyCompletion - { - private static readonly Action _callbackCompleted = () => { }; - - private readonly PipeScheduler _ioScheduler; +namespace Bedrock.Framework; - private Action _callback; - private int _bytesTransferred; - private SocketError _error; +internal class SocketAwaitable : ICriticalNotifyCompletion +{ + private static readonly Action _callbackCompleted = () => { }; - public SocketAwaitable(PipeScheduler ioScheduler) - { - _ioScheduler = ioScheduler; - } + private readonly PipeScheduler _ioScheduler; - public SocketAwaitable GetAwaiter() => this; - public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); + private Action _callback; + private int _bytesTransferred; + private SocketError _error; - public int GetResult() - { - Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); + public SocketAwaitable(PipeScheduler ioScheduler) + { + _ioScheduler = ioScheduler; + } - _callback = null; + public SocketAwaitable GetAwaiter() => this; + public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); - if (_error != SocketError.Success) - { - throw new SocketException((int)_error); - } + public int GetResult() + { + Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); - return _bytesTransferred; - } + _callback = null; - public void OnCompleted(Action continuation) + if (_error != SocketError.Success) { - if (ReferenceEquals(_callback, _callbackCompleted) || - ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) - { - Task.Run(continuation); - } + throw new SocketException((int)_error); } - public void UnsafeOnCompleted(Action continuation) + return _bytesTransferred; + } + + public void OnCompleted(Action continuation) + { + if (ReferenceEquals(_callback, _callbackCompleted) || + ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) { - OnCompleted(continuation); + Task.Run(continuation); } + } - public void Complete(int bytesTransferred, SocketError socketError) - { - _error = socketError; - _bytesTransferred = bytesTransferred; - var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); + public void UnsafeOnCompleted(Action continuation) + { + OnCompleted(continuation); + } - if (continuation != null) - { - _ioScheduler.Schedule(state => ((Action)state)(), continuation); - } + public void Complete(int bytesTransferred, SocketError socketError) + { + _error = socketError; + _bytesTransferred = bytesTransferred; + var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); + + if (continuation != null) + { + _ioScheduler.Schedule(state => ((Action)state)(), continuation); } } } diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketConnection.cs b/src/Bedrock.Framework/Transports/Sockets/SocketConnection.cs index 00ec6ca6..068bd18b 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketConnection.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketConnection.cs @@ -4,262 +4,260 @@ using System.IO.Pipelines; using System.Net; using System.Net.Sockets; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal class SocketConnection : ConnectionContext, IConnectionInherentKeepAliveFeature { - internal class SocketConnection : ConnectionContext, IConnectionInherentKeepAliveFeature + private readonly Socket _socket; + private volatile bool _aborted; + private readonly EndPoint _endPoint; + private IDuplexPipe _application; + private readonly SocketSender _sender; + private readonly SocketReceiver _receiver; + + public SocketConnection(EndPoint endPoint) { - private readonly Socket _socket; - private volatile bool _aborted; - private readonly EndPoint _endPoint; - private IDuplexPipe _application; - private readonly SocketSender _sender; - private readonly SocketReceiver _receiver; - - public SocketConnection(EndPoint endPoint) - { - _socket = new Socket(endPoint.AddressFamily, SocketType.Stream, DetermineProtocolType(endPoint)); - _endPoint = endPoint; + _socket = new Socket(endPoint.AddressFamily, SocketType.Stream, DetermineProtocolType(endPoint)); + _endPoint = endPoint; - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); + _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); + _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); - // Add IConnectionInherentKeepAliveFeature to the tcp connection impl since Kestrel doesn't implement - // the IConnectionHeartbeatFeature - Features.Set(this); - } + // Add IConnectionInherentKeepAliveFeature to the tcp connection impl since Kestrel doesn't implement + // the IConnectionHeartbeatFeature + Features.Set(this); + } - public override IDuplexPipe Transport { get; set; } + public override IDuplexPipe Transport { get; set; } - public override IFeatureCollection Features { get; } = new FeatureCollection(); - public override string ConnectionId { get; set; } = Guid.NewGuid().ToString(); - public override IDictionary Items { get; set; } = new ConnectionItems(); + public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override string ConnectionId { get; set; } = Guid.NewGuid().ToString(); + public override IDictionary Items { get; set; } = new ConnectionItems(); - // We claim to have inherent keep-alive so the client doesn't kill the connection when it hasn't seen ping frames. - public bool HasInherentKeepAlive { get; } = true; + // We claim to have inherent keep-alive so the client doesn't kill the connection when it hasn't seen ping frames. + public bool HasInherentKeepAlive { get; } = true; - public override async ValueTask DisposeAsync() + public override async ValueTask DisposeAsync() + { + if (Transport != null) { - if (Transport != null) - { - await Transport.Output.CompleteAsync().ConfigureAwait(false); - await Transport.Input.CompleteAsync().ConfigureAwait(false); - } - - // Completing these loops will cause ExecuteAsync to Dispose the socket. + await Transport.Output.CompleteAsync().ConfigureAwait(false); + await Transport.Input.CompleteAsync().ConfigureAwait(false); } - public async ValueTask StartAsync() - { - await _socket.ConnectAsync(_endPoint).ConfigureAwait(false); + // Completing these loops will cause ExecuteAsync to Dispose the socket. + } - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + public async ValueTask StartAsync() + { + await _socket.ConnectAsync(_endPoint).ConfigureAwait(false); - LocalEndPoint = _socket.LocalEndPoint; - RemoteEndPoint = _socket.RemoteEndPoint; + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - Transport = pair.Transport; - _application = pair.Application; + LocalEndPoint = _socket.LocalEndPoint; + RemoteEndPoint = _socket.RemoteEndPoint; - _ = ExecuteAsync(); + Transport = pair.Transport; + _application = pair.Application; - return this; - } + _ = ExecuteAsync(); - private async Task ExecuteAsync() + return this; + } + + private async Task ExecuteAsync() + { + Exception sendError = null; + try { - Exception sendError = null; - try + // Spawn send and receive logic + var receiveTask = DoReceive(); + var sendTask = DoSend(); + + // If the sending task completes then close the receive + // We don't need to do this in the other direction because the kestrel + // will trigger the output closing once the input is complete. + if (await Task.WhenAny(receiveTask, sendTask).ConfigureAwait(false) == sendTask) { - // Spawn send and receive logic - var receiveTask = DoReceive(); - var sendTask = DoSend(); - - // If the sending task completes then close the receive - // We don't need to do this in the other direction because the kestrel - // will trigger the output closing once the input is complete. - if (await Task.WhenAny(receiveTask, sendTask).ConfigureAwait(false) == sendTask) - { - // Tell the reader it's being aborted - _socket.Dispose(); - } - - // Now wait for both to complete - await receiveTask; - sendError = await sendTask; - - // Dispose the socket(should noop if already called) + // Tell the reader it's being aborted _socket.Dispose(); } - catch (Exception ex) - { - Console.WriteLine($"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}: " + ex); - } - finally - { - // Complete the output after disposing the socket - _application.Input.Complete(sendError); - } - } - private async Task DoReceive() + // Now wait for both to complete + await receiveTask; + sendError = await sendTask; + + // Dispose the socket(should noop if already called) + _socket.Dispose(); + } + catch (Exception ex) + { + Console.WriteLine($"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}: " + ex); + } + finally { - Exception error = null; + // Complete the output after disposing the socket + _application.Input.Complete(sendError); + } + } - try - { - await ProcessReceives().ConfigureAwait(false); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.ConnectionReset) - { - error = new ConnectionResetException(ex.Message, ex); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted || - ex.SocketErrorCode == SocketError.ConnectionAborted || - ex.SocketErrorCode == SocketError.Interrupted || - ex.SocketErrorCode == SocketError.InvalidArgument) - { - if (!_aborted) - { - // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. - error = new ConnectionAbortedException(); - } - } - catch (ObjectDisposedException) + private async Task DoReceive() + { + Exception error = null; + + try + { + await ProcessReceives().ConfigureAwait(false); + } + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.ConnectionReset) + { + error = new ConnectionResetException(ex.Message, ex); + } + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted || + ex.SocketErrorCode == SocketError.ConnectionAborted || + ex.SocketErrorCode == SocketError.Interrupted || + ex.SocketErrorCode == SocketError.InvalidArgument) + { + if (!_aborted) { - if (!_aborted) - { - error = new ConnectionAbortedException(); - } + // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. + error = new ConnectionAbortedException(); } - catch (IOException ex) + } + catch (ObjectDisposedException) + { + if (!_aborted) { - error = ex; + error = new ConnectionAbortedException(); } - catch (Exception ex) + } + catch (IOException ex) + { + error = ex; + } + catch (Exception ex) + { + error = new IOException(ex.Message, ex); + } + finally + { + if (_aborted) { - error = new IOException(ex.Message, ex); + error ??= new ConnectionAbortedException(); } - finally - { - if (_aborted) - { - error ??= new ConnectionAbortedException(); - } - await _application.Output.CompleteAsync(error).ConfigureAwait(false); - } + await _application.Output.CompleteAsync(error).ConfigureAwait(false); } + } - private async Task ProcessReceives() + private async Task ProcessReceives() + { + while (true) { - while (true) - { - // Ensure we have some reasonable amount of buffer space - var buffer = _application.Output.GetMemory(); + // Ensure we have some reasonable amount of buffer space + var buffer = _application.Output.GetMemory(); - var bytesReceived = await _receiver.ReceiveAsync(buffer); + var bytesReceived = await _receiver.ReceiveAsync(buffer); - if (bytesReceived == 0) - { - // FIN - break; - } + if (bytesReceived == 0) + { + // FIN + break; + } - _application.Output.Advance(bytesReceived); + _application.Output.Advance(bytesReceived); - var flushTask = _application.Output.FlushAsync(); + var flushTask = _application.Output.FlushAsync(); - if (!flushTask.IsCompleted) - { - await flushTask.ConfigureAwait(false); - } + if (!flushTask.IsCompleted) + { + await flushTask.ConfigureAwait(false); + } - var result = flushTask.Result; - if (result.IsCompleted) - { - // Pipe consumer is shut down, do we stop writing - break; - } + var result = flushTask.Result; + if (result.IsCompleted) + { + // Pipe consumer is shut down, do we stop writing + break; } } + } - private async Task DoSend() + private async Task DoSend() + { + Exception error = null; + + try + { + await ProcessSends().ConfigureAwait(false); + } + catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted) + { + error = null; + } + catch (ObjectDisposedException) { - Exception error = null; + error = null; + } + catch (IOException ex) + { + error = ex; + } + catch (Exception ex) + { + error = new IOException(ex.Message, ex); + } + finally + { + _aborted = true; + _socket.Shutdown(SocketShutdown.Both); + } - try - { - await ProcessSends().ConfigureAwait(false); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted) - { - error = null; - } - catch (ObjectDisposedException) - { - error = null; - } - catch (IOException ex) - { - error = ex; - } - catch (Exception ex) + return error; + } + + private async Task ProcessSends() + { + while (true) + { + // Wait for data to write from the pipe producer + var result = await _application.Input.ReadAsync().ConfigureAwait(false); + var buffer = result.Buffer; + + if (result.IsCanceled) { - error = new IOException(ex.Message, ex); + break; } - finally + + var end = buffer.End; + var isCompleted = result.IsCompleted; + if (!buffer.IsEmpty) { - _aborted = true; - _socket.Shutdown(SocketShutdown.Both); + await _sender.SendAsync(buffer); } - return error; - } + _application.Input.AdvanceTo(end); - private async Task ProcessSends() - { - while (true) + if (isCompleted) { - // Wait for data to write from the pipe producer - var result = await _application.Input.ReadAsync().ConfigureAwait(false); - var buffer = result.Buffer; - - if (result.IsCanceled) - { - break; - } - - var end = buffer.End; - var isCompleted = result.IsCompleted; - if (!buffer.IsEmpty) - { - await _sender.SendAsync(buffer); - } - - _application.Input.AdvanceTo(end); - - if (isCompleted) - { - break; - } + break; } } + } - private static ProtocolType DetermineProtocolType(EndPoint endPoint) + private static ProtocolType DetermineProtocolType(EndPoint endPoint) + { + switch (endPoint) { - switch (endPoint) - { - case UnixDomainSocketEndPoint _: - return ProtocolType.Unspecified; - default: - return ProtocolType.Tcp; - } + case UnixDomainSocketEndPoint _: + return ProtocolType.Unspecified; + default: + return ProtocolType.Tcp; } } } diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketConnectionFactory.cs b/src/Bedrock.Framework/Transports/Sockets/SocketConnectionFactory.cs index 3a484029..b63db75b 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketConnectionFactory.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketConnectionFactory.cs @@ -1,19 +1,14 @@ -using System; -using System.Collections.Generic; -using System.Net; -using System.Net.Sockets; -using System.Text; +using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class SocketConnectionFactory : IConnectionFactory { - public class SocketConnectionFactory : IConnectionFactory + public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) { - public ValueTask ConnectAsync(EndPoint endpoint, CancellationToken cancellationToken = default) - { - return new SocketConnection(endpoint).StartAsync(); - } + return new SocketConnection(endpoint).StartAsync(); } } diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketReceiver.cs b/src/Bedrock.Framework/Transports/Sockets/SocketReceiver.cs index 6ad5c235..0a3d2756 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketReceiver.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketReceiver.cs @@ -4,37 +4,36 @@ using System.Net.Sockets; using System.Text; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal class SocketReceiver { - internal class SocketReceiver - { - private readonly Socket _socket; - private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); - private readonly SocketAwaitable _awaitable; + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable; - public SocketReceiver(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } + public SocketReceiver(Socket socket, PipeScheduler scheduler) + { + _socket = socket; + _awaitable = new SocketAwaitable(scheduler); + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); + } - public SocketAwaitable ReceiveAsync(Memory buffer) - { + public SocketAwaitable ReceiveAsync(Memory buffer) + { #if NETCOREAPP - _eventArgs.SetBuffer(buffer); + _eventArgs.SetBuffer(buffer); #else - var segment = buffer.GetArray(); + var segment = buffer.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #endif - if (!_socket.ReceiveAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; + if (!_socket.ReceiveAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); } + + return _awaitable; } } diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketSender.cs b/src/Bedrock.Framework/Transports/Sockets/SocketSender.cs index 17793032..3c979011 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketSender.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketSender.cs @@ -5,96 +5,94 @@ using System.IO.Pipelines; using System.Net.Sockets; using System.Runtime.InteropServices; -using System.Text; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +internal class SocketSender { - internal class SocketSender - { - private readonly Socket _socket; - private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); - private readonly SocketAwaitable _awaitable; + private readonly Socket _socket; + private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); + private readonly SocketAwaitable _awaitable; - private List> _bufferList; + private List> _bufferList; - public SocketSender(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } + public SocketSender(Socket socket, PipeScheduler scheduler) + { + _socket = socket; + _awaitable = new SocketAwaitable(scheduler); + _eventArgs.UserToken = _awaitable; + _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); + } - public SocketAwaitable SendAsync(in ReadOnlySequence buffers) + public SocketAwaitable SendAsync(in ReadOnlySequence buffers) + { + if (buffers.IsSingleSegment) { - if (buffers.IsSingleSegment) - { - return SendAsync(buffers.First); - } + return SendAsync(buffers.First); + } #if NETCOREAPP - if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) + if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) #else - if (_eventArgs.Buffer != null) + if (_eventArgs.Buffer != null) #endif - { - _eventArgs.SetBuffer(null, 0, 0); - } - - _eventArgs.BufferList = GetBufferList(buffers); + { + _eventArgs.SetBuffer(null, 0, 0); + } - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } + _eventArgs.BufferList = GetBufferList(buffers); - return _awaitable; + if (!_socket.SendAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); } - private SocketAwaitable SendAsync(ReadOnlyMemory memory) + return _awaitable; + } + + private SocketAwaitable SendAsync(ReadOnlyMemory memory) + { + // The BufferList getter is much less expensive then the setter. + if (_eventArgs.BufferList != null) { - // The BufferList getter is much less expensive then the setter. - if (_eventArgs.BufferList != null) - { - _eventArgs.BufferList = null; - } + _eventArgs.BufferList = null; + } #if NETCOREAPP - _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); + _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); #else - var segment = memory.GetArray(); + var segment = memory.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #endif - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } + if (!_socket.SendAsync(_eventArgs)) + { + _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + } + + return _awaitable; + } - return _awaitable; + private List> GetBufferList(in ReadOnlySequence buffer) + { + Debug.Assert(!buffer.IsEmpty); + Debug.Assert(!buffer.IsSingleSegment); + + if (_bufferList == null) + { + _bufferList = new List>(); + } + else + { + // Buffers are pooled, so it's OK to root them until the next multi-buffer write. + _bufferList.Clear(); } - private List> GetBufferList(in ReadOnlySequence buffer) + foreach (var b in buffer) { - Debug.Assert(!buffer.IsEmpty); - Debug.Assert(!buffer.IsSingleSegment); - - if (_bufferList == null) - { - _bufferList = new List>(); - } - else - { - // Buffers are pooled, so it's OK to root them until the next multi-buffer write. - _bufferList.Clear(); - } - - foreach (var b in buffer) - { - _bufferList.Add(b.GetArray()); - } - - return _bufferList; + _bufferList.Add(b.GetArray()); } + + return _bufferList; } } diff --git a/src/Bedrock.Framework/Transports/Sockets/SocketsServerBuilder.cs b/src/Bedrock.Framework/Transports/Sockets/SocketsServerBuilder.cs index 455b3985..bd091ad7 100644 --- a/src/Bedrock.Framework/Transports/Sockets/SocketsServerBuilder.cs +++ b/src/Bedrock.Framework/Transports/Sockets/SocketsServerBuilder.cs @@ -5,58 +5,57 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; -namespace Bedrock.Framework +namespace Bedrock.Framework; + +public class SocketsServerBuilder { - public class SocketsServerBuilder - { - private List<(EndPoint EndPoint, int Port, Action Application)> _bindings = new List<(EndPoint, int, Action)>(); + private List<(EndPoint EndPoint, int Port, Action Application)> _bindings = []; - public SocketTransportOptions Options { get; } = new SocketTransportOptions(); + public SocketTransportOptions Options { get; } = new SocketTransportOptions(); - public SocketsServerBuilder Listen(EndPoint endPoint, Action configure) - { - _bindings.Add((endPoint, 0, configure)); - return this; - } + public SocketsServerBuilder Listen(EndPoint endPoint, Action configure) + { + _bindings.Add((endPoint, 0, configure)); + return this; + } - public SocketsServerBuilder Listen(IPAddress address, int port, Action configure) - { - return Listen(new IPEndPoint(address, port), configure); - } + public SocketsServerBuilder Listen(IPAddress address, int port, Action configure) + { + return Listen(new IPEndPoint(address, port), configure); + } - public SocketsServerBuilder ListenAnyIP(int port, Action configure) - { - return Listen(IPAddress.Any, port, configure); - } + public SocketsServerBuilder ListenAnyIP(int port, Action configure) + { + return Listen(IPAddress.Any, port, configure); + } - public SocketsServerBuilder ListenLocalhost(int port, Action configure) - { - _bindings.Add((null, port, configure)); - return this; - } + public SocketsServerBuilder ListenLocalhost(int port, Action configure) + { + _bindings.Add((null, port, configure)); + return this; + } - public SocketsServerBuilder ListenUnixSocket(string socketPath, Action configure) - { - return Listen(new UnixDomainSocketEndPoint(socketPath), configure); - } + public SocketsServerBuilder ListenUnixSocket(string socketPath, Action configure) + { + return Listen(new UnixDomainSocketEndPoint(socketPath), configure); + } - internal void Apply(ServerBuilder builder) - { - var socketTransportFactory = new SocketTransportFactory(Microsoft.Extensions.Options.Options.Create(Options), builder.ApplicationServices.GetLoggerFactory()); + internal void Apply(ServerBuilder builder) + { + var socketTransportFactory = new SocketTransportFactory(Microsoft.Extensions.Options.Options.Create(Options), builder.ApplicationServices.GetLoggerFactory()); - foreach (var binding in _bindings) + foreach (var binding in _bindings) + { + if (binding.EndPoint == null) { - if (binding.EndPoint == null) - { - var connectionBuilder = new ConnectionBuilder(builder.ApplicationServices); - binding.Application(connectionBuilder); - builder.Bindings.Add(new LocalHostBinding(binding.Port, connectionBuilder.Build(), socketTransportFactory)); - } - else - { - - builder.Listen(binding.EndPoint, socketTransportFactory, binding.Application); - } + var connectionBuilder = new ConnectionBuilder(builder.ApplicationServices); + binding.Application(connectionBuilder); + builder.Bindings.Add(new LocalHostBinding(binding.Port, connectionBuilder.Build(), socketTransportFactory)); + } + else + { + + builder.Listen(binding.EndPoint, socketTransportFactory, binding.Application); } } } diff --git a/tests/Bedrock.Framework.Tests/MessagePipeReaderTests.cs b/tests/Bedrock.Framework.Tests/MessagePipeReaderTests.cs index 8d93b6c9..19d80acf 100644 --- a/tests/Bedrock.Framework.Tests/MessagePipeReaderTests.cs +++ b/tests/Bedrock.Framework.Tests/MessagePipeReaderTests.cs @@ -29,7 +29,7 @@ private static MessagePipeReader CreateReader(out Func writeFunc) var position = stream.Position; stream.Position = written; protocol.WriteMessage(bytes, writer); - await writer.FlushAsync().ConfigureAwait(false); + await writer.FlushAsync(); written = stream.Position; stream.Position = position; }; @@ -39,14 +39,14 @@ private static MessagePipeReader CreateReader(out Func writeFunc) private static async Task CreateReaderOverBytes(byte[] bytes) { var reader = CreateReader(out var writeFunc); - await writeFunc(bytes).ConfigureAwait(false); + await writeFunc(bytes); return reader; } [Fact] public async Task CanRead() { - var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -63,8 +63,8 @@ public async Task CanRead() public async Task ReadAsyncReturnsFullBacklogWhenNotFullyConsumed() { var reader = CreateReader(out var writeFunc); - await writeFunc(Encoding.ASCII.GetBytes("Hello ")).ConfigureAwait(false); - await writeFunc(Encoding.ASCII.GetBytes("World")).ConfigureAwait(false); + await writeFunc(Encoding.ASCII.GetBytes("Hello ")); + await writeFunc(Encoding.ASCII.GetBytes("World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -81,7 +81,7 @@ public async Task ReadAsyncReturnsFullBacklogWhenNotFullyConsumed() Assert.True(buffer.IsSingleSegment); Assert.Equal("lo World", Encoding.ASCII.GetString(buffer.ToArray())); - await writeFunc(Encoding.ASCII.GetBytes("\nLorem Ipsum")).ConfigureAwait(false); + await writeFunc(Encoding.ASCII.GetBytes("\nLorem Ipsum")); reader.AdvanceTo(buffer.GetPosition(2)); readResult = await reader.ReadAsync(); buffer = readResult.Buffer; @@ -96,7 +96,7 @@ public async Task ReadAsyncReturnsFullBacklogWhenNotFullyConsumed() [Fact] public async Task TryReadReturnsTrueIfBufferedBytesAndNotExaminedEverything() { - var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -116,7 +116,7 @@ public async Task TryReadReturnsTrueIfBufferedBytesAndNotExaminedEverything() [Fact] public async Task TryReadReturnsFalseIfBufferedBytesAndEverythingConsumed() { - var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -131,7 +131,7 @@ public async Task TryReadReturnsFalseIfBufferedBytesAndEverythingConsumed() [Fact] public async Task TryReadReturnsFalseIfBufferedBytesAndEverythingExamined() { - var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(Encoding.ASCII.GetBytes("Hello World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -147,8 +147,8 @@ public async Task TryReadReturnsFalseIfBufferedBytesAndEverythingExamined() public async Task TryReadReturnsTrueIfBufferedBytesAndEverythingExaminedButMoreDataSynchronouslyAvailabe() { var reader = CreateReader(out var writeFunc); - await writeFunc(Encoding.ASCII.GetBytes("Hello ")).ConfigureAwait(false); - await writeFunc(Encoding.ASCII.GetBytes("World")).ConfigureAwait(false); + await writeFunc(Encoding.ASCII.GetBytes("Hello ")); + await writeFunc(Encoding.ASCII.GetBytes("World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -169,8 +169,8 @@ public async Task TryReadReturnsTrueIfBufferedBytesAndEverythingExaminedButMoreD public async Task TryReadReturnsTrueIfBufferedBytesAndEverythingConsumedButMoreDataSynchronouslyAvailabe() { var reader = CreateReader(out var writeFunc); - await writeFunc(Encoding.ASCII.GetBytes("Hello ")).ConfigureAwait(false); - await writeFunc(Encoding.ASCII.GetBytes("World")).ConfigureAwait(false); + await writeFunc(Encoding.ASCII.GetBytes("Hello ")); + await writeFunc(Encoding.ASCII.GetBytes("World")); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -199,7 +199,7 @@ async Task DoAsyncRead(PipeReader reader, int[] bufferSizes) var index = 0; while (true) { - var readResult = await reader.ReadAsync().ConfigureAwait(false); + var readResult = await reader.ReadAsync(); if (readResult.IsCompleted) { @@ -221,7 +221,7 @@ async Task DoAsyncWrites(PipeWriter writer, int[] bufferSizes) { writer.WriteEmpty(protocol, bufferSizes[i]); waitForRead = new TaskCompletionSource(); - await writer.FlushAsync().ConfigureAwait(false); + await writer.FlushAsync(); await waitForRead.Task; } @@ -247,7 +247,7 @@ async Task DoAsyncWrites(PipeWriter writer, int[] bufferSizes) [Fact] public async Task CanConsumeAllBytes() { - var reader = await CreateReaderOverBytes(new byte[100]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[100]); var buffer = (await reader.ReadAsync()).Buffer; reader.AdvanceTo(buffer.End); @@ -258,7 +258,7 @@ public async Task CanConsumeAllBytes() [Fact] public async Task CanConsumeNoBytes() { - var reader = await CreateReaderOverBytes(new byte[100]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[100]); var buffer = (await reader.ReadAsync()).Buffer; reader.AdvanceTo(buffer.Start); @@ -269,7 +269,7 @@ public async Task CanConsumeNoBytes() [Fact] public async Task CanExamineAllBytes() { - var reader = await CreateReaderOverBytes(new byte[100]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[100]); var buffer = (await reader.ReadAsync()).Buffer; reader.AdvanceTo(buffer.Start, buffer.End); @@ -280,7 +280,7 @@ public async Task CanExamineAllBytes() [Fact] public async Task CanExamineNoBytes() { - var reader = await CreateReaderOverBytes(new byte[100]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[100]); var buffer = (await reader.ReadAsync()).Buffer; reader.AdvanceTo(buffer.Start, buffer.Start); @@ -314,7 +314,7 @@ public async Task BufferingDataPastEndOfStreamCanBeReadAgain() var reader = new MessagePipeReader(PipeReader.Create(stream), protocol); protocol.WriteMessage(Encoding.ASCII.GetBytes("Hello World"), writer); - await writer.FlushAsync().ConfigureAwait(false); + await writer.FlushAsync(); stream.Position = 0; var readResult = await reader.ReadAsync(); @@ -343,7 +343,7 @@ public async Task CompleteReaderWithoutAdvanceDoesNotThrow() [Fact] public async Task AdvanceAfterCompleteThrows() { - var reader = await CreateReaderOverBytes(new byte[100]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[100]); var buffer = (await reader.ReadAsync()).Buffer; reader.Complete(); @@ -541,7 +541,7 @@ public void NullMessageReaderThrows() [Fact] public async Task CanReadLargeMessages() { - var reader = await CreateReaderOverBytes(new byte[10000]).ConfigureAwait(false); + var reader = await CreateReaderOverBytes(new byte[10000]); var readResult = await reader.ReadAsync(); Assert.Equal(10000, readResult.Buffer.Length); @@ -553,9 +553,9 @@ public async Task CanReadLargeMessages() public async Task EmptyMessageCausesResultToBeCompleted() { var reader = CreateReader(out var writeFunc); - await writeFunc(new byte[100]).ConfigureAwait(false); - await writeFunc(new byte[0]).ConfigureAwait(false); - await writeFunc(new byte[100]).ConfigureAwait(false); + await writeFunc(new byte[100]); + await writeFunc(new byte[0]); + await writeFunc(new byte[100]); var readResult = await reader.ReadAsync(); var buffer = readResult.Buffer; @@ -663,10 +663,10 @@ async Task WritingTask() for (var i = 0; i < 3; i++) { protocol.WriteMessage(data, writer); - await writer.FlushAsync().ConfigureAwait(false); + await writer.FlushAsync(); } - await writer.CompleteAsync().ConfigureAwait(false); + await writer.CompleteAsync(); } async Task ReadingTask() @@ -675,7 +675,7 @@ async Task ReadingTask() while (true) { - var result = await reader.ReadAsync().ConfigureAwait(false); + var result = await reader.ReadAsync(); var buffer = result.Buffer; if (buffer.Length < 3 * data.Length) @@ -687,7 +687,7 @@ async Task ReadingTask() Assert.Equal(Enumerable.Repeat(data, 3).SelectMany(a => a).ToArray(), buffer.ToArray()); reader.AdvanceTo(buffer.End); - result = await reader.ReadAsync().ConfigureAwait(false); + result = await reader.ReadAsync(); Assert.True(result.IsCompleted); break; }