From 6b0d5d441fd094bef1f55291a366b66ed1ea9bb3 Mon Sep 17 00:00:00 2001 From: Sebastiaan Koppe Date: Sun, 23 Jun 2024 21:12:17 +0200 Subject: [PATCH] Add then with Result support --- source/concurrency/operations/then.d | 48 ++++++++++++++++++++++++---- tests/ut/concurrency/operations.d | 20 ++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/source/concurrency/operations/then.d b/source/concurrency/operations/then.d index b419228..4df2b95 100644 --- a/source/concurrency/operations/then.d +++ b/source/concurrency/operations/then.d @@ -18,11 +18,26 @@ private struct ThenReceiver(Receiver, Value, Fun) { Fun fun; static if (is(Value == void)) { void setValue() @safe { - static if (is(ReturnType!Fun == void)) { - fun(); - receiver.setValue(); - } else - receiver.setValue(fun()); + static if (is(ReturnType!Fun == Result!T, T)) { + auto r = fun(); + r.match!((Cancelled c) { + receiver.setDone(); + }, (Exception e) { + receiver.setError(e); + }, (Result!(T).Value v) { + static if (is(typeof(v) == Completed)) { + receiver.setValue(); + } else { + receiver.setValue(v); + } + }); + } else { + static if (is(ReturnType!Fun == void)) { + fun(); + receiver.setValue(); + } else + receiver.setValue(fun()); + } } } else { import std.typecons : isTuple; @@ -30,7 +45,23 @@ private struct ThenReceiver(Receiver, Value, Fun) { fun(Value.init.expand); }); void setValue(Value value) @safe { - static if (is(ReturnType!Fun == void)) { + static if (is(ReturnType!Fun == Result!T, T)) { + static if (isExpandable) + auto r = fun(value.expand); + else + auto r = fun(value); + r.match!((Cancelled c) { + receiver.setDone(); + }, (Exception e) { + receiver.setError(e); + }, (Result!(T).Value v) { + static if (is(typeof(v) == Completed)) { + receiver.setValue(); + } else { + receiver.setValue(v); + } + }); + } else static if (is(ReturnType!Fun == void)) { static if (isExpandable) fun(value.expand); else @@ -60,7 +91,10 @@ private struct ThenReceiver(Receiver, Value, Fun) { struct ThenSender(Sender, Fun) if (models!(Sender, isSender)) { import std.traits : ReturnType; static assert(models!(typeof(this), isSender)); - alias Value = ReturnType!fun; + static if (is(ReturnType!fun == Result!T, T)) + alias Value = T; + else + alias Value = ReturnType!fun; Sender sender; Fun fun; auto connect(Receiver)(return Receiver receiver) @safe return scope { diff --git a/tests/ut/concurrency/operations.d b/tests/ut/concurrency/operations.d index 736cdff..16ea161 100644 --- a/tests/ut/concurrency/operations.d +++ b/tests/ut/concurrency/operations.d @@ -165,6 +165,26 @@ unittest { .shouldEqual(3); } +@("then.result.value") @safe +unittest { + just(3).then((int i) => Result!int(i)).syncWait.value.should == 3; +} + +@("then.result.completed") @safe +unittest { + just(3).then((int i) => Result!void(Completed())).syncWait.isOk.should == true; +} + +@("then.result.cancelled") @safe +unittest { + just(3).then((int i) => Result!int(Cancelled())).syncWait.isCancelled; +} + +@("then.result.error") @safe +unittest { + just(3).then((int i) => Result!int(new Exception("stuff"))).syncWait.value.shouldThrowWithMessage("stuff"); +} + @("whenAll.basic") @safe unittest { whenAll(ValueSender!int(1), ValueSender!int(2)).syncWait.value.should