diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs index c3135042b032b..5548fee130bc8 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/ResettableValueTaskSource.cs @@ -105,9 +105,7 @@ public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, C _state = State.Awaiting; } // None, Ready, Completed: return the current task. - if (state == State.None || - state == State.Ready || - state == State.Completed) + if (state is State.None or State.Ready or State.Completed) { // Remember that the value task with the current version is being given out. _hasWaiter = true; @@ -167,8 +165,7 @@ private bool TryComplete(Exception? exception, bool final) // If the _valueTaskSource has already been set, we don't want to lose the result by overwriting it. // So keep it as is and store the result in _finalTaskSource. - if (state == State.None || - state == State.Awaiting) + if (state is State.None or State.Awaiting) { _state = final ? State.Completed : State.Ready; } @@ -178,16 +175,14 @@ private bool TryComplete(Exception? exception, bool final) { // Set up the exception stack trace for the caller. exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception; - if (state == State.None || - state == State.Awaiting) + if (state is State.None or State.Awaiting) { _valueTaskSource.SetException(exception); } } else { - if (state == State.None || - state == State.Awaiting) + if (state is State.None or State.Awaiting) { _valueTaskSource.SetResult(final); } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index 049ab54063800..db3adf776d542 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -109,7 +109,26 @@ static async ValueTask StartConnectAsync(QuicClientConnectionOpt private int _disposed; private readonly ValueTaskSource _connectedTcs = new ValueTaskSource(); - private readonly ValueTaskSource _shutdownTcs = new ValueTaskSource(); + private readonly ResettableValueTaskSource _shutdownTcs = new ResettableValueTaskSource() + { + CancellationAction = target => + { + try + { + if (target is QuicConnection connection) + { + // The OCE will be propagated through stored CancellationToken in ResettableValueTaskSource. + connection._shutdownTcs.TrySetResult(); + } + } + catch (ObjectDisposedException) + { + // We collided with a Dispose in another thread. This can happen + // when using CancellationTokenSource.CancelAfter. + // Ignore the exception + } + } + }; private readonly CancellationTokenSource _shutdownTokenSource = new CancellationTokenSource(); @@ -467,7 +486,7 @@ public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken { ObjectDisposedException.ThrowIf(_disposed == 1, this); - if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken)) + if (_shutdownTcs.TryGetValueTask(out ValueTask valueTask, this, cancellationToken)) { unsafe { @@ -520,7 +539,7 @@ private unsafe int HandleEventShutdownComplete() _acceptQueue.Writer.TryComplete(exception); _connectedTcs.TrySetException(exception); _shutdownTokenSource.Cancel(); - _shutdownTcs.TrySetResult(); + _shutdownTcs.TrySetResult(final: true); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) @@ -626,7 +645,7 @@ public async ValueTask DisposeAsync() } // Check if the connection has been shut down and if not, shut it down. - if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this)) + if (_shutdownTcs.TryGetValueTask(out ValueTask valueTask, this)) { unsafe { @@ -636,9 +655,19 @@ public async ValueTask DisposeAsync() (ulong)_defaultCloseErrorCode); } } + else if (!valueTask.IsCompletedSuccessfully) + { + unsafe + { + MsQuicApi.Api.ConnectionShutdown( + _handle, + QUIC_CONNECTION_SHUTDOWN_FLAGS.SILENT, + (ulong)_defaultCloseErrorCode); + } + } // Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released. - await valueTask.ConfigureAwait(false); + await _shutdownTcs.GetFinalTask(this).ConfigureAwait(false); Debug.Assert(_connectedTcs.IsCompleted); _handle.Dispose(); _shutdownTokenSource.Dispose(); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index e39d718718d11..6f0a0d8bb5b75 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -220,7 +220,7 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient { using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, connection.ConnectionShutdownToken); cancellationToken = linkedCts.Token; - // initial timeout for retrieving connection options + // Initial timeout for retrieving connection options. linkedCts.CancelAfter(handshakeTimeout); wrapException = true; @@ -229,7 +229,7 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient options.Validate(nameof(options)); - // update handshake timetout based on the returned value + // Update handshake timeout based on the returned value. handshakeTimeout = options.HandshakeTimeout; linkedCts.CancelAfter(handshakeTimeout); @@ -248,12 +248,12 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient NetEventSource.Info(connection, $"{connection} Connection closed by remote peer"); } - // retrieve the exception which failed the handshake, the parameters are not going to be - // validated because the inner _connectedTcs is already transitioned to faulted state + // Retrieve the exception which failed the handshake, the parameters are not going to be + // validated because the inner _connectedTcs is already transitioned to faulted state. ValueTask task = connection.FinishHandshakeAsync(null!, null!, default); Debug.Assert(task.IsFaulted); - // unwrap AggregateException and propagate it to the accept queue + // Unwrap AggregateException and propagate it to the accept queue. Exception ex = task.AsTask().Exception!.InnerException!; await connection.DisposeAsync().ConfigureAwait(false); diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index 6dcd92395b50c..2e8f6a50e7e64 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -613,6 +613,7 @@ private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data) _receiveTcs.TrySetException(exception, final: true); _sendTcs.TrySetException(exception, final: true); } + _startedTcs.TrySetResult(); _shutdownTcs.TrySetResult(); return QUIC_STATUS_SUCCESS; } diff --git a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs index 9e58cab98c095..98d72124f0048 100644 --- a/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs +++ b/src/libraries/System.Net.Quic/tests/FunctionalTests/QuicConnectionTests.cs @@ -124,6 +124,47 @@ await RunClientServer( }); } + [Fact] + public async Task DisposeAfterCloseCanceled() + { + using var sync = new SemaphoreSlim(0); + + await RunClientServer( + async clientConnection => + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + await Assert.ThrowsAsync(async () => await clientConnection.CloseAsync(ExpectedErrorCode, cts.Token)); + await clientConnection.DisposeAsync(); + sync.Release(); + }, + async serverConnection => + { + await sync.WaitAsync(); + await serverConnection.DisposeAsync(); + }); + } + + [Fact] + public async Task DisposeAfterCloseTaskStored() + { + using var sync = new SemaphoreSlim(0); + + await RunClientServer( + async clientConnection => + { + var cts = new CancellationTokenSource(); + var task = clientConnection.CloseAsync(0).AsTask(); + await clientConnection.DisposeAsync(); + sync.Release(); + }, + async serverConnection => + { + await sync.WaitAsync(); + await serverConnection.DisposeAsync(); + }); + } + [Fact] public async Task ConnectionClosedByPeer_WithPendingAcceptAndConnect_PendingAndSubsequentThrowConnectionAbortedException() {