From 9d29d1c7d08a8502972999ca7a1b5ef5023bdec9 Mon Sep 17 00:00:00 2001 From: Xin Chen Date: Wed, 11 Jan 2023 10:23:01 -0800 Subject: [PATCH] Support link operation timeout in drain operation. --- .../Amqp/ReceivingAmqpLink.cs | 113 ++++++++++++------ test/TestCases/CancellationTokenTests.cs | 91 +++++++++++++- 2 files changed, 164 insertions(+), 40 deletions(-) diff --git a/Microsoft.Azure.Amqp/Amqp/ReceivingAmqpLink.cs b/Microsoft.Azure.Amqp/Amqp/ReceivingAmqpLink.cs index 8187491c..5474a27f 100644 --- a/Microsoft.Azure.Amqp/Amqp/ReceivingAmqpLink.cs +++ b/Microsoft.Azure.Amqp/Amqp/ReceivingAmqpLink.cs @@ -29,7 +29,7 @@ public sealed class ReceivingAmqpLink : AmqpLink WorkCollection, DisposeAsyncResult, DeliveryState> pendingDispositions; AmqpMessage currentMessage; LinkedList waiterList; - HashSet drainTasks; + HashSet drainTasks; public ReceivingAmqpLink(AmqpLinkSettings settings) : this(null, settings) @@ -273,27 +273,12 @@ public bool EndReceiveMessages(IAsyncResult result, out IEnumerable public Task DrainAsyc(CancellationToken cancellationToken) { - var tcs = new DrainTaskCompletionSource(this); - lock (this.SyncRoot) - { - if (this.drainTasks == null) - { - this.drainTasks = new HashSet(); - } - - this.drainTasks.Add(tcs); - if (!this.Drain) - { - this.SendFlow(false, true, null); - } - - if (cancellationToken.CanBeCanceled) - { - cancellationToken.Register(s => ((DrainTaskCompletionSource)s).Cancel(), tcs, false); - } - } - - return tcs.Task; + return Task.Factory.FromAsync( + (thisPtr, k, c, s) => new DrainAsyncResult(thisPtr, thisPtr.OperationTimeout, k, c, s), + r => DrainAsyncResult.End(r), + this, + cancellationToken, + this); } public Task DisposeMessageAsync(ArraySegment deliveryTag, Outcome outcome, bool batchable, TimeSpan timeout) @@ -515,7 +500,7 @@ protected override void OnReceiveFlow(Flow flow) base.OnReceiveFlow(flow); if (draining && this.LinkCredit == 0) { - HashSet pendingTasks = null; + HashSet pendingTasks = null; lock (this.SyncRoot) { pendingTasks = this.drainTasks; @@ -526,7 +511,7 @@ protected override void OnReceiveFlow(Flow flow) { foreach (var task in pendingTasks) { - task.TrySetResult(true); + task.Signal(false); } } } @@ -898,37 +883,87 @@ private static void RemoveFromWaiterList(ReceiveAsyncResult result) } } - sealed class DrainTaskCompletionSource : TaskCompletionSource + sealed class DrainAsyncResult : TimeoutAsyncResult { readonly ReceivingAmqpLink link; - public DrainTaskCompletionSource(ReceivingAmqpLink link) + public DrainAsyncResult(ReceivingAmqpLink link, + TimeSpan timeout, + CancellationToken cancellationToken, + AsyncCallback callback, + object state) + : base(timeout, cancellationToken, callback, state) { this.link = link; + this.Start(); } - public void Cancel() + public static void End(IAsyncResult result) + { + AsyncResult.End(result); + } + + public void Signal(bool isSynchronous) + { + this.CompleteSelf(isSynchronous); + } + + public override void Cancel(bool isSynchronous) + { + if (this.Remove()) + { + this.CompleteSelf(isSynchronous, new TaskCanceledException()); + } + } + + protected override string Target + { + get { return "drain"; } + } + + protected override void CompleteOnTimer() + { + if (this.Remove()) + { + base.CompleteOnTimer(); + } + } + + void Start() { - bool setCancel = false; lock (this.link.SyncRoot) { - if (this.link.drainTasks != null) + if (this.link.drainTasks == null) { - if (this.link.drainTasks.Remove(this)) - { - setCancel = true; - if (this.link.drainTasks.Count == 0) - { - this.link.drainTasks = null; - } - } + this.link.drainTasks = new HashSet(); + } + + this.link.drainTasks.Add(this); + if (!this.link.Drain) + { + this.link.SendFlow(false, true, null); } } - if (setCancel) + this.StartTracking(); + } + + bool Remove() + { + lock (this.link.SyncRoot) { - this.TrySetCanceled(); + if (this.link.drainTasks == null || !this.link.drainTasks.Remove(this)) + { + return false; + } + + if (this.link.drainTasks.Count == 0) + { + this.link.drainTasks = null; + } } + + return true; } } diff --git a/test/TestCases/CancellationTokenTests.cs b/test/TestCases/CancellationTokenTests.cs index 251163ce..f2a06a58 100644 --- a/test/TestCases/CancellationTokenTests.cs +++ b/test/TestCases/CancellationTokenTests.cs @@ -579,6 +579,85 @@ await Assert.ThrowsAnyAsync(async () => } } + [Fact] + public Task LinkDrainTest() + { + return this.RunLinkDrainTest(false); + } + + [Fact] + public Task LinkDrainCanceledTest() + { + return this.RunLinkDrainTest(true); + } + + [Fact] + public Task LinkDrainTimeoutTest() + { + return this.RunLinkDrainTest(false, 800); + } + + async Task RunLinkDrainTest(bool cancelBefore, int timeoutMilliseconds = 0) + { + var provider = new TestRuntimeProvider() + { + LinkFactory = (s, t) => { t.SettleType = SettleMode.SettleOnDispose; return new TestLink(s, t, flowHang: true); } + }; + AmqpConnectionListener listener = new AmqpConnectionListener(addressUri.AbsoluteUri, provider); + listener.Open(); + + try + { + try + { + var factory = new AmqpConnectionFactory(); + var connection = await factory.OpenConnectionAsync(this.addressUri, AmqpConstants.DefaultTimeout); + var session = connection.CreateSession(new AmqpSessionSettings()); + await session.OpenAsync(AmqpConstants.DefaultTimeout); + + var link = new ReceivingAmqpLink(session, new AmqpLinkSettings() { Role = true, LinkName = "receiver", TotalLinkCredit = 0, Source = new Source(), Target = new Target() }); + await link.OpenAsync(AmqpConstants.DefaultTimeout); + + var cts = new CancellationTokenSource(); + if (cancelBefore && timeoutMilliseconds == 0) + { + cts.Cancel(); + } + + if (timeoutMilliseconds > 0) + { + link.Settings.OperationTimeout = TimeSpan.FromMilliseconds(timeoutMilliseconds); + } + + var task = link.DrainAsyc(cts.Token); + if (!cancelBefore && timeoutMilliseconds == 0) + { + await Task.Yield(); + cts.Cancel(); + } + + await task; + + Assert.True(false, "Exception not thrown"); + } + catch (Exception exception) + { + if (timeoutMilliseconds > 0) + { + Assert.Equal(typeof(TimeoutException), exception.GetType()); + } + else + { + Assert.Equal(typeof(TaskCanceledException), exception.GetType()); + } + } + } + finally + { + listener.Close(); + } + } + [Fact] public async Task CbsSendTokenNoCancelTest() { @@ -714,9 +793,10 @@ class TestLink : AmqpLink readonly bool sendHang; readonly bool receiveHang; readonly bool disposeHang; + readonly bool flowHang; public TestLink(AmqpSession session, AmqpLinkSettings settings, bool openHang = false, bool closeHang = false, - bool sendHang = false, bool receiveHang = false, bool disposeHang = false) + bool sendHang = false, bool receiveHang = false, bool disposeHang = false, bool flowHang = false) : base(session, settings) { this.openHang = openHang; @@ -724,6 +804,7 @@ public TestLink(AmqpSession session, AmqpLinkSettings settings, bool openHang = this.sendHang = sendHang; this.receiveHang = receiveHang; this.disposeHang = disposeHang; + this.flowHang = flowHang; } public override bool CreateDelivery(Transfer transfer, out Delivery delivery) @@ -768,6 +849,14 @@ protected override void OnDisposeDeliveryInternal(Delivery delivery) this.DisposeDelivery(delivery, true, AmqpConstants.AcceptedOutcome); } } + + protected override void OnReceiveFlow(Flow flow) + { + if (!this.flowHang) + { + base.OnReceiveFlow(flow); + } + } } } }