Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUIC] Call silent shutdown in case CloseAsync failed. #96807

Merged
merged 8 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ public bool TryGetValueTask(out ValueTask valueTask, object? keepAlive = null, C
_state = State.Awaiting;
}
// None, Ready, Completed: return the current task.
if (state == State.None ||
state == State.Ready ||
state == State.Completed)
if (state is State.None or State.Ready or State.Completed)
{
// Remember that the value task with the current version is being given out.
_hasWaiter = true;
Expand Down Expand Up @@ -167,8 +165,7 @@ private bool TryComplete(Exception? exception, bool final)

// If the _valueTaskSource has already been set, we don't want to lose the result by overwriting it.
// So keep it as is and store the result in _finalTaskSource.
if (state == State.None ||
state == State.Awaiting)
if (state is State.None or State.Awaiting)
{
_state = final ? State.Completed : State.Ready;
}
Expand All @@ -178,16 +175,14 @@ private bool TryComplete(Exception? exception, bool final)
{
// Set up the exception stack trace for the caller.
exception = exception.StackTrace is null ? ExceptionDispatchInfo.SetCurrentStackTrace(exception) : exception;
if (state == State.None ||
state == State.Awaiting)
if (state is State.None or State.Awaiting)
{
_valueTaskSource.SetException(exception);
}
}
else
{
if (state == State.None ||
state == State.Awaiting)
if (state is State.None or State.Awaiting)
{
_valueTaskSource.SetResult(final);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,26 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
private int _disposed;

private readonly ValueTaskSource _connectedTcs = new ValueTaskSource();
private readonly ValueTaskSource _shutdownTcs = new ValueTaskSource();
private readonly ResettableValueTaskSource _shutdownTcs = new ResettableValueTaskSource()
{
CancellationAction = target =>
{
try
{
if (target is QuicConnection connection)
{
// The OCE will be propagated through stored CancellationToken in ResettableValueTaskSource.
connection._shutdownTcs.TrySetResult();
}
}
catch (ObjectDisposedException)
{
// We collided with a Dispose in another thread. This can happen
// when using CancellationTokenSource.CancelAfter.
// Ignore the exception
}
}
};

private readonly CancellationTokenSource _shutdownTokenSource = new CancellationTokenSource();

Expand Down Expand Up @@ -467,7 +486,7 @@ public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken
{
ObjectDisposedException.ThrowIf(_disposed == 1, this);

if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this, cancellationToken))
if (_shutdownTcs.TryGetValueTask(out ValueTask valueTask, this, cancellationToken))
{
unsafe
{
Expand Down Expand Up @@ -520,7 +539,7 @@ private unsafe int HandleEventShutdownComplete()
_acceptQueue.Writer.TryComplete(exception);
_connectedTcs.TrySetException(exception);
_shutdownTokenSource.Cancel();
_shutdownTcs.TrySetResult();
_shutdownTcs.TrySetResult(final: true);
return QUIC_STATUS_SUCCESS;
}
private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data)
Expand Down Expand Up @@ -626,7 +645,7 @@ public async ValueTask DisposeAsync()
}

// Check if the connection has been shut down and if not, shut it down.
if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this))
if (_shutdownTcs.TryGetValueTask(out ValueTask valueTask, this))
{
unsafe
{
Expand All @@ -636,9 +655,19 @@ public async ValueTask DisposeAsync()
(ulong)_defaultCloseErrorCode);
}
}
else if (!valueTask.IsCompletedSuccessfully)
{
unsafe
{
MsQuicApi.Api.ConnectionShutdown(
_handle,
QUIC_CONNECTION_SHUTDOWN_FLAGS.SILENT,
(ulong)_defaultCloseErrorCode);
}
}

// Wait for SHUTDOWN_COMPLETE, the last event, so that all resources can be safely released.
await valueTask.ConfigureAwait(false);
await _shutdownTcs.GetFinalTask(this).ConfigureAwait(false);
Debug.Assert(_connectedTcs.IsCompleted);
_handle.Dispose();
_shutdownTokenSource.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient
{
using CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, connection.ConnectionShutdownToken);
cancellationToken = linkedCts.Token;
// initial timeout for retrieving connection options
// Initial timeout for retrieving connection options.
linkedCts.CancelAfter(handshakeTimeout);

wrapException = true;
Expand All @@ -229,7 +229,7 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient

options.Validate(nameof(options));

// update handshake timetout based on the returned value
// Update handshake timeout based on the returned value.
handshakeTimeout = options.HandshakeTimeout;
linkedCts.CancelAfter(handshakeTimeout);

Expand All @@ -248,12 +248,12 @@ private async void StartConnectionHandshake(QuicConnection connection, SslClient
NetEventSource.Info(connection, $"{connection} Connection closed by remote peer");
}

// retrieve the exception which failed the handshake, the parameters are not going to be
// validated because the inner _connectedTcs is already transitioned to faulted state
// Retrieve the exception which failed the handshake, the parameters are not going to be
// validated because the inner _connectedTcs is already transitioned to faulted state.
ValueTask task = connection.FinishHandshakeAsync(null!, null!, default);
Debug.Assert(task.IsFaulted);

// unwrap AggregateException and propagate it to the accept queue
// Unwrap AggregateException and propagate it to the accept queue.
Exception ex = task.AsTask().Exception!.InnerException!;

await connection.DisposeAsync().ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data)
_receiveTcs.TrySetException(exception, final: true);
_sendTcs.TrySetException(exception, final: true);
}
_startedTcs.TrySetResult();
_shutdownTcs.TrySetResult();
return QUIC_STATUS_SUCCESS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,47 @@ await RunClientServer(
});
}

[Fact]
public async Task DisposeAfterCloseCanceled()
{
using var sync = new SemaphoreSlim(0);

await RunClientServer(
async clientConnection =>
{
var cts = new CancellationTokenSource();
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(async () => await clientConnection.CloseAsync(ExpectedErrorCode, cts.Token));
await clientConnection.DisposeAsync();
sync.Release();
},
async serverConnection =>
{
await sync.WaitAsync();
await serverConnection.DisposeAsync();
});
}

[Fact]
public async Task DisposeAfterCloseTaskStored()
{
using var sync = new SemaphoreSlim(0);

await RunClientServer(
async clientConnection =>
{
var cts = new CancellationTokenSource();
var task = clientConnection.CloseAsync(0).AsTask();
await clientConnection.DisposeAsync();
sync.Release();
},
async serverConnection =>
{
await sync.WaitAsync();
await serverConnection.DisposeAsync();
});
}

[Fact]
public async Task ConnectionClosedByPeer_WithPendingAcceptAndConnect_PendingAndSubsequentThrowConnectionAbortedException()
{
Expand Down
Loading