Skip to content

Commit

Permalink
* Set cancellation correctly for TaskCompletionSource in AsyncRpcCont…
Browse files Browse the repository at this point in the history
…inuation
  • Loading branch information
lukebakken committed Dec 27, 2024
1 parent f331ddf commit a3e2a69
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,23 @@ internal ConsumerDispatcherChannelBase(Impl.Channel channel, ushort concurrency)

public ValueTask HandleBasicConsumeOkAsync(IAsyncBasicConsumer consumer, string consumerTag, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if (false == _disposedValue && false == _quiesce)
{
AddConsumer(consumer, consumerTag);
WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag);
return _writer.WriteAsync(work, cancellationToken);
try
{
AddConsumer(consumer, consumerTag);
WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag);

cancellationToken.ThrowIfCancellationRequested();
return _writer.WriteAsync(work, cancellationToken);
}
catch
{
_ = GetAndRemoveConsumer(consumerTag);
throw;
}
}
else
{
Expand All @@ -101,10 +113,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag,
string exchange, string routingKey, IReadOnlyBasicProperties basicProperties, RentedMemory body,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if (false == _disposedValue && false == _quiesce)
{
IAsyncBasicConsumer consumer = GetConsumerOrDefault(consumerTag);
var work = WorkStruct.CreateDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, body);

cancellationToken.ThrowIfCancellationRequested();
return _writer.WriteAsync(work, cancellationToken);
}
else
Expand All @@ -115,10 +131,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag,

public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if (false == _disposedValue && false == _quiesce)
{
IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag);
WorkStruct work = WorkStruct.CreateCancelOk(consumer, consumerTag);

cancellationToken.ThrowIfCancellationRequested();
return _writer.WriteAsync(work, cancellationToken);
}
else
Expand All @@ -129,10 +149,14 @@ public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken

public ValueTask HandleBasicCancelAsync(string consumerTag, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

if (false == _disposedValue && false == _quiesce)
{
IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag);
WorkStruct work = WorkStruct.CreateCancel(consumer, consumerTag);

cancellationToken.ThrowIfCancellationRequested();
return _writer.WriteAsync(work, cancellationToken);
}
else
Expand Down
66 changes: 44 additions & 22 deletions projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ internal abstract class AsyncRpcContinuation<T> : IRpcContinuation

private bool _disposedValue;

public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken cancellationToken)
public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken rpcCancellationToken)
{
/*
* Note: we can't use an ObjectPool for these because the netstandard2.0
Expand Down Expand Up @@ -89,7 +89,7 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken canc
_tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false);

_linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(
_continuationTimeoutCancellationTokenSource.Token, cancellationToken);
_continuationTimeoutCancellationTokenSource.Token, rpcCancellationToken);
}

public CancellationToken CancellationToken
Expand All @@ -105,7 +105,27 @@ public ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter GetAwaiter()
return _tcsConfiguredTaskAwaitable.GetAwaiter();
}

public abstract Task HandleCommandAsync(IncomingCommand cmd);
public async Task HandleCommandAsync(IncomingCommand cmd)
{
try
{
await DoHandleCommandAsync(cmd)
.ConfigureAwait(false);
}
catch (OperationCanceledException)
{
if (CancellationToken.IsCancellationRequested)
{
_tcs.SetCanceled();
}
else
{
throw;
}
}
}

protected abstract Task DoHandleCommandAsync(IncomingCommand cmd);

public virtual void HandleChannelShutdown(ShutdownEventArgs reason)
{
Expand Down Expand Up @@ -141,17 +161,17 @@ public ConnectionSecureOrTuneAsyncRpcContinuation(TimeSpan continuationTimeout,
{
}

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.ConnectionSecure)
{
var secure = new ConnectionSecure(cmd.MethodSpan);
_tcs.TrySetResult(new ConnectionSecureOrTune(secure._challenge, default));
_tcs.SetResult(new ConnectionSecureOrTune(secure._challenge, default));
}
else if (cmd.CommandId == ProtocolCommandId.ConnectionTune)
{
var tune = new ConnectionTune(cmd.MethodSpan);
_tcs.TrySetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails
_tcs.SetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails
{
m_channelMax = tune._channelMax,
m_frameMax = tune._frameMax,
Expand All @@ -178,11 +198,11 @@ public SimpleAsyncRpcContinuation(ProtocolCommandId expectedCommandId, TimeSpan
_expectedCommandId = expectedCommandId;
}

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == _expectedCommandId)
{
_tcs.TrySetResult(true);
_tcs.SetResult(true);
}
else
{
Expand All @@ -206,14 +226,14 @@ public BasicCancelAsyncRpcContinuation(string consumerTag, IConsumerDispatcher c
_consumerDispatcher = consumerDispatcher;
}

public override async Task HandleCommandAsync(IncomingCommand cmd)
protected override async Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.BasicCancelOk)
{
_tcs.TrySetResult(true);
Debug.Assert(_consumerTag == new BasicCancelOk(cmd.MethodSpan)._consumerTag);
await _consumerDispatcher.HandleBasicCancelOkAsync(_consumerTag, CancellationToken)
.ConfigureAwait(false);
_tcs.SetResult(true);
}
else
{
Expand All @@ -235,14 +255,16 @@ public BasicConsumeAsyncRpcContinuation(IAsyncBasicConsumer consumer, IConsumerD
_consumerDispatcher = consumerDispatcher;
}

public override async Task HandleCommandAsync(IncomingCommand cmd)
protected override async Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.BasicConsumeOk)
{
var method = new BasicConsumeOk(cmd.MethodSpan);
_tcs.TrySetResult(method._consumerTag);

await _consumerDispatcher.HandleBasicConsumeOkAsync(_consumer, method._consumerTag, CancellationToken)
.ConfigureAwait(false);

_tcs.SetResult(method._consumerTag);
}
else
{
Expand All @@ -264,7 +286,7 @@ public BasicGetAsyncRpcContinuation(Func<ulong, ulong> adjustDeliveryTag,

internal DateTime StartTime { get; } = DateTime.UtcNow;

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.BasicGetOk)
{
Expand All @@ -280,11 +302,11 @@ public override Task HandleCommandAsync(IncomingCommand cmd)
header,
cmd.Body.ToArray());

_tcs.TrySetResult(result);
_tcs.SetResult(result);
}
else if (cmd.CommandId == ProtocolCommandId.BasicGetEmpty)
{
_tcs.TrySetResult(null);
_tcs.SetResult(null);
}
else
{
Expand Down Expand Up @@ -325,7 +347,7 @@ public override void HandleChannelShutdown(ShutdownEventArgs reason)

public Task OnConnectionShutdownAsync(object? sender, ShutdownEventArgs reason)
{
_tcs.TrySetResult(true);
_tcs.SetResult(true);
return Task.CompletedTask;
}
}
Expand Down Expand Up @@ -377,13 +399,13 @@ public QueueDeclareAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellati
{
}

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.QueueDeclareOk)
{
var method = new Client.Framing.QueueDeclareOk(cmd.MethodSpan);
var result = new QueueDeclareOk(method._queue, method._messageCount, method._consumerCount);
_tcs.TrySetResult(result);
_tcs.SetResult(result);
}
else
{
Expand Down Expand Up @@ -417,12 +439,12 @@ public QueueDeleteAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellatio
{
}

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.QueueDeleteOk)
{
var method = new QueueDeleteOk(cmd.MethodSpan);
_tcs.TrySetResult(method._messageCount);
_tcs.SetResult(method._messageCount);
}
else
{
Expand All @@ -440,12 +462,12 @@ public QueuePurgeAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellation
{
}

public override Task HandleCommandAsync(IncomingCommand cmd)
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
{
if (cmd.CommandId == ProtocolCommandId.QueuePurgeOk)
{
var method = new QueuePurgeOk(cmd.MethodSpan);
_tcs.TrySetResult(method._messageCount);
_tcs.SetResult(method._messageCount);
}
else
{
Expand Down
100 changes: 100 additions & 0 deletions projects/Test/Integration/GH/TestGitHubIssues.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// This source code is dual-licensed under the Apache License, version
// 2.0, and the Mozilla Public License, version 2.0.
//
// The APL v2.0:
//
//---------------------------------------------------------------------------
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//---------------------------------------------------------------------------
//
// The MPL v2.0:
//
//---------------------------------------------------------------------------
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
//---------------------------------------------------------------------------

using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client;
using RabbitMQ.Client.Events;
using Xunit;
using Xunit.Abstractions;

#nullable enable

namespace Test.Integration.GH
{
public class TestGitHubIssues : IntegrationFixture
{
public TestGitHubIssues(ITestOutputHelper output) : base(output)
{
}

public override Task InitializeAsync()
{
// NB: nothing to do here since each test creates its own factory,
// connections and channels
Assert.Null(_connFactory);
Assert.Null(_conn);
Assert.Null(_channel);
return Task.CompletedTask;
}

[Fact]
public async Task TestBasicConsumeCancellation_GH1750()
{
/*
* Note:
* Testing that the task is actually canceled requires a hacked RabbitMQ server.
* Modify deps/rabbit/src/rabbit_channel.erl, handle_cast for basic.consume_ok
* Before send/2, add timer:sleep(1000), then `make run-broker`
*
* The _output line at the end of the test will print TaskCanceledException
*/
Assert.Null(_connFactory);
Assert.Null(_conn);
Assert.Null(_channel);

_connFactory = CreateConnectionFactory();
_connFactory.AutomaticRecoveryEnabled = false;
_connFactory.TopologyRecoveryEnabled = false;

_conn = await _connFactory.CreateConnectionAsync();
_channel = await _conn.CreateChannelAsync();

QueueDeclareOk q = await _channel.QueueDeclareAsync();

var consumer = new AsyncEventingBasicConsumer(_channel);
consumer.ReceivedAsync += (o, a) =>
{
return Task.CompletedTask;
};

try
{
using var cts = new CancellationTokenSource(5);
await _channel.BasicConsumeAsync(q.QueueName, true, consumer, cts.Token);
}
catch (TaskCanceledException ex)
{
_output.WriteLine("ex: {0}", ex);
}
}
}
}
Loading

0 comments on commit a3e2a69

Please sign in to comment.