diff --git a/source/concurrency/sequence.d b/source/concurrency/sequence.d index 06134a1..474eabd 100644 --- a/source/concurrency/sequence.d +++ b/source/concurrency/sequence.d @@ -815,6 +815,70 @@ struct FlattenSequenceReceiver(Sequence, Receiver) { } } +auto flatMap(Sequence, Fun)(Sequence s, Fun f) { + // TOD: probably flatMap is just .then(f).flatten() + return FlatMapSequence!(Sequence, Fun)(s, f); +} + +struct FlatMapSequence(Sequence, Fun) { + import std.traits : ReturnType; + + alias Value = void; + alias Element = ReturnType!(Fun).Value; + Sequence s; + Fun f; + auto connect(Receiver)(return Receiver receiver) @safe return scope { + auto op = FlatMapSequenceOp!(Sequence, Fun, Receiver)(s, f, receiver); + return op; + } +} + +struct FlatMapSequenceOp(Sequence, Fun, Receiver) { + import concurrency.sender : OpType; + + alias Op = OpType!(Sequence, FlatMapSequenceReceiver!(Sequence, Fun, Receiver)); + Fun fun; + Receiver receiver; + Op op; + + @disable this(ref return scope typeof(this) rhs); + @disable this(this); + this(Sequence sequence, Fun fun, Receiver receiver) { + this.fun = fun; + this.receiver = receiver; + op = sequence.connect(FlatMapSequenceReceiver!(Sequence, Fun, Receiver)(this)); + } + void start() nothrow { + op.start(); + } +} + +struct FlatMapSequenceReceiver(Sequence, Fun, Receiver) { + FlatMapSequenceOp!(Sequence, Fun, Receiver)* op; + this(ref FlatMapSequenceOp!(Sequence, Fun, Receiver) op) { + this.op = &op; + } + auto setNext(Sender)(Sender sender) { + import concurrency.operations : then; + return op.receiver.setNext(sender.then(op.fun).flatten); + } + auto setValue() { + op.receiver.setValue(); + } + auto setDone() nothrow @safe { + op.receiver.setDone(); + } + auto setError(Throwable t) nothrow @safe { + op.receiver.setError(t); + } + auto getStopToken() nothrow @trusted { + return op.receiver.getStopToken(); + } + auto getScheduler() nothrow @safe { + return op.receiver.getScheduler(); + } +} + struct ScanSequenceTransformer(Fun, Seed) { Fun fun; Seed seed; diff --git a/tests/ut/concurrency/sequence.d b/tests/ut/concurrency/sequence.d index 206e6bb..6f255db 100644 --- a/tests/ut/concurrency/sequence.d +++ b/tests/ut/concurrency/sequence.d @@ -119,6 +119,11 @@ import unit_threaded; [VoidSender()].sequence.flatten.toList().syncWait.isOk.should == true; } +@("flatMap.just") +@safe unittest { + import core.time : msecs; + [1,2,3].sequence.flatMap((int i) => just(i*3)).toList().syncWait.value.should == [3,6,9]; +} @("scan") @safe unittest {