diff --git a/src/Fx/Singleton.cs b/src/Fx/Singleton.cs index f09058e4..a45890f6 100644 --- a/src/Fx/Singleton.cs +++ b/src/Fx/Singleton.cs @@ -88,7 +88,7 @@ protected virtual void Dispose(bool disposing) { this.disposed = true; var thisTaskCompletionSource = this.taskCompletionSource; - if (thisTaskCompletionSource != null && thisTaskCompletionSource.Task.Status == TaskStatus.RanToCompletion) + if (thisTaskCompletionSource != null && thisTaskCompletionSource.Task.Status == TaskStatus.RanToCompletion && this.TryRemove()) { OnSafeClose(thisTaskCompletionSource.Task.Result); } @@ -126,6 +126,11 @@ public async Task GetOrCreateAsync(TimeSpan timeout) { TValue value = await this.OnCreateAsync(timeout).ConfigureAwait(false); tcs.SetResult(value); + + if (this.disposed && this.TryRemove()) + { + OnSafeClose(value); + } } catch (Exception ex) when (!Fx.IsFatal(ex)) { diff --git a/test/Test.Microsoft.Amqp/TestCases/SingletonTests.cs b/test/Test.Microsoft.Amqp/TestCases/SingletonTests.cs new file mode 100644 index 00000000..4b4fecff --- /dev/null +++ b/test/Test.Microsoft.Amqp/TestCases/SingletonTests.cs @@ -0,0 +1,63 @@ +using System; +using System.Threading.Tasks; +using Microsoft.Azure.Amqp; +using Xunit; + +namespace Test.Microsoft.Azure.Amqp +{ + public class SingletonTests + { + [Fact] + public async Task SingletonConcurrentCloseOpenTests() + { + var createTcs = new TaskCompletionSource(); + var closeTcs = new TaskCompletionSource(); + + var singleton = new SingletonTester(createTcs.Task, closeTcs.Task); + + var creating = singleton.GetOrCreateAsync(TimeSpan.FromSeconds(1)); + var closing = singleton.CloseAsync(); + + closeTcs.SetResult(new object()); + await closing; + + createTcs.SetResult(new object()); + await creating; + + await Assert.ThrowsAsync(() => singleton.GetOrCreateAsync(TimeSpan.FromSeconds(1))); + + var createdObj = GetInternalProperty(singleton, "Value"); + Assert.Null(createdObj); + } + + private class SingletonTester : Singleton + { + private readonly Task _onCreateComplete; + private readonly Task _onSafeCloseComplete; + + public SingletonTester(Task onCreateComplete, Task onSafeCloseComplete) + { + _onCreateComplete = onCreateComplete; + _onSafeCloseComplete = onSafeCloseComplete; + } + + protected override async Task OnCreateAsync(TimeSpan timeout) + { + await _onCreateComplete; + return new object(); + } + + protected override void OnSafeClose(object value) + { + _onSafeCloseComplete.GetAwaiter().GetResult(); + } + } + + private T GetInternalProperty(object from, string propertyName) + where T : class + { + var prop = from.GetType().GetProperty(propertyName, System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + return prop.GetValue(from) as T; + } + } +}