diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs index 3eea4d19c692bb..39bd5e80f6cef0 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicConnection.cs @@ -710,6 +710,12 @@ private void Dispose(bool disposing) if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(_state, $"{TraceId()} Stream disposing {disposing}"); + // If we haven't already shutdown gracefully (via a successful CloseAsync call), then force an abortive shutdown. + MsQuicApi.Api.ConnectionShutdownDelegate( + _state.Handle, + QUIC_CONNECTION_SHUTDOWN_FLAGS.SILENT, + 0); + bool releaseHandles = false; lock (_state) { @@ -740,7 +746,10 @@ private void Dispose(bool disposing) // It's unclear how to gracefully wait for a connection to be 100% done. internal override ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken = default) { - ThrowIfDisposed(); + if (_disposed == 1) + { + return default; + } return ShutdownAsync(QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, errorCode); } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs index 41b8d549ed7480..9988961e81e5af 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs @@ -11,6 +11,8 @@ namespace System.Net.Quic.Tests public abstract class QuicConnectionTests : QuicTestBase where T : IQuicImplProviderFactory, new() { + const int ExpectedErrorCode = 1234; + [Fact] public async Task TestConnect() { @@ -39,8 +41,6 @@ public async Task TestConnect() [ActiveIssue("https://github.com/dotnet/runtime/issues/55242", TestPlatforms.Linux)] public async Task AcceptStream_ConnectionAborted_ByClient_Throws() { - const int ExpectedErrorCode = 1234; - using var sync = new SemaphoreSlim(0); await RunClientServer( @@ -56,6 +56,110 @@ await RunClientServer( Assert.Equal(ExpectedErrorCode, ex.ErrorCode); }); } + + private static async Task DoWrites(QuicStream writer, int writeCount) + { + for (int i = 0; i < writeCount; i++) + { + await writer.WriteAsync(new byte[1]); + } + } + + private static async Task DoReads(QuicStream reader, int readCount) + { + for (int i = 0; i < readCount; i++) + { + int bytesRead = await reader.ReadAsync(new byte[1]); + Assert.Equal(1, bytesRead); + } + } + + [Theory] + [InlineData(1)] + [InlineData(10)] + public async Task CloseAsync_WithOpenStream_LocalAndPeerStreamsFailWithQuicOperationAbortedException(int writesBeforeClose) + { + if (typeof(T) == typeof(MockProviderFactory)) + { + return; + } + + using var sync = new SemaphoreSlim(0); + + await RunClientServer( + async clientConnection => + { + using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); + await DoWrites(clientStream, writesBeforeClose); + + // Wait for peer to receive data + await sync.WaitAsync(); + + await clientConnection.CloseAsync(ExpectedErrorCode); + + await Assert.ThrowsAsync(async () => await clientStream.ReadAsync(new byte[1])); + await Assert.ThrowsAsync(async () => await clientStream.WriteAsync(new byte[1])); + }, + async serverConnection => + { + using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await DoReads(serverStream, writesBeforeClose); + + sync.Release(); + + // Since the peer did the abort, we should receive the abort error code in the exception. + QuicConnectionAbortedException ex; + ex = await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + ex = await Assert.ThrowsAsync(async () => await serverStream.WriteAsync(new byte[1])); + Assert.Equal(ExpectedErrorCode, ex.ErrorCode); + }); + } + + [OuterLoop("Depends on IdleTimeout")] + [Theory] + [InlineData(1)] + [InlineData(10)] + public async Task Dispose_WithOpenLocalStream_LocalStreamFailsWithQuicOperationAbortedException(int writesBeforeClose) + { + if (typeof(T) == typeof(MockProviderFactory)) + { + return; + } + + // Set a short idle timeout so that after we dispose the connection, the peer will discover the connection is dead before too long. + QuicListenerOptions listenerOptions = CreateQuicListenerOptions(); + listenerOptions.IdleTimeout = TimeSpan.FromSeconds(1); + + using var sync = new SemaphoreSlim(0); + + await RunClientServer( + async clientConnection => + { + using QuicStream clientStream = clientConnection.OpenBidirectionalStream(); + await DoWrites(clientStream, writesBeforeClose); + + // Wait for peer to receive data + await sync.WaitAsync(); + + clientConnection.Dispose(); + + await Assert.ThrowsAsync(async () => await clientStream.ReadAsync(new byte[1])); + await Assert.ThrowsAsync(async () => await clientStream.WriteAsync(new byte[1])); + }, + async serverConnection => + { + using QuicStream serverStream = await serverConnection.AcceptStreamAsync(); + await DoReads(serverStream, writesBeforeClose); + + sync.Release(); + + // The client has done an abortive shutdown of the connection, which means we are not notified that the connection has closed. + // But the connection idle timeout should kick in and eventually we will get exceptions. + await Assert.ThrowsAsync(async () => await serverStream.ReadAsync(new byte[1])); + await Assert.ThrowsAsync(async () => await serverStream.WriteAsync(new byte[1])); + }, listenerOptions: listenerOptions); + } } public sealed class QuicConnectionTests_MockProvider : QuicConnectionTests { } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs index da2cfb37f412cd..bcb2bd247ee071 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs @@ -57,15 +57,21 @@ internal QuicConnection CreateQuicConnection(IPEndPoint endpoint) return new QuicConnection(ImplementationProvider, endpoint, GetSslClientAuthenticationOptions()); } - internal QuicListener CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100) + internal QuicListenerOptions CreateQuicListenerOptions() { - var options = new QuicListenerOptions() + return new QuicListenerOptions() { ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0), - ServerAuthenticationOptions = GetSslServerAuthenticationOptions(), - MaxUnidirectionalStreams = maxUnidirectionalStreams, - MaxBidirectionalStreams = maxBidirectionalStreams + ServerAuthenticationOptions = GetSslServerAuthenticationOptions() }; + } + + internal QuicListener CreateQuicListener(int maxUnidirectionalStreams = 100, int maxBidirectionalStreams = 100) + { + var options = CreateQuicListenerOptions(); + options.MaxUnidirectionalStreams = maxUnidirectionalStreams; + options.MaxBidirectionalStreams = maxBidirectionalStreams; + return CreateQuicListener(options); } @@ -111,9 +117,12 @@ internal async Task PingPong(QuicConnection client, QuicConnection server) private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options); - internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000) + internal async Task RunClientServer(Func clientFunction, Func serverFunction, int iterations = 1, int millisecondsTimeout = 10_000, QuicListenerOptions listenerOptions = null) { - using QuicListener listener = CreateQuicListener(); + const long ClientCloseErrorCode = 11111; + const long ServerCloseErrorCode = 22222; + + using QuicListener listener = CreateQuicListener(listenerOptions ?? CreateQuicListenerOptions()); using var serverFinished = new SemaphoreSlim(0); using var clientFinished = new SemaphoreSlim(0); @@ -126,18 +135,20 @@ await new[] { using QuicConnection serverConnection = await listener.AcceptConnectionAsync(); await serverFunction(serverConnection); + serverFinished.Release(); await clientFinished.WaitAsync(); - await serverConnection.CloseAsync(0); + await serverConnection.CloseAsync(ServerCloseErrorCode); }), Task.Run(async () => { using QuicConnection clientConnection = CreateQuicConnection(listener.ListenEndPoint); await clientConnection.ConnectAsync(); await clientFunction(clientConnection); + clientFinished.Release(); await serverFinished.WaitAsync(); - await clientConnection.CloseAsync(0); + await clientConnection.CloseAsync(ClientCloseErrorCode); }) }.WhenAllOrAnyFailed(millisecondsTimeout); }