Skip to content

Commit

Permalink
Add QuicStream.WaitForWriteCompletionAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK committed Aug 31, 2021
1 parent 1d32e0a commit 75be833
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public override void Flush() { }
public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace System.Net.Quic.Implementations.Mock
{
Expand Down Expand Up @@ -244,6 +246,9 @@ internal MockStream OpenStream(long streamId, bool bidirectional)
}

MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional);
// TODO Streams are never removed from a connection. Consider cleaning up in the future.
state._streams[streamState._streamId] = streamState;

Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
streamChannel.Writer.TryWrite(streamState);

Expand Down Expand Up @@ -320,6 +325,12 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell
state._serverErrorCode = errorCode;
DrainAcceptQueue(errorCode, -1);
}

foreach (KeyValuePair<long, MockStream.StreamState> kvp in state._streams)
{
kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
}
}

Dispose();
Expand Down Expand Up @@ -474,8 +485,9 @@ public PeerStreamLimit(int maxUnidirectional, int maxBidirectional)
internal sealed class ConnectionState
{
public readonly SslApplicationProtocol _applicationProtocol;
public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public readonly Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
public readonly Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
public readonly ConcurrentDictionary<long, MockStream.StreamState> _streams;

public PeerStreamLimit? _clientStreamLimit;
public PeerStreamLimit? _serverStreamLimit;
Expand All @@ -490,6 +502,7 @@ public ConnectionState(SslApplicationProtocol applicationProtocol)
_clientInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_serverInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_clientErrorCode = _serverErrorCode = -1;
_streams = new ConcurrentDictionary<long, MockStream.StreamState>();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
if (endStream)
{
streamBuffer.EndWrite();
WritesCompletedTcs.TrySetResult();
}
}

Expand Down Expand Up @@ -206,10 +207,12 @@ internal override void AbortRead(long errorCode)
if (_isInitiator)
{
_streamState._outboundWriteErrorCode = errorCode;
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundWriteErrorCode = errorCode;
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}

ReadStreamBuffer?.AbortRead();
Expand All @@ -220,10 +223,12 @@ internal override void AbortWrite(long errorCode)
if (_isInitiator)
{
_streamState._outboundReadErrorCode = errorCode;
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundReadErrorCode = errorCode;
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}

WriteStreamBuffer?.EndWrite();
Expand Down Expand Up @@ -251,6 +256,8 @@ internal override void Shutdown()
{
_connection.LocalStreamLimit!.Bidirectional.Decrement();
}

WritesCompletedTcs.TrySetResult();
}

private void CheckDisposed()
Expand Down Expand Up @@ -283,6 +290,17 @@ public override ValueTask DisposeAsync()
return default;
}

internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
{
CheckDisposed();

return new ValueTask(WritesCompletedTcs.Task);
}

private TaskCompletionSource WritesCompletedTcs => _isInitiator
? _streamState._outboundWritesCompletedTcs
: _streamState._inboundWritesCompletedTcs;

internal sealed class StreamState
{
public readonly long _streamId;
Expand All @@ -292,6 +310,8 @@ internal sealed class StreamState
public long _inboundReadErrorCode;
public long _outboundWriteErrorCode;
public long _inboundWriteErrorCode;
public TaskCompletionSource _outboundWritesCompletedTcs;
public TaskCompletionSource _inboundWritesCompletedTcs;

private const int InitialBufferSize =
#if DEBUG
Expand All @@ -310,6 +330,8 @@ public StreamState(long streamId, bool bidirectional)
_streamId = streamId;
_outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize);
_inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null);
_outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
_inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ private sealed class State
// Resettable completions to be used for multiple calls to send.
public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();

public ShutdownWriteState ShutdownWriteState;

// Set once writes have been shutdown.
public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

public ShutdownState ShutdownState;
// The value makes sure that we release the handles only once.
public int ShutdownDone;
Expand Down Expand Up @@ -577,12 +582,26 @@ internal override void AbortWrite(long errorCode)
return;
}

bool shouldComplete = false;

lock (_state)
{
if (_state.SendState < SendState.Aborted)
{
_state.SendState = SendState.Aborted;
}

if (_state.ShutdownWriteState == ShutdownWriteState.None)
{
_state.ShutdownWriteState = ShutdownWriteState.Canceled;
shouldComplete = true;
}
}

if (shouldComplete)
{
_state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
}

StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
Expand Down Expand Up @@ -629,6 +648,23 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati
await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
}

internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
{
// TODO: What should happen if this is called for a unidirectional stream and there are no writes?

ThrowIfDisposed();

lock (_state)
{
if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed)
{
throw GetConnectionAbortedException(_state);
}
}

return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken));
}

internal override void Shutdown()
{
ThrowIfDisposed();
Expand Down Expand Up @@ -861,6 +897,11 @@ private static uint HandleEvent(State state, ref StreamEvent evt)
// Peer has stopped receiving data, don't send anymore.
case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED:
return HandleEventPeerRecvAborted(state, ref evt);
// Occurs when shutdown is completed for the send side.
// This only happens for shutdown on sending, not receiving
// Receive shutdown can only be abortive.
case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE:
return HandleEventSendShutdownComplete(state, ref evt);
// Shutdown for both sending and receiving is completed.
case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE:
return HandleEventShutdownComplete(state, ref evt);
Expand Down Expand Up @@ -993,23 +1034,37 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)

private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
{
bool shouldComplete = false;
bool shouldSendComplete = false;
bool shouldShutdownWriteComplete = false;
lock (state)
{
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
{
shouldComplete = true;
shouldSendComplete = true;
}

if (state.ShutdownWriteState == ShutdownWriteState.None)
{
state.ShutdownWriteState = ShutdownWriteState.Canceled;
shouldShutdownWriteComplete = true;
}

state.SendState = SendState.Aborted;
state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode;
}

if (shouldComplete)
if (shouldSendComplete)
{
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}

if (shouldShutdownWriteComplete)
{
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}

return MsQuicStatusCodes.Success;
}

Expand All @@ -1021,6 +1076,38 @@ private static uint HandleEventStartComplete(State state, ref StreamEvent evt)
return MsQuicStatusCodes.Success;
}

private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt)
{
// Graceful will be false in three situations:
// 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised.
// ShutdownWriteCompletionSource is already complete with an error.
// 2. We aborted writes.
// ShutdownWriteCompletionSource is already complete with an error.
// 3. The connection was closed.
// SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error.
//
// Only use this event with sends gracefully completed.
if (evt.Data.SendShutdownComplete.Graceful != 0)
{
bool shouldComplete = false;
lock (state)
{
if (state.ShutdownWriteState == ShutdownWriteState.None)
{
state.ShutdownWriteState = ShutdownWriteState.Finished;
shouldComplete = true;
}
}

if (shouldComplete)
{
state.ShutdownWriteCompletionSource.SetResult();
}
}

return MsQuicStatusCodes.Success;
}

private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt)
{
StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete;
Expand All @@ -1031,6 +1118,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
}

bool shouldReadComplete = false;
bool shouldShutdownWriteComplete = false;
bool shouldShutdownComplete = false;

lock (state)
Expand All @@ -1040,6 +1128,15 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt

shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);

if (state.ShutdownWriteState == ShutdownWriteState.None)
{
// TODO: We can get to this point if the stream is unidirectional and there are no writes.
// Consider what is the best behavior here with write shutdown and the read side of
// unidirecitonal streams in the future.
state.ShutdownWriteState = ShutdownWriteState.Finished;
shouldShutdownWriteComplete = true;
}

if (state.ShutdownState == ShutdownState.None)
{
state.ShutdownState = ShutdownState.Finished;
Expand All @@ -1052,6 +1149,11 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
state.ReceiveResettableCompletionSource.Complete(0);
}

if (shouldShutdownWriteComplete)
{
state.ShutdownWriteCompletionSource.SetResult();
}

if (shouldShutdownComplete)
{
state.ShutdownCompletionSource.SetResult();
Expand Down Expand Up @@ -1361,6 +1463,7 @@ private static uint HandleEventConnectionClose(State state)

bool shouldCompleteRead = false;
bool shouldCompleteSend = false;
bool shouldCompleteShutdownWrite = false;
bool shouldCompleteShutdown = false;

lock (state)
Expand All @@ -1373,6 +1476,12 @@ private static uint HandleEventConnectionClose(State state)
}
state.SendState = SendState.ConnectionClosed;

if (state.ShutdownWriteState == ShutdownWriteState.None)
{
shouldCompleteShutdownWrite = true;
}
state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed;

if (state.ShutdownState == ShutdownState.None)
{
shouldCompleteShutdown = true;
Expand All @@ -1392,6 +1501,12 @@ private static uint HandleEventConnectionClose(State state)
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
}

if (shouldCompleteShutdownWrite)
{
state.ShutdownWriteCompletionSource.SetException(
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
}

if (shouldCompleteShutdown)
{
state.ShutdownCompletionSource.SetException(
Expand Down Expand Up @@ -1493,6 +1608,14 @@ private enum ReadState
Closed
}

private enum ShutdownWriteState
{
None = 0,
Canceled,
Finished,
ConnectionClosed
}

private enum ShutdownState
{
None = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable

internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);

internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default);

internal abstract void Shutdown();

internal abstract void Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ public override int WriteTimeout

public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);

public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken);

public void Shutdown() => _provider.Shutdown();

protected override void Dispose(bool disposing)
Expand Down
Loading

0 comments on commit 75be833

Please sign in to comment.