From ffd57e9aea8ee9452e2d808a12326971995c8fe3 Mon Sep 17 00:00:00 2001 From: Sebastiaan Koppe Date: Sat, 17 Aug 2024 09:14:15 +0200 Subject: [PATCH] Add fiber, add io ops to object scheduler, fix wakeup --- source/concurrency/fiber.d | 36 +++++++++++++---- source/concurrency/io/iouring.d | 72 +++++++++++++++++++++------------ source/concurrency/io/package.d | 9 +---- source/concurrency/io/socket.d | 20 ++++----- source/concurrency/scheduler.d | 44 ++++++++++++++++++++ source/concurrency/stoptoken.d | 1 + tests/ut/concurrency/io.d | 31 ++++++++++++-- 7 files changed, 157 insertions(+), 56 deletions(-) diff --git a/source/concurrency/fiber.d b/source/concurrency/fiber.d index 0c335dc..0aa6a57 100644 --- a/source/concurrency/fiber.d +++ b/source/concurrency/fiber.d @@ -15,6 +15,7 @@ class CancelledException : Exception { package(concurrency) abstract class BaseFiber : Fiber { private Continuation continuation; + private Throwable nextError; this(void delegate() dg, size_t sz, size_t guardPageSize) @trusted nothrow { super(dg, sz, guardPageSize); } @@ -34,6 +35,11 @@ class OpFiber(Op) : BaseFiber { } } +auto fiber(Fun)(Fun fun) { + import concurrency.operations : then; + return FiberSender().then(fun); +} + struct FiberSender { static assert (models!(typeof(this), isSender)); alias Value = void; @@ -114,6 +120,7 @@ struct FiberContinuationReceiver(Receiver) { cycle(fiber, true); } void setError(Throwable e) nothrow @safe { + fiber.nextError = e; cycle(fiber, true); } void setValue() nothrow @safe { @@ -127,9 +134,9 @@ void yield() @trusted { std.concurrency.yield(); } -auto yield(Sender)(Sender sender) @trusted { +auto yield(Sender)(return Sender sender) @trusted { import concurrency : Result; - import concurrency.operations : onResult; + import concurrency.operations : onResult, then; import concurrency.sender : toSenderObject; auto fiber = BaseFiber.getThis(); @@ -139,12 +146,27 @@ auto yield(Sender)(Sender sender) @trusted { local = YieldResult!(Sender.Value)(r); } - fiber.continuation = cast(Object)sender - .onResult(cast(void delegate(Result!(Sender.Value)) @safe shared)&store) - .toSenderObject; + SenderObjectBase!void object; + + auto base = sender + .onResult(cast(void delegate(Result!(Sender.Value)) @safe shared)&store); + + static if (is(Sender.Value == void)) { + object = base.toSenderObject(); + } else { + object = base.then((Sender.Value v){}).toSenderObject(); + } + + fiber.continuation = cast(Object)object; yield(); + if (fiber.nextError) { + auto error = fiber.nextError; + fiber.nextError = null; + throw error; + } + return local; } @@ -191,13 +213,13 @@ import core.attribute : mustuse; return result.isA!Value; } - auto value() @safe { + auto value() @trusted scope { static if (is(T == void)) alias valueHandler = (Completed c) {}; else alias valueHandler = (T t) => t; - return std.sumtype.match!(valueHandler, function T(Exception e) { + return std.sumtype.match!(valueHandler, function T(Exception e) @trusted { throw e; })(result); } diff --git a/source/concurrency/io/iouring.d b/source/concurrency/io/iouring.d index fa1eb24..739f622 100644 --- a/source/concurrency/io/iouring.d +++ b/source/concurrency/io/iouring.d @@ -4,7 +4,7 @@ version(linux): import concurrency.data.queue.mpsc; import concurrency.stoptoken; -import concurrency.receiver : setErrno; +import concurrency.receiver : setErrno, setValueOrError; import during; import core.stdc.errno : ECANCELED; @@ -83,7 +83,7 @@ struct IOUringContext { private void wakeup() @trusted nothrow shared { import core.sys.posix.unistd; - ubyte wakeup = 1; + size_t wakeup = 1; // TODO: check return value core.sys.posix.unistd.write(event, &wakeup, wakeup.sizeof); } @@ -115,8 +115,8 @@ struct IOUringContext { private int run(scope shared StopToken stopToken) @safe nothrow { pending.append(requests.popAll()); + putEventFdChannel(); while (!stopToken.isStopRequested() || !pending.empty() || !io.empty()) { - putEventFdChannel(); putPending(); // TODO: might have to flip this around. scheduleTimers(); @@ -206,7 +206,7 @@ struct IOUringContext { private void putEventFdChannel() @safe nothrow { io.putWith!((ref SubmissionEntry e, IOUringContext* context) { - e.prepRead(context.event, context.buffer[], 0); + e.prepRead(context.event, context.buffer[0..8], 0); })(&this); } @@ -214,9 +214,10 @@ struct IOUringContext { while (!pending.empty && !io.full()) { auto item = pending.pop(); SubmissionEntry entry; - item.submit(entry); - entry.setUserDataRaw(item); - io.put(entry); + if (item.submit(entry)) { + entry.setUserDataRaw(item); + io.put(entry); + } } } @@ -227,6 +228,8 @@ struct IOUringContext { auto item = entry.userDataAs!(Item*); if (item !is null) item.complete(entry); + else + putEventFdChannel(); io.popFront(); } } @@ -252,17 +255,27 @@ struct RunOp(Sender, Receiver) { alias RunSender = JustFromSender!(void delegate() @trusted shared); alias SenderWithScheduler = WithSchedulerSender!(Sender, IOUringScheduler); - alias ValueSender = DoFinallySender!(SenderWithScheduler, bool delegate() @safe nothrow shared); + alias ValueSender = DoFinallySender!(SenderWithScheduler, void delegate() @safe nothrow shared); alias CombinedSender = WhenAllSender!(ValueSender, RunSender); alias Op = OpType!(CombinedSender, Receiver); IOUringContext* context; - Op op; shared StopSource stopSource; + Op op; + + @disable + this(ref return scope typeof(this) rhs); + @disable + this(this); + this(IOUringContext* context, Sender sender, return Receiver receiver) @trusted return scope { this.context = context; + shared IOUringContext* sharedContext = cast(shared)context; op = whenAll( - sender.withScheduler(IOUringScheduler(cast(shared)context)).doFinally(() @safe shared => stopSource.stop()), + sender.withScheduler(IOUringScheduler(cast(shared)context)).doFinally(() @safe nothrow shared { + stopSource.stop(); + sharedContext.wakeup(); + }), justFrom(&(cast(shared)this).run), ).connect(receiver); } @@ -282,7 +295,7 @@ struct RunOp(Sender, Receiver) { struct Item { // TODO: we are storing 2 this pointers here - void delegate(ref SubmissionEntry sqe) @safe nothrow submit; + bool delegate(ref SubmissionEntry sqe) @safe nothrow submit; void delegate(ref const CompletionEntry cqe) @safe nothrow complete; Item* next; } @@ -312,13 +325,20 @@ struct CancellableOperation(Operation) { } } - private void submit(ref SubmissionEntry entry) @trusted nothrow { - import core.atomic; - ops.atomicFetchAdd!(MemoryOrder.raw)(1); - auto stopToken = operation.receiver.getStopToken(); - cb.register(stopToken, &(cast(shared)this).onStop); + // TODO: shouldn't submit be shared? + private bool submit(ref SubmissionEntry entry) @trusted nothrow { + try { + import core.atomic; + ops.atomicFetchAdd!(MemoryOrder.raw)(1); + auto stopToken = operation.receiver.getStopToken(); + cb.register(stopToken, &(cast(shared)this).onStop); - operation.submit(entry); + operation.submit(entry); + return true; + } catch (Throwable e) { + operation.receiver.setError(e); + return false; + } } private void complete(const ref CompletionEntry entry) @safe nothrow { @@ -352,8 +372,9 @@ struct CancellableOperation(Operation) { } } - private void submitStop(ref SubmissionEntry entry) nothrow @safe { + private bool submitStop(ref SubmissionEntry entry) nothrow @safe { entry.prepCancel(item); + return true; } private ref assumeThreadSafe() nothrow @trusted shared { @@ -422,7 +443,7 @@ struct ReadOperation(Receiver) { } void complete(const ref CompletionEntry entry) @safe nothrow { if (entry.res >= 0) { - receiver.setValue(buffer[offset..$]); + receiver.setValueOrError(buffer[offset..entry.res]); } else { receiver.setErrno("Read failed", -entry.res); } @@ -430,7 +451,7 @@ struct ReadOperation(Receiver) { } struct AcceptSender { - import concurrency.io : Client; + import concurrency.scheduler : Client; import std.socket : socket_t; alias Value = Client; shared IOUringContext* context; @@ -448,7 +469,7 @@ struct AcceptSender { struct AcceptOperation(Receiver) { import core.sys.posix.sys.socket : sockaddr, socklen_t; import core.sys.posix.netinet.in_; - import concurrency.io : Client; + import concurrency.scheduler : Client; import std.socket : socket_t; socket_t fd; @@ -461,7 +482,7 @@ struct AcceptOperation(Receiver) { void complete(const ref CompletionEntry entry) @safe nothrow { import std.socket : socket_t; if (entry.res >= 0) { - receiver.setValue(Client(cast(socket_t)entry.res, addr, addrlen)); + receiver.setValueOrError(Client(cast(socket_t)entry.res, addr, addrlen)); } else { receiver.setErrno("Accept failed", -entry.res); } @@ -470,7 +491,7 @@ struct AcceptOperation(Receiver) { struct ConnectSender { import std.socket : socket_t; - alias Value = int; + alias Value = socket_t; shared IOUringContext* context; socket_t fd; string address; @@ -483,7 +504,6 @@ struct ConnectSender { ); return op; } - } struct ConnectOperation(Receiver) { @@ -519,7 +539,7 @@ struct ConnectOperation(Receiver) { } void complete(const ref CompletionEntry entry) @safe nothrow { if (entry.res >= 0) { - receiver.setValue(cast(socket_t)entry.res); + receiver.setValueOrError(cast(socket_t)entry.res); } else { receiver.setErrno("Connect failed", -entry.res); } @@ -554,7 +574,7 @@ struct WriteOperation(Receiver) { } void complete(const ref CompletionEntry entry) @safe nothrow { if (entry.res >= 0) { - receiver.setValue(entry.res); + receiver.setValueOrError(entry.res); } else { receiver.setErrno("Write failed", -entry.res); } diff --git a/source/concurrency/io/package.d b/source/concurrency/io/package.d index 3f5c712..eaaf001 100644 --- a/source/concurrency/io/package.d +++ b/source/concurrency/io/package.d @@ -1,6 +1,7 @@ module concurrency.io; import concurrency.io.iouring; +import concurrency.scheduler : Client; import std.socket : socket_t; @@ -28,14 +29,6 @@ auto acceptAsync(socket_t fd) @safe nothrow @nogc { return AcceptAsyncSender(fd); } -struct Client { - import std.socket : socket_t; - import core.sys.posix.sys.socket : sockaddr, socklen_t; - socket_t fd; - sockaddr addr; - socklen_t addrlen; -} - struct AcceptAsyncSender { alias Value = Client; socket_t fd; diff --git a/source/concurrency/io/socket.d b/source/concurrency/io/socket.d index adf7e2b..87d67b8 100644 --- a/source/concurrency/io/socket.d +++ b/source/concurrency/io/socket.d @@ -1,7 +1,7 @@ module concurrency.io.socket; import std.socket : socket_t; -auto getSocket() @trusted { +auto tcpSocket() @trusted { import std.socket : socket_t; version(Windows) { import core.sys.windows.windows; @@ -20,15 +20,19 @@ auto getSocket() @trusted { } socket_t sock = cast(socket_t) socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); - int on = 1; + if (sock == -1) + throw new Exception("socket"); + int on = 1; setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, on.sizeof); setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &on, on.sizeof); + version(Posix) // on windows REUSEADDR includes REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &on, on.sizeof); return sock; } -auto listen(string address = "", ushort port = 0, int backlog = 128) @trusted { +auto listenTcp(string address = "", ushort port = 0, int backlog = 128) @trusted { import core.stdc.stdio : fprintf, stderr; import std.socket : socket_t; version(Windows) { @@ -48,9 +52,7 @@ auto listen(string address = "", ushort port = 0, int backlog = 128) @trusted { } import core.stdc.errno; - socket_t sock = cast(socket_t) socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (sock == -1) - throw new Exception("socket"); + socket_t sock = tcpSocket(); sockaddr_in addr; addr.sin_family = AF_INET; @@ -69,12 +71,6 @@ auto listen(string address = "", ushort port = 0, int backlog = 128) @trusted { } else addr.sin_addr.s_addr = INADDR_ANY; - int on = 1; - setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, on.sizeof); - setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &on, on.sizeof); - version(Posix) // on windows REUSEADDR includes REUSEPORT - setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &on, on.sizeof); - if (bind(sock, cast(sockaddr*) &addr, addr.sizeof) == -1) { closeSocket(sock); throw new Exception("bind"); diff --git a/source/concurrency/scheduler.d b/source/concurrency/scheduler.d index 5e006ec..189a417 100644 --- a/source/concurrency/scheduler.d +++ b/source/concurrency/scheduler.d @@ -17,10 +17,24 @@ void checkScheduler(T)() { enum isScheduler(T) = is(typeof(checkScheduler!T)); +struct Client { + import std.socket : socket_t; + import core.sys.posix.sys.socket : sockaddr, socklen_t; + socket_t fd; + sockaddr addr; + socklen_t addrlen; +} + /// polymorphic Scheduler interface SchedulerObjectBase { + import std.socket : socket_t; SenderObjectBase!void schedule() @safe; SenderObjectBase!void scheduleAfter(Duration d) @safe; + // TODO: do these belong here? + SenderObjectBase!(ubyte[]) read(socket_t fd, return ubyte[] buffer, long offset = 0) @safe; + SenderObjectBase!(Client) accept(socket_t fd) @safe; + SenderObjectBase!(socket_t) connect(socket_t fd, return string address, ushort port) @safe; + SenderObjectBase!(int) write(socket_t fd, return ubyte[] buffer, long offset = 0) @safe; } @@ -64,6 +78,36 @@ class SchedulerObject(S) : SchedulerObjectBase { SenderObjectBase!void scheduleAfter(Duration d) @safe { return scheduler.scheduleAfter(d).toSenderObject(); } + SenderObjectBase!(ubyte[]) read(socket_t fd, return ubyte[] buffer, long offset = 0) @safe { + static if (__traits(hasMember, S, "read")) { + return scheduler.read(fd, buffer, offset).toSenderObject(); + } else { + throw new Exception("`read` not implemented on "~S.stringof); + } + } + SenderObjectBase!(Client) accept(socket_t fd) @safe { + static if (__traits(hasMember, S, "accept")) { + return scheduler.accept(fd).toSenderObject(); + } else { + throw new Exception("`accept` not implemented on "~S.stringof); + } + } + // TODO: is trusted because of scope string address + SenderObjectBase!(socket_t) connect(socket_t fd, return string address, ushort port) @trusted { + static if (__traits(hasMember, S, "connect")) { + string adr = address; + return scheduler.connect(fd, adr, port).toSenderObject(); + } else { + throw new Exception("`connect` not implemented on "~S.stringof); + } + } + SenderObjectBase!(int) write(socket_t fd, return ubyte[] buffer, long offset = 0) @safe { + static if (__traits(hasMember, S, "write")) { + return scheduler.write(fd, buffer, offset).toSenderObject(); + } else { + throw new Exception("`write` not implemented on "~S.stringof); + } + } } SchedulerObjectBase toSchedulerObject(S)(S scheduler) { diff --git a/source/concurrency/stoptoken.d b/source/concurrency/stoptoken.d index 119587e..de8a871 100644 --- a/source/concurrency/stoptoken.d +++ b/source/concurrency/stoptoken.d @@ -389,6 +389,7 @@ private struct StopState { atomicStore(head.state, blank); head = next; } + atomicStore(head, null); unlock(); assert(false, "StopSource has lingering callbacks"); } diff --git a/tests/ut/concurrency/io.d b/tests/ut/concurrency/io.d index 024c459..2aa1fd0 100644 --- a/tests/ut/concurrency/io.d +++ b/tests/ut/concurrency/io.d @@ -49,11 +49,11 @@ unittest { } @safe -@("acceptAsync.connectAsync") +@("acceptAsync.connectAsync.basic") unittest { import concurrency.io.socket; - auto fd = listen("127.0.0.1", 0); - auto socket = getSocket(); + auto fd = listenTcp("127.0.0.1", 0); + auto socket = tcpSocket(); auto port = fd.getPort(); auto io = IOContext.construct(12); @@ -70,3 +70,28 @@ unittest { closeSocket(socket); closeSocket(fd); } + +@safe +@("acceptAsync.connectAsync.fiber") +unittest { + import concurrency.io.socket; + import concurrency.fiber; + + auto io = IOContext.construct(12); + io.run(fiber({ + auto fd = listenTcp("127.0.0.1", 0); + auto socket = tcpSocket(); + auto port = fd.getPort(); + + auto result = whenAll( + acceptAsync(fd), + connectAsync(socket, "127.0.0.1", port), + ).yield().value; + + auto client = result[0]; + + closeSocket(client.fd); + closeSocket(socket); + closeSocket(fd); + })).syncWait.assumeOk; +} \ No newline at end of file