From a5050596f8d3c76421d9ff55279bbcda1a5a0411 Mon Sep 17 00:00:00 2001 From: Jimmy Byrd Date: Sun, 14 Jan 2024 12:41:30 -0500 Subject: [PATCH] fixing caching of cancelled cached tasks (#1221) --- src/FsAutoComplete.Core/AdaptiveExtensions.fs | 80 +++++++++++-------- 1 file changed, 48 insertions(+), 32 deletions(-) diff --git a/src/FsAutoComplete.Core/AdaptiveExtensions.fs b/src/FsAutoComplete.Core/AdaptiveExtensions.fs index 7d9ba5fce..26820e83d 100644 --- a/src/FsAutoComplete.Core/AdaptiveExtensions.fs +++ b/src/FsAutoComplete.Core/AdaptiveExtensions.fs @@ -28,6 +28,20 @@ module AdaptiveExtensions = with _ -> () + type TaskCompletionSource<'a> with + + /// https://github.com/dotnet/runtime/issues/47998 + member tcs.TrySetFromTask(real: Task<'a>) = + task { + try + let! r = real + tcs.TrySetResult r |> ignore + with + | :? OperationCanceledException as x -> tcs.TrySetCanceled(x.CancellationToken) |> ignore + | ex -> tcs.TrySetException ex |> ignore + } + |> ignore> + type ChangeableHashMap<'Key, 'Value> with /// @@ -366,33 +380,44 @@ type internal RefCountingTaskCreator<'a>(create: CancellationToken -> Task<'a>) /// and AdaptiveCancellableTask<'a>(cancel: unit -> unit, real: Task<'a>) = let cts = new CancellationTokenSource() + let mutable cached: Task<'a> = null - let output = - if real.IsCompleted then - real - else - let tcs = new TaskCompletionSource<'a>() + let getTask () = + let createCached () = + if real.IsCompleted then + real + else + task { + let tcs = new TaskCompletionSource<'a>() + use _s = cts.Token.Register(fun () -> tcs.TrySetCanceled(cts.Token) |> ignore) - let s = cts.Token.Register(fun () -> tcs.TrySetCanceled() |> ignore) + tcs.TrySetFromTask real - real.ContinueWith(fun (t: Task<'a>) -> - s.Dispose() + return! tcs.Task + } - if t.IsFaulted then tcs.TrySetException(t.Exception) - elif t.IsCanceled then tcs.TrySetCanceled() - else tcs.TrySetResult(t.Result)) - |> ignore + cached <- + match cached with + | null -> createCached () + | x when x.IsCompleted && not real.IsCompleted -> + // When the real task isn't finished, we create a new task to attach to the real task + // so we can cancel this new task immediately without waiting for the real task to cancel (as other tasks might depend on it and we use ref counting) + // However, if the cached task is completed (cancelled or faulted) but the real task is not, + // we should re-attach to the original task instead of assuming we want to cache the cancelled/faulted task. + createCached () + | o -> o - tcs.Task + cached /// Will run the cancel function passed into the constructor and set the output Task to cancelled state. member x.Cancel() = - cancel () - cts.TryCancel() + lock x (fun () -> + cancel () + cts.TryCancel()) /// The output of the passed in task to the constructor. /// - member x.Task = output + member x.Task = lock x getTask type asyncaval<'a> = inherit IAdaptiveObject @@ -575,10 +600,7 @@ module AsyncAVal = let ref = RefCountingTaskCreator( cancellableTask { - let! ct = CancellableTask.getCancellationToken () - let it = input.GetValue t - use _s = ct.Register(fun () -> it.Cancel()) - let! i = it + let! i = input.GetValue t return! mapping i } ) @@ -633,8 +655,8 @@ module AsyncAVal = ta.Cancel() tb.Cancel()) - let! va = ta - let! vb = tb + let! va = ta.Task + let! vb = tb.Task return! mapping va vb } ) @@ -672,11 +694,8 @@ module AsyncAVal = let outerTask = RefCountingTaskCreator( cancellableTask { - let it = value.GetValue t + let! i = value.GetValue t let! ct = CancellableTask.getCancellationToken () - use _s = ct.Register(fun () -> it.Cancel()) - - let! i = it let inner = mapping i ct return inner @@ -690,14 +709,11 @@ module AsyncAVal = let ref = RefCountingTaskCreator( cancellableTask { - let innerCellTask = outerTask.New() - let! ct = CancellableTask.getCancellationToken () - use _s = ct.Register(fun () -> innerCellTask.Cancel()) - let! inner = innerCellTask - let innerTask = inner.GetValue t + let! inner = outerTask.New() lock inners (fun () -> inners.Value <- HashSet.add inner inners.Value) + let innerTask = inner.GetValue t use _s2 = ct.Register(fun () -> @@ -705,7 +721,7 @@ module AsyncAVal = lock inners (fun () -> inners.Value <- HashSet.remove inner inners.Value) inner.Outputs.Remove x |> ignore) - return! innerTask + return! innerTask.Task } )