From 13e96f65cd3908e46e2021b4cd3635ef26245ee8 Mon Sep 17 00:00:00 2001 From: Sebastiaan Koppe Date: Sun, 21 Aug 2022 20:30:52 +0200 Subject: [PATCH] Support for Fibers --- source/concurrency/fiber.d | 208 ++++++++++++++++++++++++++++++++++ source/concurrency/syncwait.d | 2 +- tests/ut/concurrency/fiber.d | 61 ++++++++++ tests/ut/ut_runner.d | 1 + 4 files changed, 271 insertions(+), 1 deletion(-) create mode 100644 source/concurrency/fiber.d create mode 100644 tests/ut/concurrency/fiber.d diff --git a/source/concurrency/fiber.d b/source/concurrency/fiber.d new file mode 100644 index 0000000..0c335dc --- /dev/null +++ b/source/concurrency/fiber.d @@ -0,0 +1,208 @@ +module concurrency.fiber; + +import concurrency.sender; +import concepts; +import core.thread.fiber; +import core.thread.fiber : Fiber; + +alias Continuation = Object; + +class CancelledException : Exception { + this(string file = __FILE__, size_t line = __LINE__, Throwable next = null) @nogc @safe pure nothrow { + super("Cancelled", file, line, next); + } +} + +package(concurrency) abstract class BaseFiber : Fiber { + private Continuation continuation; + this(void delegate() dg, size_t sz, size_t guardPageSize) @trusted nothrow { + super(dg, sz, guardPageSize); + } + static BaseFiber getThis() @trusted nothrow { + import core.thread.fiber : Fiber; + return cast(BaseFiber)Fiber.getThis(); + } +} + +class OpFiber(Op) : BaseFiber { + import core.memory : pageSize; + + private Op op; + + this(void delegate() shared @safe nothrow dg, size_t sz = pageSize * defaultStackPages, size_t guardPageSize = pageSize) @trusted nothrow { + super(cast(void delegate())dg, sz, guardPageSize); + } +} + +struct FiberSender { + static assert (models!(typeof(this), isSender)); + alias Value = void; + auto connect(Receiver)(return Receiver receiver) @safe return scope { + auto op = FiberSenderOp!(Receiver)(receiver); + return op; + } +} + +struct FiberSenderOp(Receiver) { + Receiver receiver; + alias BaseSender = typeof(receiver.getScheduler().schedule()); + alias Op = OpType!(BaseSender, FiberContinuationReceiver!Receiver); + @disable this(this); + @disable this(ref return scope typeof(this) rhs); + void start() @trusted nothrow scope { + auto fiber = new OpFiber!Op(cast(void delegate()shared nothrow @safe)&run); + cycle(fiber, true); + } + private void schedule(OpFiber!Op fiber) @trusted nothrow { + // TODO: why can't we store the Op here? + fiber.op = receiver.getScheduler.schedule().connect(FiberContinuationReceiver!Receiver(fiber, &cycle, receiver)); + fiber.op.start(); + } + private void cycle(BaseFiber f, bool inline_) @trusted nothrow { + auto fiber = cast(OpFiber!Op)f; + if (!inline_) + return schedule(fiber); + + if (auto throwable = fiber.call!(Fiber.Rethrow.no)) { + receiver.setError(throwable); + return; + } + + if (fiber.continuation !is null) { + auto sender = cast(SenderObjectBase!void)fiber.continuation; + fiber.continuation = null; + try { + // TODO: we could try to reuse this space. + // e.g. inline some space in the FiberSenderOp and storing it there + // and/or otherwise (if too big) dynamically allocate and reuse that + // space. + auto op = sender.connectHeap(FiberContinuationReceiver!Receiver(fiber, &cycle, receiver)); + op.start(); + } catch (Throwable t) { + receiver.setError(t); + return; + } + } else if (fiber.state == Fiber.State.HOLD) { + schedule(fiber); + } else { + // reuse it? + } + } + private void run() nothrow @trusted { + import concurrency.receiver : setValueOrError; + import concurrency.error : clone; + + try { + receiver.setValue(); + } catch (CancelledException e) { + receiver.setDone(); + } catch (Exception e) { + receiver.setError(e); + } catch (Throwable t) { + receiver.setError(t.clone()); + } + } +} + +// Receiver used to continue the Fiber after yielding on a Sender. +struct FiberContinuationReceiver(Receiver) { + import concurrency.receiver : ForwardExtensionPoints; + BaseFiber fiber; + void delegate(BaseFiber, bool) nothrow @trusted cycle; + Receiver receiver; + void setDone() nothrow @safe { + cycle(fiber, true); + } + void setError(Throwable e) nothrow @safe { + cycle(fiber, true); + } + void setValue() nothrow @safe { + cycle(fiber, true); + } + mixin ForwardExtensionPoints!receiver; +} + +void yield() @trusted { + import std.concurrency; + std.concurrency.yield(); +} + +auto yield(Sender)(Sender sender) @trusted { + import concurrency : Result; + import concurrency.operations : onResult; + import concurrency.sender : toSenderObject; + + auto fiber = BaseFiber.getThis(); + + YieldResult!(Sender.Value) local; + void store(Result!(Sender.Value) r) @trusted { + local = YieldResult!(Sender.Value)(r); + } + + fiber.continuation = cast(Object)sender + .onResult(cast(void delegate(Result!(Sender.Value)) @safe shared)&store) + .toSenderObject; + + yield(); + + return local; +} + +import core.attribute : mustuse; +@mustuse struct YieldResult(T) { + import concurrency.syncwait : Completed, Cancelled, Result, isA, match; + import std.sumtype; + + static if (is(T == void)) { + alias Value = Completed; + } else { + alias Value = T; + } + + alias V = SumType!(Value, Exception); + + private V result; + + this(Result!(T) other) { + static if (is(T == void)) + alias valueHandler = (Completed c) => V(c); + else + alias valueHandler = (T t) => V(t); + + result = other.match!( + valueHandler, + (Cancelled c) => V(new CancelledException()), + (Exception e) => V(e), + ); + } + + bool isError() @safe nothrow { + return result.isA!Exception; + } + + bool isCancelled() @safe nothrow { + return std.sumtype.match!( + (Exception e) => (cast(CancelledException)e) !is null, + t => false + )(result); + } + + bool isOk() @safe nothrow { + return result.isA!Value; + } + + auto value() @safe { + static if (is(T == void)) + alias valueHandler = (Completed c) {}; + else + alias valueHandler = (T t) => t; + + return std.sumtype.match!(valueHandler, function T(Exception e) { + throw e; + })(result); + } + + void assumeOk() @safe { + value(); + } +} diff --git a/source/concurrency/syncwait.d b/source/concurrency/syncwait.d index e0cedd1..d521d14 100644 --- a/source/concurrency/syncwait.d +++ b/source/concurrency/syncwait.d @@ -76,7 +76,7 @@ struct Result(T) { alias Value = T; } - alias V = SumType!(Cancelled, Exception, Value); + alias V = SumType!(Value, Cancelled, Exception); V result; this(P)(P p) { diff --git a/tests/ut/concurrency/fiber.d b/tests/ut/concurrency/fiber.d new file mode 100644 index 0000000..062b57f --- /dev/null +++ b/tests/ut/concurrency/fiber.d @@ -0,0 +1,61 @@ +module ut.concurrency.fiber; + +import concurrency.fiber; +import concurrency.operations : then, whenAll; +import concurrency; +import concurrency.sender; +import core.time; + +import unit_threaded; + +@("yield.basic") +@safe unittest { + auto fiber = FiberSender().then(() @trusted shared { + yield(); + }); + whenAll(fiber, fiber).syncWait().assumeOk; +} + +@("yield.delay.single") +@safe unittest { + auto fiber = FiberSender().then(() @trusted shared { + delay(1.msecs).yield().assumeOk; + }); + fiber.syncWait().assumeOk; +} + +@("yield.delay.double") +@safe unittest { + auto fiber = FiberSender().then(() @trusted shared { + delay(1.msecs).yield().assumeOk; + }); + whenAll(fiber, fiber).syncWait().assumeOk; +} + +@("yield.error.basic") +@safe unittest { + FiberSender().then(() @trusted shared { + ThrowingSender().yield().isError.should == true; + }).syncWait(); +} + +@("yield.error.propagate") +@safe unittest { + FiberSender().then(() @trusted shared { + ThrowingSender().yield().assumeOk; + }).syncWait().isError.should == true; +} + +@("yield.cancel.basic") +@safe unittest { + FiberSender().then(() @trusted shared { + DoneSender().yield().isCancelled.should == true; + }).syncWait().assumeOk; +} + +@("yield.cancel.propagate") +@safe unittest { + FiberSender().then(() @trusted shared { + DoneSender().yield().assumeOk; + }).syncWait().isCancelled.should == true; +} diff --git a/tests/ut/ut_runner.d b/tests/ut/ut_runner.d index 274b634..bcd453f 100644 --- a/tests/ut/ut_runner.d +++ b/tests/ut/ut_runner.d @@ -21,5 +21,6 @@ int main(string[] args) { "concurrency.stoptoken", "ut.concurrency.stoptoken", "ut.concurrency.io", + "ut.concurrency.fiber", ); }