From 48d675b10e3b7834be1a9645fba5e30ae44fea0f Mon Sep 17 00:00:00 2001 From: Thomas Leonard Date: Thu, 7 Nov 2024 12:32:00 +0000 Subject: [PATCH] Initial Eio port This switches capnp-rpc from Lwt to Eio. One particularly nice side effect of this is that `Service.return_lwt` has gone, as there is no distinction now between concurrent and non-concurrent service methods. --- CHANGES.md | 15 +- README.md | 325 ++++++------- capnp-rpc-net.opam | 3 +- capnp-rpc-net/auth.mli | 5 +- capnp-rpc-net/capTP_capnp.ml | 100 ++-- capnp-rpc-net/capTP_capnp.mli | 13 +- capnp-rpc-net/capnp_rpc_net.ml | 6 +- capnp-rpc-net/capnp_rpc_net.mli | 24 +- capnp-rpc-net/dune | 4 +- capnp-rpc-net/endpoint.ml | 58 +-- capnp-rpc-net/endpoint.mli | 24 +- capnp-rpc-net/restorer.ml | 93 ++-- capnp-rpc-net/s.ml | 41 +- capnp-rpc-net/tls_wrapper.ml | 81 ++-- capnp-rpc-net/tls_wrapper.mli | 28 +- capnp-rpc-net/two_party_network.ml | 2 +- capnp-rpc-net/vat.ml | 165 +++---- capnp-rpc-unix.opam | 10 +- capnp-rpc.opam | 4 +- capnp-rpc/capability.ml | 38 +- capnp-rpc/capnp_core.ml | 10 +- capnp-rpc/capnp_rpc.ml | 1 + capnp-rpc/capnp_rpc.mli | 67 ++- capnp-rpc/dune | 2 +- capnp-rpc/leak_handler.ml | 54 +++ capnp-rpc/leak_handler.mli | 22 + capnp-rpc/persistence.ml | 21 +- capnp-rpc/proto/capTP.ml | 14 +- capnp-rpc/proto/capTP.mli | 8 +- capnp-rpc/proto/core_types.ml | 3 +- capnp-rpc/proto/dune | 2 +- capnp-rpc/proto/s.ml | 10 +- capnp-rpc/service.ml | 28 +- capnp-rpc/sturdy_ref.ml | 14 +- examples/pipelining/dune | 2 +- examples/pipelining/echo.ml | 25 +- examples/pipelining/main.ml | 32 +- examples/sturdy-refs-2/dune | 2 +- examples/sturdy-refs-2/main.ml | 30 +- examples/sturdy-refs-3/dune | 2 +- examples/sturdy-refs-3/main.ml | 47 +- examples/sturdy-refs-4/db.ml | 13 +- examples/sturdy-refs-4/db.mli | 2 +- examples/sturdy-refs-4/dune | 2 +- examples/sturdy-refs-4/logger.ml | 9 +- examples/sturdy-refs-4/main.ml | 96 ++-- examples/sturdy-refs/dune | 2 +- examples/sturdy-refs/main.ml | 23 +- examples/testlib/calc.ml | 68 +-- examples/testlib/calc.mli | 11 +- examples/testlib/echo.ml | 17 +- examples/testlib/echo.mli | 6 +- examples/testlib/registry.ml | 21 +- examples/testlib/registry.mli | 11 +- examples/testlib/store.ml | 18 +- examples/testlib/store.mli | 6 +- examples/v1/dune | 2 +- examples/v1/echo.ml | 3 +- examples/v1/main.ml | 12 +- examples/v2/dune | 2 +- examples/v2/echo.ml | 22 +- examples/v2/main.ml | 13 +- examples/v3/dune | 2 +- examples/v3/echo.ml | 22 +- examples/v3/main.ml | 28 +- examples/v4/client.ml | 14 +- examples/v4/dune | 2 +- examples/v4/echo.ml | 22 +- examples/v4/server.ml | 26 +- fuzz/fuzz.ml | 6 +- test-bin/calc.ml | 42 +- test-bin/calc_direct.ml | 96 ++-- test-bin/dune | 3 +- test-bin/echo/dune | 2 +- test-bin/echo/echo.ml | 3 +- test-bin/echo/echo_bench.ml | 32 +- test/dune | 4 +- test/proto/testbed/capnp_direct.ml | 2 +- test/proto/testbed/connection.ml | 4 +- test/test.ml | 733 +++++++++++++++-------------- unix/capnp_rpc_unix.ml | 155 +++--- unix/capnp_rpc_unix.mli | 24 +- unix/dune | 3 +- unix/file_store.ml | 48 +- unix/network.ml | 77 ++- unix/network.mli | 11 +- unix/unix_flow.ml | 109 ----- unix/unix_flow.mli | 7 - unix/vat_network.ml | 2 +- 89 files changed, 1581 insertions(+), 1692 deletions(-) create mode 100644 capnp-rpc/leak_handler.ml create mode 100644 capnp-rpc/leak_handler.mli delete mode 100644 unix/unix_flow.ml delete mode 100644 unix/unix_flow.mli diff --git a/CHANGES.md b/CHANGES.md index 3b99f1364..499c28e73 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,14 +8,21 @@ - Add `Capnp_rpc.Std` with some common module aliases, to reduce the need to `open Capnp_rpc` (which is rather large). +- Convert API from Lwt to Eio. + To update to the new API: -1. Replace `open Capnp_rpc_lwt` with `open Capnp_rpc.Std`. -2. Replace all other uses of `Capnp_rpc_lwt` with just `Capnp_rpc`. -3. In `dune` and `opam` files, replace `capnp-rpc-lwt` with `capnp-rpc`. -4. Some modules are in `Capnp_rpc` but not the `Capnp_rpc.Std` subset. +1. Use [lwt_eio][] during the migration to allow using Eio and Lwt together in your application. +2. Replace `open Capnp_rpc_lwt` with `open Capnp_rpc.Std`. +3. Replace all other uses of `Capnp_rpc_lwt` with just `Capnp_rpc`. +4. In `dune` and `opam` files, replace `capnp-rpc-lwt` with `capnp-rpc`. +5. Some modules are in `Capnp_rpc` but not the `Capnp_rpc.Std` subset. Those should now be fully qualified (e.g. replace `Persistence` with `Capnp_rpc.Persistence`). +6. Replace `Service.return_lwt` with `Lwt_eio.run_lwt`. +7. Once all Lwt code is gone, `lwt_eio` can be removed. + +[lwt_eio]: https://github.com/ocaml-multicore/lwt_eio ### v1.2.3 diff --git a/README.md b/README.md index 1f09f178c..a6c9ab7be 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # OCaml Cap'n Proto RPC library Copyright 2017 Docker, Inc. -Copyright 2019 Thomas Leonard. +Copyright 2024 Thomas Leonard. See [LICENSE.md](LICENSE.md) for details. [API documentation][api] @@ -71,10 +71,10 @@ This library should be used with the [capnp-ocaml][] schema compiler, which gene RPC Level 2 is complete, with encryption and authentication using TLS and support for persistence. The library has unit tests and AFL fuzz tests that cover most of the core logic. -It is used as the RPC system in [ocaml-ci][]. +It is used as the RPC system in [ocaml-ci][] and [ocluster][]. The default network provided supports TCP and Unix-domain sockets, both with or without TLS. -For two-party networking, you can provide any bi-directional byte stream (satisfying the Mirage flow signature) +For two-party networking, you can provide any bi-directional byte stream (satisfying the `Eio.Flow.two_way` signature) to the library to create a connection. You can also define your own network types. @@ -84,26 +84,21 @@ Until that is implemented, Carol can ask Bob for a persistent reference (sturdy ## Installing -To install, you will need a platform with the capnproto package available (e.g. Debian >= 9). Then: +To install, you will need a platform with the capnproto package available (e.g. Debian >= 9). Then (using opam 2.1 or later): opam install capnp-rpc-unix -(note: if you are using opam < 2.1, direct install is not possible, so do the following): - - opam depext -i capnp-rpc-unix - ## Structure of the library -**Note:** This README documents the newer (unreleased) API. For the 1.x API, see an older version of the README. The main change is that `Capnp_rpc_lwt` is now just `Capnp_rpc`. +**Note:** This README documents the newer (unreleased) Eio API. For the 1.x Lwt API, see an older version of the README. The main change is that `Capnp_rpc_lwt` is now just `Capnp_rpc`. See the [CHANGES.md](./CHANGES.md) file for help migrating to 2.0. The code is split into several packages: -- `capnp-rpc` defines the main API, using the Cap'n Proto serialisation for messages and Lwt for concurrency. +- `capnp-rpc` allows you to define and use services. - `capnp-rpc-net` adds networking support, including TLS. - `capnp-rpc-unix` adds helper functions for parsing command-line arguments and setting up connections over Unix sockets. - The tests in `test-lwt` test this by sending Cap'n Proto messages over a Unix-domain socket. **Libraries** that consume or provide Cap'n Proto services should normally depend only on `capnp-rpc`, since they shouldn't care whether the services they use are local or accessed over some kind of network. @@ -172,7 +167,6 @@ For the server, you should inherit from the generated `Api.Service.Echo.service` ```ocaml module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std let local = @@ -206,7 +200,7 @@ There's a bit of ugly boilerplate here, but it's quite simple: should always free them anyway. - `Service.Response.create Results.init_pointer` creates a new response message, using `Ping.Results.init_pointer` to initialise the payload contents. - `response` is the complete message to be sent back, and `results` is the data part of it. -- `Service.return` returns the results immediately (like `Lwt.return`). +- `Service.return` returns the results immediately (rather than returning a promise). The client implementation is similar, but uses `Api.Client` instead of `Api.Service`. Here, we have a *builder* for the parameters and a *reader* for the results. @@ -220,7 +214,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get ``` `Capability.call_for_value_exn` sends the request message to the service and waits for the response to arrive. @@ -234,19 +228,17 @@ With the boilerplate out of the way, we can now write a `main.ml` to test it: ```ocaml -open Lwt.Infix +open Eio.Std let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let () = - Lwt_main.run begin - let service = Echo.local in - Echo.ping service "foo" >>= fun reply -> - Fmt.pr "Got reply %S@." reply; - Lwt.return_unit - end + Eio_main.run @@ fun _ -> + let service = Echo.local in + let reply = Echo.ping service "foo" in + traceln "Got reply %S" reply ```

@@ -259,7 +251,7 @@ Here's a suitable `dune` file to compile the schema file and then the generated ``` (executable (name main) - (libraries lwt.unix capnp-rpc logs.fmt)) + (libraries eio_main capnp-rpc logs.fmt)) (rule (targets echo_api.ml echo_api.mli) @@ -280,12 +272,10 @@ The service is now usable: ```bash $ opam install capnp-rpc ``` -(note: or `$ opam depext -i capnp-rpc` for opam < 2.1) - ```bash $ dune exec ./main.exe -Got reply "echo:foo" ++Got reply "echo:foo" ``` This isn't very exciting, so let's add some capabilities to the protocol... @@ -323,33 +313,31 @@ The new `heartbeat_impl` method looks like this: match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~delay msg) ``` Note that all parameters in Cap'n Proto are optional, so we have to check for `callback` not being set (data parameters such as `msg` get a default value from the schema, which is `""` for strings if not set explicitly). -`Service.return_lwt fn` runs `fn ()` and replies to the `heartbeat` call when it finishes. -Here, the whole of the rest of the method is the argument to `return_lwt`, which is a common pattern. +You'll need to add a `~delay` argument to `local` too, to configure the time between messages. `Capability.with_ref x f` calls `f x` and then releases `x` (capabilities are ref-counted). -`notify callback msg` just sends a few messages to `callback` in a loop: +`notify ~delay msg callback` just sends a few messages to `callback` in a loop: ```ocaml -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~delay msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.Timeout.sleep delay; + loop (i - 1) in loop 3 ``` @@ -376,24 +364,27 @@ In `main.ml`, we can now wrap a regular OCaml function as the callback: ```ocaml +open Eio.Std open Capnp_rpc.Std +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let () = - Lwt_main.run begin - let service = Echo.local in - run_client service - end + Eio_main.run @@ fun env -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + let service = Echo.local ~delay in + run_client service ``` Step 1: The client creates the callback: @@ -419,12 +410,12 @@ Exercise: implement `Callback.local fn` (hint: it's similar to the original `pin And testing it should give (three times, at one second intervals): - + ```sh $ dune exec -- ./main.exe -Callback got "foo" -Callback got "foo" -Callback got "foo" ++Callback got "foo" ++Callback got "foo" ++Callback got "foo" ``` Note that the client gives the echo service permission to call its callback service by sending a message containing the callback to the service. @@ -443,15 +434,17 @@ Here's the new `main.ml` (the top half is the same as before): ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> @@ -460,37 +453,40 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~delay net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~delay) in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + let uri = start_server ~sw ~delay env#net in + traceln "Connecting to echo service at: %a" Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client ```

-You'll need to edit your `dune` file to add a dependency on `capnp-rpc-unix` in the `(libraries ...` line and also: +You'll need to edit your `dune` file to add a dependencies +on `capnp-rpc-unix` and `mirage-crypto-rng-eio` in the `(libraries ...` line and also: ```sh -$ opam depext -i capnp-rpc-unix +$ opam install capnp-rpc-unix mirage-crypto-rng-eio ``` Running this will give something like: - + ```sh $ dune exec ./main.exe Connecting to echo service at: capnp://sha-256:3Tj5y5Q2qpqN3Sbh0GRPxgORZw98_NtrU2nLI0-Tn6g@127.0.0.1:7000/eBIndzZyoVDxaJdZ8uh_xBx5V1lfXWTJCDX-qEkgNZ4 @@ -545,10 +541,11 @@ In `start_server`: and the name. This means that the ID will be stable as long as the server's key doesn't change. The name used ("main" here) isn't important - it just needs to be unique. -- `let restore = Restorer.single service_id Echo.local` configures a simple "restorer" that - answers requests for `service_id` with our `Echo.local` service. +- `let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~delay)` + configures a simple "restorer" that answers requests for `service_id` with + our `Echo.local` service. -- `Capnp_rpc_unix.serve config ~restore` creates the service vat using the +- `Capnp_rpc_unix.serve ~sw ~net ~restore config` creates the service vat using the previous configuration items and starts it listening for incoming connections. - `Capnp_rpc_unix.Vat.sturdy_uri vat service_id` returns a "capnp://" URI for @@ -572,7 +569,7 @@ Edit the `dune` file to build a client and server: ``` (executables (names client server) - (libraries lwt.unix capnp-rpc logs.fmt capnp-rpc-unix)) + (libraries eio_main capnp-rpc logs.fmt capnp-rpc-unix mirage-crypto-rng-eio)) (rule (targets echo_api.ml echo_api.mli) @@ -584,9 +581,11 @@ Here's a suitable `server.ml`: ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc_net +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) @@ -594,16 +593,18 @@ let () = let cap_file = "echo.cap" let serve config = - Lwt_main.run begin - let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >>= fun vat -> - match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with - | Error `Msg m -> failwith m - | Ok () -> - Fmt.pr "Server running. Connect using %S.@." cap_file; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in + let restore = Restorer.single service_id (Echo.local ~delay) in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with + | Error `Msg m -> failwith m + | Ok () -> + traceln "Server running. Connect using %S." cap_file; + Fiber.await_cancel () open Cmdliner @@ -623,6 +624,7 @@ And here's the corresponding `client.ml`: ```ocaml +open Eio.Std open Capnp_rpc.Std let () = @@ -630,18 +632,19 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let connect uri = - Lwt_main.run begin - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Capnp_rpc_unix.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Capnp_rpc_unix.with_cap_exn sr run_client open Cmdliner @@ -752,7 +755,7 @@ We can test it as follows: ```ocaml let run_client service = let logger = Echo.get_logger service in - Echo.Callback.log logger "Message from client" >|= function + match Echo.Callback.log logger "Message from client" with | Ok () -> () | Error (`Capnp err) -> Fmt.epr "Server's logger failed: %a" Capnp_rpc.Error.pp err @@ -767,8 +770,8 @@ This should print (in the server's output) something like: ```sh $ dune exec ./main.exe -[client] Connecting to echo service... -[server] Received "Message from client" ++[client] Connecting to echo service... ++[server] Received "Message from client" ``` In this case, we didn't wait for the `getLogger` call to return before using the logger. @@ -836,32 +839,35 @@ let make_service ~config ~services name = Restorer.Table.add services id service; name, id -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in let restore = Restorer.of_table services in let services = List.map (make_service ~config ~services) ["alice"; "bob"] in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in services |> List.iter (fun (name, id) -> let cap_file = name ^ ".cap" in Capnp_rpc_unix.Cap_file.save_service vat id cap_file |> or_fail; Printf.printf "[server] saved %S\n%!" cap_file ) -let run_client cap_file msg = - let vat = Capnp_rpc_unix.client_only_vat () in +let run_client ~net cap_file msg = + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Printf.printf "[client] loaded %S\n%!" cap_file; Sturdy_ref.with_cap_exn sr @@ fun cap -> Logger.log cap msg let () = - Lwt_main.run begin - start_server () >>= fun () -> - run_client "./alice.cap" "Message from Alice" >>= fun () -> - run_client "./bob.cap" "Message from Bob" - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let net = env#net in + start_server ~sw net; + run_client ~net "./alice.cap" "Message from Alice"; + run_client ~net "./bob.cap" "Message from Bob" ``` @@ -899,17 +905,19 @@ We can use the new API like this: ```ocaml let () = - Lwt_main.run begin - start_server () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - Capability.with_ref (Logger.sub root "alice") @@ fun for_alice -> - Capability.with_ref (Logger.sub root "bob") @@ fun for_bob -> - Logger.log for_alice "Message from Alice" >>= fun () -> - Logger.log for_bob "Message from Bob" - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + Capability.with_ref (Logger.sub root "alice") @@ fun for_alice -> + Capability.with_ref (Logger.sub root "bob") @@ fun for_bob -> + Logger.log for_alice "Message from Alice"; + Logger.log for_bob "Message from Bob" ``` @@ -934,13 +942,11 @@ the admin can request the sturdy ref like this: ```ocaml - (* The admin creates a logger for Alice and saves it: *) - Capability.with_ref (Logger.sub root "alice") (fun for_alice -> - Capnp_rpc.Persistence.save_exn for_alice >|= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail - ) >>= fun () -> - (* Alice uses it: *) - run_client "alice.cap" + (* The admin creates a logger for Alice and saves it: *) + let uri = Capability.with_ref (Logger.sub root "alice") Capnp_rpc.Persistence.save_exn in + Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; + (* Alice uses it: *) + run_client ~net "alice.cap" ``` If successful, the client can use this sturdy ref to connect directly to the logger in future: @@ -1006,7 +1012,7 @@ include Restorer.LOADER type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restorer.resolution (** A function to create a new in-memory logger with the given label and sturdy-ref. *) -val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Lwt.u +val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> _ Eio.Path.t -> t * loader Eio.Promise.u (** [create ~make_sturdy dir] is a database that persists services in [dir] and a resolver to let you set the loader (we're not ready to set the loader when we create the database). *) @@ -1039,7 +1045,7 @@ We can use this with `File_store` to implement `Db`: ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std open Capnp_rpc_net @@ -1050,7 +1056,7 @@ type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restore type t = { store : Store.Reader.SavedService.struct_t File_store.t; - loader : loader Lwt.t; + loader : loader Promise.t; make_sturdy : Restorer.Id.t -> Uri.t; } @@ -1073,17 +1079,18 @@ let save_new t ~label = let load t sr digest = match File_store.load t.store ~digest with - | None -> Lwt.return Restorer.unknown_service_id + | None -> Restorer.unknown_service_id | Some saved_service -> let logger = Store.Reader.SavedService.logger_get saved_service in let label = Store.Reader.SavedLogger.label_get logger in let sr = Capnp_rpc.Sturdy_ref.cast sr in - t.loader >|= fun loader -> + let loader = Promise.await t.loader in loader sr ~label let create ~make_sturdy dir = - let loader, set_loader = Lwt.wait () in - if not (Sys.file_exists dir) then Unix.mkdir dir 0o755; + let loader, set_loader = Promise.create () in + if not (Eio.Path.is_directory dir) then + Eio.Path.mkdir dir ~perm:0o755; let store = File_store.create dir in {store; loader; make_sturdy}, set_loader ``` @@ -1096,33 +1103,35 @@ The main `start_server` function then uses `Db` to create the table: ```ocaml let serve config = - Lwt_main.run begin - (* Create the on-disk store *) - let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let db, set_loader = Db.create ~make_sturdy "./store" in - (* Create the restorer *) - let services = Restorer.Table.of_loader (module Db) db in - let restore = Restorer.of_table services in - (* Add the root service *) - let persist_new ~label = - let id = Db.save_new db ~label in - Capnp_rpc_net.Restorer.restore restore id - in - let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in - let root = - let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in - Logger.local ~persist_new sr "root" - in - Restorer.Table.add services root_id root; - (* Tell the database how to restore saved loggers *) - Lwt.wakeup set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); - (* Run the server *) - Capnp_rpc_unix.serve config ~restore >>= fun _vat -> - let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in - Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; - print_endline "Wrote admin.cap"; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + (* Create the on-disk store *) + let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in + let db, set_loader = Db.create ~make_sturdy (env#cwd / "store") in + (* Create the restorer *) + let services = Restorer.Table.of_loader ~sw (module Db) db in + Switch.on_release sw (fun () -> Restorer.Table.clear services); + let restore = Restorer.of_table services in + (* Add the root service *) + let persist_new ~label = + let id = Db.save_new db ~label in + Capnp_rpc_net.Restorer.restore restore id + in + let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in + let root = + let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in + Logger.local ~persist_new sr "root" + in + Restorer.Table.add services root_id root; + (* Tell the database how to restore saved loggers *) + Promise.resolve set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); + (* Run the server *) + let _vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in + Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; + print_endline "Wrote admin.cap"; + Fiber.await_cancel () ``` The server implementation of the `sub` method gets the label from the parameters @@ -1135,14 +1144,13 @@ and uses `persist_new` to save the new logger to the database: let sub_label = Params.label_get params in release_param_caps (); let label = Printf.sprintf "%s/%s" label sub_label in - Service.return_lwt @@ fun () -> - persist_new ~label >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persist_new ~label with + | Error e -> Service.error (`Exception e) | Ok logger -> let response, results = Service.Response.create Results.init_pointer in Results.logger_set results (Some logger); Capability.dec_ref logger; - Ok response + Service.return response ``` @@ -1256,12 +1264,12 @@ The solution here is to construct `Frontend` with a *promise* for the sturdy ref ```ocaml -let run_frontend backend_uri = - let backend_promise, resolver = Lwt.wait () in +let run_frontend ~sw ~net backend_uri = + let backend_promise, resolver = Promise.create () in let frontend = Frontend.make backend_promise in let restore = Restorer.single id frontend in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> - Lwt.wakeup resolver (Capnp_rpc_unix.Vat.import_exn vat backend_uri) + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in + Promise.resolve resolver (Capnp_rpc_unix.Vat.import_exn vat backend_uri) ``` ### How can I release other resources when my service is released? @@ -1392,7 +1400,7 @@ like regular OCaml method calls, but also over the network to remote objects. The network is made up of communicating "vats" of objects. You can think of a Unix process as a single vat. The vats are peers - there is no difference between a "client" and a "server" at the protocol level. -However, some vats may not be listening for incoming network connections, and you might like to think of such vats as clients. +However, some vats may not be listening for incoming network connections, and you might like to think of such vats as clients. When a connection is established between two vats, each can choose to ask the other for access to some service. Services are usually identified by a long random secret (a "Swiss number") so that only authorised clients can get access to them. @@ -1417,7 +1425,6 @@ To build: git clone https://github.com/mirage/capnp-rpc.git cd capnp-rpc opam pin add -ny . - opam depext -t capnp-rpc-unix capnp-rpc-mirage opam install --deps-only -t . make test @@ -1425,7 +1432,7 @@ If you have trouble building, you can use the Dockerfile shown in the CI logs (c ### Testing -Running `make test` will run through the tests in `test-lwt/test.ml`, which run some in-process examples. +Running `make test` will run through the tests in the `test` directory, which run some in-process examples. The calculator example can also be run across two Unix processes. @@ -1453,7 +1460,8 @@ In that case, the client URL would be `capnp://insecure@/tmp/calc.socket`. ### Fuzzing -Running `make fuzz` will run the AFL fuzz tester. You will need to use a version of the OCaml compiler with AFL support (e.g. `opam sw 4.04.0+afl`). +Running `make fuzz` will run the AFL fuzz tester. You will need to use a version of the OCaml compiler with AFL support +(e.g. `opam switch create 5.2-afl ocaml-variants.5.2.0+options ocaml-option-afl`). The fuzzing code is in the `fuzz` directory. The tests set up some vats in a single process and then have them perform operations based on input from the fuzzer. @@ -1486,6 +1494,7 @@ We should also test with some malicious vats (that don't follow the protocol cor [pycapnp]: http://jparyani.github.io/pycapnp/ [Persistence API]: https://github.com/capnproto/capnproto/blob/master/c%2B%2B/src/capnp/persistent.capnp [ocaml-ci]: https://github.com/ocurrent/ocaml-ci +[ocluster]: https://github.com/ocurrent/ocluster [api]: https://mirage.github.io/capnp-rpc/ [NETWORK]: https://mirage.github.io/capnp-rpc/capnp-rpc-net/Capnp_rpc_net/S/module-type-NETWORK/index.html [calc_direct.ml]: ./test-bin/calc_direct.ml diff --git a/capnp-rpc-net.opam b/capnp-rpc-net.opam index c111daf5f..28f85fae8 100644 --- a/capnp-rpc-net.opam +++ b/capnp-rpc-net.opam @@ -20,8 +20,7 @@ depends: [ "logs" "asetmap" "cstruct" {>= "6.0.0"} - "mirage-flow" {>= "4.0.2"} - "tls" {>= "1.0.2"} + "tls-eio" {>= "1.0.2"} "base64" {>= "3.0.0"} "uri" {>= "1.6.0"} "ptime" diff --git a/capnp-rpc-net/auth.mli b/capnp-rpc-net/auth.mli index 59f820771..b9aa30624 100644 --- a/capnp-rpc-net/auth.mli +++ b/capnp-rpc-net/auth.mli @@ -55,9 +55,8 @@ module Secret_key : sig val generate : unit -> t (** [generate ()] is a fresh secret key. - You must call the relevant entropy initialization function - (e.g. {!Mirage_crypto_rng_lwt.initialize}) before using this, or it - will raise an error if you forget. *) + You must use e.g. {!Mirage_crypto_rng_eio.run} to set a source of + randomness before using this (it will raise an error if you forget). *) val digest : ?hash:hash -> t -> Digest.t (** [digest ~hash t] is the digest of [t]'s public key, using [hash]. *) diff --git a/capnp-rpc-net/capTP_capnp.ml b/capnp-rpc-net/capTP_capnp.ml index 9eb9bdf2a..18b94316e 100644 --- a/capnp-rpc-net/capTP_capnp.ml +++ b/capnp-rpc-net/capTP_capnp.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std module Metrics = struct open Prometheus @@ -42,6 +42,7 @@ module Make (Network : S.NETWORK) = struct module Serialise = Serialise.Make(Endpoint_types) type t = { + sw : Switch.t; endpoint : Endpoint.t; conn : Conn.t; xmit_queue : Capnp.Message.rw Capnp.BytesMessage.Message.t Queue.t; @@ -50,16 +51,6 @@ module Make (Network : S.NETWORK) = struct let bootstrap t id = Conn.bootstrap t.conn id |> Capnp_rpc.Cast.cap_of_raw - let async_tagged label fn = - Lwt.async - (fun () -> - Lwt.catch fn - (fun ex -> - Log.warn (fun f -> f "Uncaught async exception in %S: %a" label Fmt.exn ex); - Lwt.return_unit - ) - ) - let pp_msg f call = let open Reader in let call = Capnp_rpc.Private.Msg.Request.readable call in @@ -75,30 +66,32 @@ module Make (Network : S.NETWORK) = struct (* [flush ~xmit_queue endpoint] writes each message in the queue until it is empty. Invariant: - Whenever Lwt blocks or switches threads, a flush thread is running iff the + Whenever Eio blocks or switches threads, a flush thread is running iff the queue is non-empty. *) let rec flush ~xmit_queue endpoint = (* We keep the item on the queue until it is transmitted, as the queue state tells us whether there is a [flush] currently running. *) let next = Queue.peek xmit_queue in - Endpoint.send endpoint next >>= function + match Endpoint.send endpoint next with | Error `Closed -> - Endpoint.disconnect endpoint >|= fun () -> (* We'll read a close soon *) + Endpoint.disconnect endpoint; (* We'll read a close soon *) drop_queue xmit_queue - | Error e -> - Log.warn (fun f -> f "Error sending messages: %a (will shutdown connection)" Endpoint.pp_error e); - Endpoint.disconnect endpoint >|= fun () -> + | Error (`Msg msg) -> + Log.warn (fun f -> f "Error sending messages: %s (will shutdown connection)" msg); + Endpoint.disconnect endpoint; drop_queue xmit_queue | Ok () -> Prometheus.Counter.inc_one Metrics.messages_outbound_sent_total; ignore (Queue.pop xmit_queue); if not (Queue.is_empty xmit_queue) then flush ~xmit_queue endpoint - else (* queue is empty and flush thread is done *) - Lwt.return_unit + (* else queue is empty and flush thread is done *) + | exception ex -> + drop_queue xmit_queue; + raise ex (* Enqueue [message] in [xmit_queue] and ensure the flush thread is running. *) - let queue_send ~xmit_queue endpoint message = + let queue_send ~sw ~xmit_queue endpoint message = Log.debug (fun f -> let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in f "queue_send: %d/%d allocated bytes in %d segs" @@ -108,19 +101,19 @@ module Make (Network : S.NETWORK) = struct let was_idle = Queue.is_empty xmit_queue in Queue.add message xmit_queue; Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total; - if was_idle then async_tagged "Message sender thread" (fun () -> flush ~xmit_queue endpoint) + if was_idle then Eio.Fiber.fork ~sw (fun () -> flush ~xmit_queue endpoint) let return_not_implemented t x = Log.debug (fun f -> f ~tags:(tags t) "Returning Unimplemented"); let open Builder in let m = Message.init_root () in let _ : Builder.Message.t = Message.unimplemented_set_reader m x in - queue_send ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m) + queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m) let listen t = let rec loop () = - Endpoint.recv t.endpoint >>= function - | Error e -> Lwt.return e + match Endpoint.recv t.endpoint with + | Error e -> e | Ok msg -> let open Reader.Message in let msg = of_message msg in @@ -134,8 +127,8 @@ module Make (Network : S.NETWORK) = struct | `Abort _ -> t.disconnecting <- true; Conn.handle_msg t.conn msg; - Endpoint.disconnect t.endpoint >>= fun () -> - Lwt.return `Aborted + Endpoint.disconnect t.endpoint; + `Aborted | _ -> Conn.handle_msg t.conn msg; loop () @@ -153,48 +146,53 @@ module Make (Network : S.NETWORK) = struct in loop () + let send_abort t ex = + queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex)) + let disconnect t ex = if not t.disconnecting then ( t.disconnecting <- true; - queue_send ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex)); - Endpoint.disconnect t.endpoint >|= fun () -> + send_abort t ex; + Endpoint.disconnect t.endpoint; Conn.disconnect t.conn ex - ) else ( - Lwt.return_unit ) let disconnecting t = t.disconnecting - let connect ~restore ?(tags=Logs.Tag.empty) endpoint = + let connect ~sw ~restore ?(tags=Logs.Tag.empty) endpoint = let xmit_queue = Queue.create () in - let queue_send msg = queue_send ~xmit_queue endpoint (Serialise.message msg) in + let queue_send msg = queue_send ~sw ~xmit_queue endpoint (Serialise.message msg) in let restore = Restorer.fn restore in - let conn = Conn.create ~restore ~tags ~queue_send in - let t = { + let fork = Fiber.fork ~sw in + let conn = Conn.create ~restore ~tags ~fork ~queue_send in + { + sw; conn; endpoint; xmit_queue; disconnecting = false; - } in + } + + let listen t = Prometheus.Gauge.inc_one Metrics.connections; - Lwt.async (fun () -> - Lwt.catch - (fun () -> - listen t >|= fun (`Closed | `Aborted) -> () - ) - (fun ex -> - Log.warn (fun f -> - f ~tags "Uncaught exception handling CapTP connection: %a (dropping connection)" Fmt.exn ex - ); - queue_send @@ `Abort (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)); - Lwt.return_unit - ) - >>= fun () -> - Log.info (fun f -> f ~tags "Connection closed"); - Prometheus.Gauge.dec_one Metrics.connections; + let tags = Conn.tags t.conn in + begin + match listen t with + | `Closed | `Aborted -> () + | exception Eio.Cancel.Cancelled ex -> + Log.debug (fun f -> f ~tags "Cancelled: %a" Fmt.exn ex) + | exception ex -> + Log.warn (fun f -> + f ~tags "Uncaught exception handling CapTP connection: %a (dropping connection)" Fmt.exn ex + ); + send_abort t (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)) + end; + Log.info (fun f -> f ~tags "Connection closed"); + Prometheus.Gauge.dec_one Metrics.connections; + Eio.Cancel.protect (fun () -> disconnect t (Capnp_rpc.Exception.v ~ty:`Disconnected "Connection closed") ); - t + Fiber.check () let dump f t = Conn.dump f t.conn end diff --git a/capnp-rpc-net/capTP_capnp.mli b/capnp-rpc-net/capTP_capnp.mli index d0ad8de25..e8a9e185a 100644 --- a/capnp-rpc-net/capTP_capnp.mli +++ b/capnp-rpc-net/capTP_capnp.mli @@ -4,17 +4,22 @@ module Make : S.NETWORK -> sig type t (** A Cap'n Proto RPC protocol handler. *) - val connect : restore:Restorer.t -> ?tags:Logs.Tag.set -> Endpoint.t -> t - (** [connect ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and + val connect : sw:Eio.Switch.t -> restore:Restorer.t -> ?tags:Logs.Tag.set -> Endpoint.t -> t + (** [connect ~sw ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and receives messages using [endpoint]. [restore] is used to respond to "Bootstrap" messages. - If the connection fails then [endpoint] will be disconnected. *) + If the connection fails then [endpoint] will be disconnected. + You must call {!listen} to run the loop handling messages. + @param sw Used to run methods and to run the transmit thread. *) + + val listen : t -> unit + (** [listen t] reads and handles incoming messages until the connection is finished. *) val bootstrap : t -> string -> 'a Capnp_rpc.Capability.t (** [bootstrap t object_id] is the peer's bootstrap object [object_id], if any. Use [object_id = ""] for the main, public object. *) - val disconnect : t -> Capnp_rpc.Exception.t -> unit Lwt.t + val disconnect : t -> Capnp_rpc.Exception.t -> unit (** [disconnect t reason] releases all resources used by the connection. *) val disconnecting : t -> bool diff --git a/capnp-rpc-net/capnp_rpc_net.ml b/capnp-rpc-net/capnp_rpc_net.ml index 4d49564dd..f11cf8929 100644 --- a/capnp-rpc-net/capnp_rpc_net.ml +++ b/capnp-rpc-net/capnp_rpc_net.ml @@ -11,11 +11,9 @@ module type VAT_NETWORK = S.VAT_NETWORK with type service_id := Restorer.Id.t and type 'a sturdy_ref := 'a Sturdy_ref.t -module Networking (N : S.NETWORK) (F : Mirage_flow.S) = struct - type flow = F.flow - +module Networking (N : S.NETWORK) = struct module Network = N - module Vat = Vat.Make (N) (F) + module Vat = Vat.Make (N) module CapTP = Vat.CapTP end diff --git a/capnp-rpc-net/capnp_rpc_net.mli b/capnp-rpc-net/capnp_rpc_net.mli index 7bdc2e57b..f4cec9b40 100644 --- a/capnp-rpc-net/capnp_rpc_net.mli +++ b/capnp-rpc-net/capnp_rpc_net.mli @@ -1,7 +1,7 @@ -(** This package adds networking support, including TLS. It contains code common - to capnp-rpc-unix and capnp-rpc-mirage. Libraries should not need to link against - this package (just use capnp-rpc-lwt instead), since they generally shouldn't - care whether services are local or remote. *) +(** This package adds networking support, including TLS. + Libraries should not need to link against this package (just use capnp-rpc + instead), since they generally shouldn't care whether services are local or + remote. *) open Capnp_rpc.Std @@ -90,7 +90,7 @@ module Restorer : sig (** [make_sturdy t id] converts an ID to a full URI, by adding the hosting vat's address and fingerprint. *) - val load : t -> 'a Sturdy_ref.t -> string -> resolution Lwt.t + val load : t -> 'a Sturdy_ref.t -> string -> resolution (** [load t sr digest] is called to restore the service with key [digest]. [sr] is a sturdy ref that refers to the service, which the service might want to hand out to clients. @@ -109,9 +109,10 @@ module Restorer : sig [make_sturdy id] converts an ID to a full URI, by adding the hosting vat's address and fingerprint. *) - val of_loader : (module LOADER with type t = 'loader) -> 'loader -> t - (** [of_loader (module Loader) l] is a new caching table that uses - [Loader.load l sr (Loader.hash id)] to restore services that aren't in the cache. *) + val of_loader : sw:Eio.Switch.t -> (module LOADER with type t = 'loader) -> 'loader -> t + (** [of_loader ~sw (module Loader) l] is a new caching table that uses + [Loader.load l sr (Loader.hash id)] to restore services that aren't in the cache. + The load function runs in a new fiber in [sw]. *) val add : t -> Id.t -> 'a Capability.t -> unit (** [add t id cap] adds a mapping to [t]. @@ -130,7 +131,7 @@ module Restorer : sig val of_table : Table.t -> t - val restore : t -> Id.t -> ('a Capability.t, Capnp_rpc.Exception.t) result Lwt.t + val restore : t -> Id.t -> ('a Capability.t, Capnp_rpc.Exception.t) result (** [restore t id] restores [id] using [t]. You don't normally need to call this directly, as the Vat will do it automatically. *) end @@ -141,8 +142,7 @@ module type VAT_NETWORK = S.VAT_NETWORK with type service_id := Restorer.Id.t and type 'a sturdy_ref := 'a Sturdy_ref.t -module Networking (N : S.NETWORK) (Flow : Mirage_flow.S) : VAT_NETWORK with - module Network = N and - type flow = Flow.flow +module Networking (N : S.NETWORK) : VAT_NETWORK with + module Network = N module Capnp_address = Capnp_address diff --git a/capnp-rpc-net/dune b/capnp-rpc-net/dune index 0c545769f..77afc8b4d 100644 --- a/capnp-rpc-net/dune +++ b/capnp-rpc-net/dune @@ -1,5 +1,5 @@ (library (name capnp_rpc_net) (public_name capnp-rpc-net) - (libraries astring capnp capnp-rpc fmt logs mirage-flow mirage-crypto mirage-crypto-rng - tls-mirage base64 uri ptime prometheus)) + (libraries astring capnp capnp-rpc fmt logs mirage-crypto-rng + tls-eio base64 uri ptime prometheus)) diff --git a/capnp-rpc-net/endpoint.ml b/capnp-rpc-net/endpoint.ml index 1ae0e03de..140f42e35 100644 --- a/capnp-rpc-net/endpoint.ml +++ b/capnp-rpc-net/endpoint.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std let src = Logs.Src.create "endpoint" ~doc:"Send and receive Cap'n'Proto messages" module Log = (val Logs.src_log src: Logs.LOG) @@ -7,21 +7,20 @@ let compression = `None let record_sent_messages = false -type flow = Flow : (module Mirage_flow.S with type flow = 'a) * 'a -> flow +type flow = Eio.Flow.two_way_ty r type t = { flow : flow; decoder : Capnp.Codecs.FramedStream.t; - switch : Lwt_switch.t; peer_id : Auth.Digest.t; } let peer_id t = t.peer_id -let of_flow (type flow) ~switch ~peer_id (module F : Mirage_flow.S with type flow = flow) (flow:flow) = - let generic_flow = Flow ((module F), flow) in +let of_flow ~peer_id flow = let decoder = Capnp.Codecs.FramedStream.empty compression in - { flow = generic_flow; decoder; switch; peer_id } + let flow = (flow :> flow) in + { flow; decoder; peer_id } let dump_msg = let next = ref 0 in @@ -34,37 +33,42 @@ let dump_msg = close_out ch let send t msg = - let (Flow ((module F), flow)) = t.flow in let data = Capnp.Codecs.serialize ~compression msg in if record_sent_messages then dump_msg data; - F.write flow (Cstruct.of_string data) >|= function - | Ok () - | Error `Closed as e -> e - | Error e -> Error (`Msg (Fmt.to_to_string F.pp_write_error e)) + match Eio.Flow.copy_string data t.flow with + | () + | exception End_of_file -> Ok () + | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> + Log.info (fun f -> f "%a" Eio.Exn.pp ex); + Error `Closed + | exception ex -> + Eio.Fiber.check (); + Error (`Msg (Printexc.to_string ex)) let rec recv t = - let (Flow ((module F), flow)) = t.flow in match Capnp.Codecs.FramedStream.get_next_frame t.decoder with - | _ when not (Lwt_switch.is_on t.switch) -> Lwt.return @@ Error `Closed - | Ok msg -> Lwt.return (Ok (Capnp.BytesMessage.Message.readonly msg)) + | Ok msg -> Ok (Capnp.BytesMessage.Message.readonly msg) | Error Capnp.Codecs.FramingError.Unsupported -> failwith "Unsupported Cap'n'Proto frame received" | Error Capnp.Codecs.FramingError.Incomplete -> Log.debug (fun f -> f "Incomplete; waiting for more data..."); - F.read flow >>= function - | Ok (`Data data) -> - Log.debug (fun f -> f "Read %d bytes" (Cstruct.length data)); - Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string data); + let buf = Cstruct.create 4096 in (* TODO: make this efficient *) + match Eio.Flow.single_read t.flow buf with + | got -> + Log.debug (fun f -> f "Read %d bytes" got); + Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string buf ~len:got); recv t - | Ok `Eof -> + | exception End_of_file -> Log.info (fun f -> f "Connection closed"); - Lwt_switch.turn_off t.switch >|= fun () -> Error `Closed - | Error ex when Lwt_switch.is_on t.switch -> Capnp_rpc.Debug.failf "recv: %a" F.pp_error ex - | Error _ -> Lwt.return (Error `Closed) + | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> + Log.info (fun f -> f "%a" Eio.Exn.pp ex); + Error `Closed let disconnect t = - Lwt_switch.turn_off t.switch - -let pp_error f = function - | `Closed -> Fmt.string f "Connection closed" - | `Msg m -> Fmt.string f m + try + Eio.Flow.shutdown t.flow `All + with + | Invalid_argument _ + | Eio.Io (Eio.Net.E Connection_reset _, _) -> + (* TCP connection already shut down, so TLS shutdown failed. Ignore. *) + () diff --git a/capnp-rpc-net/endpoint.mli b/capnp-rpc-net/endpoint.mli index 1fe0b9e01..674a6a296 100644 --- a/capnp-rpc-net/endpoint.mli +++ b/capnp-rpc-net/endpoint.mli @@ -6,27 +6,19 @@ val src : Logs.src type t (** A wrapper for a byte-stream (flow). *) -val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result Lwt.t +val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result (** [send t msg] transmits [msg]. *) -val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result Lwt.t +val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result (** [recv t] reads the next message from the remote peer. - It returns [Error `Closed] if the connection to the peer is lost - (this will also happen if the switch is turned off). *) + It returns [Error `Closed] if the connection to the peer is lost. *) -val of_flow : switch:Lwt_switch.t -> peer_id:Auth.Digest.t -> - (module Mirage_flow.S with type flow = 'flow) -> 'flow -> t -(** [of_flow ~switch ~peer_id (module F) flow] sends and receives on [flow]. - The caller should arrange for [flow] to be closed when the switch is turned off. - If the flow is closed, the switch will be turned off. - If the flow returns an error when the switch is off, the endpoint will return [`Closed] - instead of the underlying error. *) +val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t +(** [of_flow ~peer_id flow] sends and receives on [flow]. *) val peer_id : t -> Auth.Digest.t (** [peer_id t] is the fingerprint of the peer's public key, - or [Auth.Digest.insecure] TLS isn't being used. *) + or [Auth.Digest.insecure] if TLS isn't being used. *) -val disconnect : t -> unit Lwt.t -(** [disconnect t] turns off [t]'s switch. *) - -val pp_error : [< `Closed | `Msg of string] Fmt.t +val disconnect : t -> unit +(** [disconnect t] shuts down the underlying flow. *) diff --git a/capnp-rpc-net/restorer.ml b/capnp-rpc-net/restorer.ml index 11607aa94..bdc920fad 100644 --- a/capnp-rpc-net/restorer.ml +++ b/capnp-rpc-net/restorer.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc module Core_types = Private.Capnp_core.Core_types @@ -33,10 +33,10 @@ module type LOADER = sig type t val hash : t -> Auth.hash val make_sturdy : t -> Id.t -> Uri.t - val load : t -> 'a Sturdy_ref.t -> string -> resolution Lwt.t + val load : t -> 'a Sturdy_ref.t -> string -> resolution end -type t = Id.t -> resolution Lwt.t +type t = Id.t -> resolution let grant x : resolution = Ok (Cast.cap_to_raw x) let reject ex = Error ex @@ -45,21 +45,19 @@ let unknown_service_id = reject (Capnp_rpc.Exception.v "Unknown persistent servi let fn (r:t) = fun k object_id -> - Lwt.async (fun () -> - Lwt.try_bind - (fun () -> r object_id) - (fun r -> k r; Lwt.return_unit) - (fun ex -> - Log.err (fun f -> f "Uncaught exception restoring object: %a" Fmt.exn ex); - k (reject (Capnp_rpc.Exception.v "Internal error restoring object")); - Lwt.return_unit - ) - ) - -let restore (f:t) x = f x |> Lwt_result.map Cast.cap_of_raw + match r object_id with + | r -> k r + | exception (Eio.Cancel.Cancelled _ as ex) -> + k (reject Capnp_rpc.Exception.cancelled); + raise ex + | exception ex -> + Log.err (fun f -> f "Uncaught exception restoring object: %a" Fmt.exn ex); + k (reject (Capnp_rpc.Exception.v "Internal error restoring object")) + +let restore (f:t) x = f x |> Result.map Cast.cap_of_raw let none : t = fun _ -> - Lwt.return @@ Error (Capnp_rpc.Exception.v "This vat has no restorer") + Error (Capnp_rpc.Exception.v "This vat has no restorer") let single id cap = let cap = Cast.cap_to_raw cap in @@ -69,20 +67,20 @@ let single id cap = let requested_id = Digestif.SHA256.digest_string requested_id |> Digestif.SHA256.to_raw_string in if String.equal id requested_id then ( Core_types.inc_ref cap; - Lwt.return (Ok cap) - ) else Lwt.return unknown_service_id + Ok cap + ) else unknown_service_id module Table = struct type digest = string type entry = - | Cached of resolution Lwt.t + | Cached of resolution Promise.or_exn | Manual of Core_types.cap (* We hold a ref on the cap *) type t = { hash : Digestif.hash'; cache : (digest, entry) Hashtbl.t; - load : Id.t -> digest -> resolution Lwt.t; + load : Id.t -> digest -> resolution Promise.or_exn; make_sturdy : Id.t -> Uri.t; } @@ -91,7 +89,7 @@ module Table = struct let create make_sturdy = let hash = `SHA256 in let cache = Hashtbl.create 53 in - let load _ _ = Lwt.return unknown_service_id in + let load _ _ = Promise.create_resolved (Ok unknown_service_id) in { hash; cache; load; make_sturdy } let hash t id = @@ -102,43 +100,44 @@ module Table = struct match Hashtbl.find t.cache digest with | Manual cap -> Core_types.inc_ref cap; - Lwt.return @@ Ok cap + Ok cap | Cached res -> - begin res >>= function - | Error _ as e -> Lwt.return e + begin match Promise.await_exn res with + | Error _ as e -> e | Ok cap -> Core_types.inc_ref cap; - Lwt.pause () >|= fun () -> + Fiber.yield (); Ok cap end | exception Not_found -> let cap = t.load id digest in Hashtbl.add t.cache digest (Cached cap); - Lwt.try_bind - (fun () -> cap) - (fun result -> - begin match result with - | Error _ -> Hashtbl.remove t.cache digest - | Ok cap -> cap#when_released (fun () -> Hashtbl.remove t.cache digest) - end; - (* Ensure all [inc_ref]s are done before handing over to the user. *) - Lwt.pause () >|= fun () -> - result - ) - (fun ex -> - Hashtbl.remove t.cache digest; - Lwt.fail ex - ) - - let of_loader (type l) (module L : LOADER with type t = l) loader = + match Promise.await_exn cap with + | exception ex -> + Hashtbl.remove t.cache digest; + raise ex + | result -> + begin match result with + | Error _ -> Hashtbl.remove t.cache digest + | Ok cap -> + cap#when_released (fun () -> Hashtbl.remove t.cache digest); + (* Ensure all [inc_ref]s are done before handing over to the user. *) + try Fiber.yield () + with ex -> Core_types.dec_ref cap; raise ex + end; + result + + let of_loader (type l) ~sw (module L : LOADER with type t = l) loader = let hash = (L.hash loader :> Digestif.hash') in let cache = Hashtbl.create 53 in let rec load id digest = - let sr : Private.Capnp_core.sturdy_ref = object - method connect = resolve t id - method to_uri_with_secrets = L.make_sturdy loader id - end in - L.load loader (Cast.sturdy_of_raw sr) digest + Fiber.fork_promise ~sw (fun () -> + let sr : Private.Capnp_core.sturdy_ref = object + method connect = resolve t id + method to_uri_with_secrets = L.make_sturdy loader id + end in + L.load loader (Cast.sturdy_of_raw sr) digest + ) and t = { hash; cache; load; make_sturdy = L.make_sturdy loader } in t diff --git a/capnp-rpc-net/s.ml b/capnp-rpc-net/s.ml index 915806ef1..b73251ec1 100644 --- a/capnp-rpc-net/s.ml +++ b/capnp-rpc-net/s.ml @@ -27,15 +27,14 @@ module type NETWORK = sig val connect : t -> - switch:Lwt_switch.t -> + sw:Eio.Switch.t -> secret_key:Auth.Secret_key.t Lazy.t -> Address.t -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - (** [connect t ~switch ~secret_key address] connects to [address], proves ownership of + (Endpoint.t, [> `Msg of string]) result + (** [connect t ~sw ~secret_key address] connects to [address], proves ownership of [secret_key] (if TLS is being used), and returns the resulting endpoint. Returns an error if no connection can be established or the target fails - to authenticate itself. - If [switch] is turned off, the connection should be terminated. *) + to authenticate itself. *) val parse_third_party_cap_id : Capnp_rpc.Private.Schema.Reader.pointer_t -> Types.third_party_cap_id end @@ -47,9 +46,6 @@ module type VAT_NETWORK = sig type +'a capability (** An ['a capability] is a capability reference to a service of type ['a]. *) - type flow - (** A bi-directional byte-stream. *) - type restorer (** A function for restoring persistent capabilities from sturdy ref service IDs. *) @@ -67,17 +63,22 @@ module type VAT_NETWORK = sig type t (** A CapTP connection to a remote peer. *) - val connect : restore:restorer -> ?tags:Logs.Tag.set -> Endpoint.t -> t - (** [connect ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and + val connect : sw:Eio.Switch.t -> restore:restorer -> ?tags:Logs.Tag.set -> Endpoint.t -> t + (** [connect ~sw ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and receives messages using [endpoint]. [restore] is used to respond to "Bootstrap" messages. - If the connection fails then [endpoint] will be disconnected. *) + If the connection fails then [endpoint] will be disconnected. + You must call {!listen} to run the loop handling messages. + @param sw Used to run methods and to run the transmit thread. *) + + val listen : t -> unit + (** [listen t] reads and handles incoming messages until the connection is finished. *) val bootstrap : t -> service_id -> 'a capability (** [bootstrap t object_id] is the peer's bootstrap object [object_id], if any. Use [object_id = ""] for the main, public object. *) - val disconnect : t -> Capnp_rpc.Exception.t -> unit Lwt.t + val disconnect : t -> Capnp_rpc.Exception.t -> unit (** [disconnect reason] closes the connection, sending [reason] to the peer to explain why. Capabilities and questions at both ends will break, with [reason] as the problem. *) @@ -97,26 +98,28 @@ module type VAT_NETWORK = sig (** A local Vat. *) val create : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:restorer -> ?address:Network.Address.t -> + sw:Eio.Switch.t -> secret_key:Auth.Secret_key.t Lazy.t -> Network.t -> t - (** [create ~switch ~restore ~address ~secret_key network] is a new Vat that + (** [create ~sw ~restore ~address ~secret_key network] is a new Vat that uses [restore] to restore sturdy refs hosted at this vat to live capabilities for peers. The Vat will suggest that other parties connect to it using [address]. Turning off the switch will disconnect any active connections. *) - val add_connection : t -> switch:Lwt_switch.t -> mode:[`Accept|`Connect] -> Endpoint.t -> CapTP.t Lwt.t - (** [add_connection t ~switch ~mode endpoint] runs the CapTP protocol over [endpoint], + val run_connection : t -> mode:[`Accept|`Connect] -> Endpoint.t -> (CapTP.t -> unit) -> unit + (** [run_connection t ~mode endpoint r] runs the protocol over [endpoint], which is a connection to another vat. - When the connection ends, [switch] will be turned off, and turning off [switch] will - end the connection. + Once connected, [r conn] is called with the new connection. + When [run_connection] returns, [endpoint] can be closed. [mode] is used if two Vats connect to each other at the same time to decide which connection to drop. Use [`Connect] if [t] initiated the new - connection. Note that [add_connection] may return an existing connection. *) + connection. + If there is already a connection to [endpoint], [run_connection] may + call [r] on that instead and then return. *) val public_address : t -> Network.Address.t option (** [public_address t] is the address that peers should use when connecting diff --git a/capnp-rpc-net/tls_wrapper.ml b/capnp-rpc-net/tls_wrapper.ml index 25e615ebc..f2cc73faa 100644 --- a/capnp-rpc-net/tls_wrapper.ml +++ b/capnp-rpc-net/tls_wrapper.ml @@ -1,57 +1,40 @@ module Log = Capnp_rpc.Debug.Log -open Lwt.Infix open Auth let error fmt = fmt |> Fmt.kstr @@ fun msg -> Error (`Msg msg) -module Make (Underlying : Mirage_flow.S) = struct - module Flow = struct - include Tls_mirage.Make(Underlying) - - let read flow = - read flow >|= function - | Error (`Write `Closed) -> Ok `Eof (* This can happen, despite being a write error on a read! *) - | x -> x - - let writev flow bufs = - writev flow bufs >|= function - | Error (`Write `Closed) -> Error `Closed - | x -> x - - let write flow buf = writev flow [buf] - end - - let plain_endpoint ~switch flow = - Endpoint.of_flow ~switch ~peer_id:Auth.Digest.insecure (module Underlying) flow - - let connect_as_server ~switch flow secret_key = - match secret_key with - | None -> Lwt.return @@ Ok (plain_endpoint ~switch flow) - | Some key -> - Log.info (fun f -> f "Doing TLS server-side handshake..."); - let tls_config = Secret_key.tls_server_config key in - Flow.server_of_flow tls_config flow >|= function - | Error e -> error "TLS connection failed: %a" Flow.pp_write_error e - | Ok flow -> - match Flow.epoch flow with - | Error () -> failwith "Unknown error getting TLS epoch data" - | Ok data -> - match data.Tls.Core.peer_certificate with - | None -> error "No client certificate found" - | Some client_cert -> - let peer_id = Digest.of_certificate client_cert in - Ok (Endpoint.of_flow ~switch ~peer_id (module Flow) flow) - - let connect_as_client ~switch flow secret_key auth = - match Digest.authenticator auth with - | None -> Lwt.return @@ Ok (plain_endpoint ~switch flow) - | Some authenticator -> - let tls_config = Secret_key.tls_client_config ~authenticator (Lazy.force secret_key) in - Log.info (fun f -> f "Doing TLS client-side handshake..."); - Flow.client_of_flow tls_config flow >|= function - | Error e -> error "TLS connection failed: %a" Flow.pp_write_error e - | Ok flow -> Ok (Endpoint.of_flow ~switch ~peer_id:auth (module Flow) flow) -end +let plain_endpoint flow = + Endpoint.of_flow ~peer_id:Auth.Digest.insecure flow + +let connect_as_server flow secret_key = + match secret_key with + | None -> Ok (plain_endpoint flow) + | Some key -> + Log.info (fun f -> f "Doing TLS server-side handshake..."); + let tls_config = Secret_key.tls_server_config key in + match Tls_eio.server_of_flow tls_config flow with + | exception (Failure msg) -> error "TLS connection failed: %s" msg + | exception ex -> Eio.Fiber.check (); error "TLS connection failed: %a" Fmt.exn ex + | flow -> + match Tls_eio.epoch flow with + | Error () -> failwith "Unknown error getting TLS epoch data" + | Ok data -> + match data.Tls.Core.peer_certificate with + | None -> error "No client certificate found" + | Some client_cert -> + let peer_id = Digest.of_certificate client_cert in + Ok (Endpoint.of_flow ~peer_id flow) + +let connect_as_client flow secret_key auth = + match Digest.authenticator auth with + | None -> Ok (plain_endpoint flow) + | Some authenticator -> + let tls_config = Secret_key.tls_client_config ~authenticator (Lazy.force secret_key) in + Log.info (fun f -> f "Doing TLS client-side handshake..."); + match Tls_eio.client_of_flow tls_config flow with + | exception (Failure msg) -> error "TLS connection failed: %s" msg + | exception ex -> Eio.Fiber.check (); error "TLS connection failed: %a" Fmt.exn ex + | flow -> Ok (Endpoint.of_flow ~peer_id:auth flow) diff --git a/capnp-rpc-net/tls_wrapper.mli b/capnp-rpc-net/tls_wrapper.mli index f99c7b562..81c214f39 100644 --- a/capnp-rpc-net/tls_wrapper.mli +++ b/capnp-rpc-net/tls_wrapper.mli @@ -1,17 +1,13 @@ open Auth - -module Make (Underlying : Mirage_flow.S) : sig - (** Make an [Endpoint] from an [Underlying.flow], using TLS if appropriate. *) - - val connect_as_server : - switch:Lwt_switch.t -> Underlying.flow -> Auth.Secret_key.t option -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - - val connect_as_client : - switch:Lwt_switch.t -> Underlying.flow -> Auth.Secret_key.t Lazy.t -> Digest.t -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - (** [connect_as_client ~switch underlying key digest] is an endpoint using flow [underlying]. - If [digest] requires TLS, it performs a TLS handshake. It uses [key] as its private key - and checks that the server is the one required by [auth]. *) -end - +open Eio.Std + +val connect_as_server : + [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> Auth.Secret_key.t option -> + (Endpoint.t, [> `Msg of string]) result + +val connect_as_client : + [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> Auth.Secret_key.t Lazy.t -> Digest.t -> + (Endpoint.t, [> `Msg of string]) result +(** [connect_as_client underlying key digest] is an endpoint using flow [underlying]. + If [digest] requires TLS, it performs a TLS handshake. It uses [key] as its private key + and checks that the server is the one required by [auth]. *) diff --git a/capnp-rpc-net/two_party_network.ml b/capnp-rpc-net/two_party_network.ml index d800c9f40..a407d96d5 100644 --- a/capnp-rpc-net/two_party_network.ml +++ b/capnp-rpc-net/two_party_network.ml @@ -22,4 +22,4 @@ type t = unit let parse_third_party_cap_id _ = `Two_party_only -let connect () ~switch:_ ~secret_key:_ _ = assert false +let connect () ~sw:_ ~secret_key:_ _ = assert false diff --git a/capnp-rpc-net/vat.ml b/capnp-rpc-net/vat.ml index 498c4a7a0..64827d8e6 100644 --- a/capnp-rpc-net/vat.ml +++ b/capnp-rpc-net/vat.ml @@ -1,83 +1,75 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Log = Capnp_rpc.Debug.Log module ID_map = Auth.Digest.Map -module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct +module Make (Network : S.NETWORK) = struct module CapTP = CapTP_capnp.Make (Network) let hash = `SHA256 (* Only support a single hash for now *) - type connection_attempt = (CapTP.t, Capnp_rpc.Exception.t) result Lwt.t + type connection_attempt = (CapTP.t, Capnp_rpc.Exception.t) result Eio.Promise.t type t = { + sw : Eio.Switch.t; network : Network.t; - switch : Lwt_switch.t option; secret_key : Auth.Secret_key.t Lazy.t; address : Network.Address.t option; restore : Restorer.t; tags : Logs.Tag.set; - connection_removed : unit Lwt_condition.t; (* Fires when a connection is removed *) + connection_removed : Eio.Condition.t; (* Fires when a connection is removed *) mutable connecting : connection_attempt ID_map.t; (* Out-going connections being attempted. *) mutable connections : CapTP.t ID_map.t; (* Accepted connections *) mutable anon_connections : CapTP.t list; (* Connections not using TLS. *) } - let create ?switch ?(tags=Logs.Tag.empty) ?(restore=Restorer.none) ?address ~secret_key network = - let t = { + let create ?(tags=Logs.Tag.empty) ?(restore=Restorer.none) ?address ~sw ~secret_key network = + Fiber.fork_daemon ~sw Capnp_rpc.Leak_handler.run; + { + sw; network; - switch; secret_key; address; restore; tags; - connection_removed = Lwt_condition.create (); + connection_removed = Eio.Condition.create (); connecting = ID_map.empty; connections = ID_map.empty; anon_connections = []; - } in - Lwt_switch.add_hook switch (fun () -> - let ex = Capnp_rpc.Exception.v ~ty:`Disconnected "Vat shut down" in - ID_map.bindings t.connections |> Lwt_list.iter_p (fun (_, c) -> CapTP.disconnect c ex) >>= fun () -> - t.connections <- ID_map.empty; - Lwt_list.iter_p (fun c -> CapTP.disconnect c ex) t.anon_connections >|= fun () -> - t.anon_connections <- []; - ID_map.iter (fun _ th -> Lwt.cancel th) t.connecting; - t.connecting <- ID_map.empty; - ); - t - - let add_tls_connection t ~switch endpoint = - let conn = CapTP.connect ~tags:t.tags ~restore:t.restore endpoint in + } + + let run_connection_generic t ~add ~remove endpoint = + let conn = CapTP.connect ~sw:t.sw ~tags:t.tags ~restore:t.restore endpoint in + add conn; + Fun.protect (fun () -> CapTP.listen conn) + ~finally:(fun () -> + remove conn; + Eio.Condition.broadcast t.connection_removed + ) + + let run_connection_tls t endpoint r = let peer_id = Endpoint.peer_id endpoint in - t.connections <- ID_map.add peer_id conn t.connections; - Lwt_switch.add_hook (Some switch) (fun () -> - begin match ID_map.find peer_id t.connections with - | Some x when x == conn -> t.connections <- ID_map.remove peer_id t.connections - | Some _ (* Already replaced by a new one? *) - | None -> () - end; - CapTP.disconnect conn (Capnp_rpc.Exception.v ~ty:`Disconnected "Switch turned off") >|= fun () -> - Lwt_condition.broadcast t.connection_removed () - ); - conn - - let add_connection t ~switch ~(mode:[`Accept|`Connect]) endpoint = - let tags = t.tags in + run_connection_generic t endpoint + ~add:(fun conn -> t.connections <- ID_map.add peer_id conn t.connections; r conn) + ~remove:(fun conn -> + match ID_map.find peer_id t.connections with + | Some x when x == conn -> t.connections <- ID_map.remove peer_id t.connections + | Some _ (* Already replaced by a new one? *) + | None -> () + ) + + (* Run CapTP on [endpoint], calling [r conn] with the connection (possibly reusing an existing one). + If a new connection is used, it is also stored in [t] while running. *) + let run_connection t ~(mode:[`Accept|`Connect]) endpoint r = let peer_id = Endpoint.peer_id endpoint in if peer_id = Auth.Digest.insecure then ( - let conn = CapTP.connect ~tags ~restore:t.restore endpoint in - t.anon_connections <- conn :: t.anon_connections; - Lwt_switch.add_hook (Some switch) (fun () -> - t.anon_connections <- List.filter ((!=) conn) t.anon_connections; - CapTP.disconnect conn (Capnp_rpc.Exception.v ~ty:`Disconnected "Switch turned off") >|= fun () -> - Lwt_condition.broadcast t.connection_removed () - ); - Lwt.return conn + run_connection_generic t endpoint + ~add:(fun conn -> t.anon_connections <- conn :: t.anon_connections; r conn) + ~remove:(fun conn -> t.anon_connections <- List.filter ((!=) conn) t.anon_connections) ) else match ID_map.find peer_id t.connections with - | None -> Lwt.return @@ add_tls_connection t ~switch endpoint + | None -> run_connection_tls t endpoint r | Some existing -> Log.info (fun f -> f ~tags:t.tags "Trying to add a connection, but we already have one for this vat"); (* This can happen if two vats call each other at exactly the same time. @@ -127,70 +119,63 @@ module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct let my_id = Auth.Secret_key.digest ~hash (Lazy.force t.secret_key) in let keep_new = (my_id > peer_id) = (mode = `Connect) in if keep_new then ( - let conn = add_tls_connection t ~switch endpoint in let reason = Capnp_rpc.Exception.v "Closing duplicate connection" in - CapTP.disconnect existing reason >|= fun () -> - conn + CapTP.disconnect existing reason; + run_connection_tls t endpoint r ) else ( - Lwt_switch.turn_off switch >|= fun () -> - existing + Endpoint.disconnect endpoint; + r existing ) let public_address t = t.address - let connect_anon t addr ~service = - let switch = Lwt_switch.create () in - Network.connect t.network ~switch ~secret_key:t.secret_key addr >>= function - | Error (`Msg m) -> Lwt.return @@ Error (Capnp_rpc.Exception.v m) - | Ok ep -> - add_connection t ~switch ep ~mode:`Connect >|= fun conn -> - Ok (CapTP.bootstrap conn service) - - let initiate_connection t remote_id addr service = - (* We need to start a new connection attempt. *) - let switch = Lwt_switch.create () in - let conn = - Network.connect t.network ~switch ~secret_key:t.secret_key addr >>= function - | Error (`Msg m) -> Lwt.return @@ Error (Capnp_rpc.Exception.v m) - | Ok ep -> add_connection t ~switch ep ~mode:`Connect >|= fun conn -> Ok conn - in - t.connecting <- ID_map.add remote_id conn t.connecting; - conn >|= fun conn -> - t.connecting <- ID_map.remove remote_id t.connecting; - match conn with - | Ok conn -> Ok (CapTP.bootstrap conn service) - | Error _ as e -> e - - let rec connect_auth t remote_id addr ~service = + (* Make a new connection to remote service [addr] and request [service] from it. *) + let initiate_connection t addr service = + let remote_id = Network.Address.digest addr in + let p, r = Promise.create () in + let tracked = remote_id <> Auth.Digest.insecure in + if tracked then t.connecting <- ID_map.add remote_id p t.connecting; + Fun.protect + ~finally:(fun () -> + if tracked then t.connecting <- ID_map.remove remote_id t.connecting; + if not (Promise.is_resolved p) then Promise.resolve_error r Capnp_rpc.Exception.cancelled + ) + (fun () -> + Fiber.fork_daemon ~sw:t.sw (fun () -> + Switch.run (fun sw -> + match Network.connect ~sw t.network ~secret_key:t.secret_key addr with + | Error (`Msg m) -> Promise.resolve_error r (Capnp_rpc.Exception.v m) + | Ok ep -> run_connection t ep (Promise.resolve_ok r) ~mode:`Connect + ); + `Stop_daemon + ); + Promise.await p + ) + |> Result.map (fun conn -> CapTP.bootstrap conn service) + + (* Get a connection to [addr] and request [service] from it. *) + let rec connect t (addr, service) = + let remote_id = Network.Address.digest addr in let my_id = Auth.Secret_key.digest ~hash (Lazy.force t.secret_key) in if Auth.Digest.equal remote_id my_id then Restorer.restore t.restore service else match ID_map.find remote_id t.connections with | Some conn when CapTP.disconnecting conn -> - Lwt_condition.wait t.connection_removed >>= fun () -> - connect_auth t remote_id addr ~service + Eio.Condition.await_no_mutex t.connection_removed; + connect t (addr, service) | Some conn -> (* Already connected; use that. *) - Lwt.return @@ Ok (CapTP.bootstrap conn service) + Ok (CapTP.bootstrap conn service) | None -> match ID_map.find remote_id t.connecting with - | None -> initiate_connection t remote_id addr service + | None -> initiate_connection t addr service | Some conn -> (* We're already trying to establish a connection, wait for that. *) - conn >|= function - | Ok conn -> Ok (CapTP.bootstrap conn service) - | Error _ as e -> e + Promise.await conn |> Result.map (fun conn -> CapTP.bootstrap conn service) let make_sturdy_ref t sr = Capnp_rpc.Cast.sturdy_of_raw @@ object (_ : Capnp_rpc.Private.Capnp_core.sturdy_ref) - method connect = - let (addr, service) = sr in - let remote_id = Network.Address.digest addr in - Lwt_result.map Capnp_rpc.Cast.cap_to_raw ( - if remote_id = Auth.Digest.insecure then connect_anon t addr ~service - else connect_auth t remote_id addr ~service - ) - + method connect = Result.map Capnp_rpc.Cast.cap_to_raw (connect t sr) method to_uri_with_secrets = Network.Address.to_uri sr end diff --git a/capnp-rpc-unix.opam b/capnp-rpc-unix.opam index f9e2ade4b..1d4f31a40 100644 --- a/capnp-rpc-unix.opam +++ b/capnp-rpc-unix.opam @@ -13,7 +13,7 @@ depends: [ "ocaml" {>= "4.08.0"} "capnp-rpc-net" {= version} "cmdliner" {>= "1.1.0"} - "cstruct-lwt" + "cstruct" {>= "6.2.0"} "astring" "fmt" {>= "0.8.7"} "logs" @@ -21,11 +21,11 @@ depends: [ "base64" {>= "3.0.0"} "dune" {>= "3.16"} "alcotest" {>= "1.6.0" & with-test} - "alcotest-lwt" { >= "1.6.0" & with-test} - "mirage-crypto-rng-lwt" {>= "0.11.0"} - "mdx" {>= "2.2.1" & with-test} - "lwt" {>= "5.6.1"} + "mirage-crypto-rng-eio" {>= "1.1.0" & with-test} + "mdx" {>= "2.4.1" & with-test} "asetmap" {with-test} + "eio_main" {with-test} + "eio" {>= "1.1"} ] conflicts: [ "jbuilder" diff --git a/capnp-rpc.opam b/capnp-rpc.opam index ee9c3fa27..1ebd914ef 100644 --- a/capnp-rpc.opam +++ b/capnp-rpc.opam @@ -3,7 +3,7 @@ synopsis: "Cap'n Proto is a capability-based RPC system with bindings for many languages" description: """ This package provides a version of the Cap'n Proto RPC system using the Cap'n -Proto serialisation format and Lwt for concurrency.""" +Proto serialisation format and Eio for concurrency.""" maintainer: "Thomas Leonard " authors: "Thomas Leonard " license: "Apache-2.0" @@ -15,7 +15,7 @@ depends: [ "conf-capnproto" {build} "capnp" {>= "3.6.0"} "stdint" {>= "0.6.0"} - "lwt" {>= "5.6.1"} + "eio" {>= "1.1"} "astring" "fmt" {>= "0.8.7"} "logs" diff --git a/capnp-rpc/capability.ml b/capnp-rpc/capability.ml index dcf8ef2a3..94101d735 100644 --- a/capnp-rpc/capability.ml +++ b/capnp-rpc/capability.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_core module Log = Capnp_rpc_proto.Debug.Log @@ -14,9 +14,9 @@ let inc_ref = Core_types.inc_ref let dec_ref = Core_types.dec_ref let with_ref t fn = - Lwt.finalize + Fun.protect (fun () -> fn t) - (fun () -> dec_ref t; Lwt.return_unit) + ~finally:(fun () -> dec_ref t) let pp f x = x#pp f @@ -26,10 +26,10 @@ let when_released (x:Core_types.cap) f = x#when_released f let problem x = x#problem let wait_until_settled (x : _ t) = - let result, set_result = Lwt.wait () in + let result, set_result = Promise.create () in let rec aux x = if x#blocker = None then ( - Lwt.wakeup set_result () + Promise.resolve set_result () ) else ( x#when_more_resolved (fun x -> Core_types.dec_ref x; @@ -38,16 +38,16 @@ let wait_until_settled (x : _ t) = ) in aux x; - result + Promise.await result let await_settled t = - wait_until_settled t >|= fun () -> + wait_until_settled t; match problem t with | None -> Ok () | Some ex -> Error ex let await_settled_exn t = - wait_until_settled t >|= fun () -> + wait_until_settled t; match problem t with | None -> () | Some e -> Fmt.failwith "%a" Capnp_rpc_proto.Exception.pp e @@ -72,31 +72,33 @@ let call (target : 't capability_t) (m : ('t, 'a, 'b) method_t) (req : 'a Reques results let call_and_wait cap (m : ('t, 'a, 'b StructStorage.reader_t) method_t) req = - let p, r = Lwt.task () in + let p, r = Promise.create () in let result = call cap m req in let finish = lazy (Core_types.dec_ref result) in - Lwt.on_cancel p (fun () -> Lazy.force finish); result#when_resolved (function - | Error e -> Lwt.wakeup r (Error (`Capnp e)) + | Error e -> Promise.resolve_error r (`Capnp e) | Ok resp -> Lazy.force finish; let payload = Msg.Response.readable resp in let release_response_caps () = Core_types.Response_payload.release resp in let contents = Schema.Reader.Payload.content_get payload |> Schema.Reader.of_pointer in - Lwt.wakeup r @@ Ok (contents, release_response_caps) + Promise.resolve_ok r (contents, release_response_caps) ); - p + try Promise.await p + with ex -> + Lazy.force finish; + raise ex let call_for_value cap m req = - call_and_wait cap m req >|= function + match call_and_wait cap m req with | Error _ as response -> response | Ok (response, release_response_caps) -> release_response_caps (); Ok response let call_for_value_exn cap m req = - call_for_value cap m req >>= function - | Ok x -> Lwt.return x + match call_for_value cap m req with + | Ok x -> x | Error (`Capnp e) -> Log.debug (fun f -> f "Error calling %t(%a): %a" cap#pp @@ -105,11 +107,11 @@ let call_for_value_exn cap m req = Fmt.failwith "%a: %a" Capnp.RPC.MethodID.pp m Capnp_rpc_proto.Error.pp e let call_for_unit cap m req = - call_for_value cap m req >|= function + match call_for_value cap m req with | Ok _ -> Ok () | Error _ as e -> e -let call_for_unit_exn cap m req = call_for_value_exn cap m req >|= ignore +let call_for_unit_exn cap m req = call_for_value_exn cap m req |> ignore let call_for_caps cap m req fn = let q = call cap m req in diff --git a/capnp-rpc/capnp_core.ml b/capnp-rpc/capnp_core.ml index 733c78147..fcddd1654 100644 --- a/capnp-rpc/capnp_core.ml +++ b/capnp-rpc/capnp_core.ml @@ -1,13 +1,7 @@ -open Lwt.Infix - module Capnp_content = struct include Msg - let ref_leak_detected fn = - Lwt.async (fun () -> - Lwt.pause () >|= fun () -> - fn () - ) + let ref_leak_detected = Leak_handler.ref_leak_detected end module Core_types = Capnp_rpc_proto.Core_types(Capnp_content) @@ -19,6 +13,6 @@ module type ENDPOINT = Capnp_rpc_proto.Message_types.ENDPOINT with module Core_types = Core_types class type sturdy_ref = object - method connect : (Core_types.cap, Capnp_rpc_proto.Exception.t) result Lwt.t + method connect : (Core_types.cap, Capnp_rpc_proto.Exception.t) result method to_uri_with_secrets : Uri.t end diff --git a/capnp-rpc/capnp_rpc.ml b/capnp-rpc/capnp_rpc.ml index 5f64fe040..6743dcce6 100644 --- a/capnp-rpc/capnp_rpc.ml +++ b/capnp-rpc/capnp_rpc.ml @@ -7,6 +7,7 @@ module Error = Capnp_rpc_proto.Error module Log = Capnp_rpc_proto.Debug.Log module RO_array = Capnp_rpc_proto.RO_array module Debug = Capnp_rpc_proto.Debug +module Leak_handler = Leak_handler module Capability = Capability diff --git a/capnp-rpc/capnp_rpc.mli b/capnp-rpc/capnp_rpc.mli index 5ffd08dd0..63591f22f 100644 --- a/capnp-rpc/capnp_rpc.mli +++ b/capnp-rpc/capnp_rpc.mli @@ -1,4 +1,4 @@ -(** Cap'n Proto RPC using the Cap'n Proto serialisation and Lwt for concurrency. *) +(** Cap'n Proto core API for defining and using services. *) include (module type of Capnp.BytesMessage) (** @closed *) @@ -57,7 +57,7 @@ module Capability : sig believed to be healthy. Once a capability is broken, it will never work again and any calls made on it will fail with exception [ex]. *) - val await_settled : 'a t -> (unit, Exception.t) Lwt_result.t + val await_settled : 'a t -> (unit, Exception.t) result (** [await_settled t] resolves once [t] is a "settled" (non-promise) reference. If [t] is a near, far or broken reference, this returns immediately. If it is currently a local or remote promise, it waits until it isn't. @@ -66,13 +66,10 @@ module Capability : sig @return [Ok ()] on success, or [Error _] if [t] failed. @since 1.2 *) - val await_settled_exn : 'a t -> unit Lwt.t + val await_settled_exn : 'a t -> unit (** Like [await_settled], but raises an exception on error. @since 1.2 *) - val wait_until_settled : 'a t -> unit Lwt.t - [@@deprecated "Use await_settled instead."] - val equal : 'a t -> 'a t -> (bool, [`Unsettled]) result (** [equal a b] indicates whether [a] and [b] designate the same settled service. Returns [Error `Unsettled] if [a] or [b] is still a promise (and they therefore @@ -106,7 +103,7 @@ module Capability : sig instead for a simpler interface). *) val call_and_wait : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> - 'a Request.t -> (('b StructStorage.reader_t * (unit -> unit)), [> `Capnp of Error.t]) Lwt_result.t + 'a Request.t -> (('b StructStorage.reader_t * (unit -> unit)), [> `Capnp of Error.t]) result (** [call_and_wait t m req] does [call t m req] and waits for the response. This is simpler than using [call], but doesn't support pipelining (you can't use any capabilities in the response in another message until the @@ -116,26 +113,25 @@ module Capability : sig contain that you didn't use (remembering that future versions of the protocol might add new optional capabilities you don't know about yet). If you don't need any capabilities from the result, consider using [call_for_value] instead. - Doing [Lwt.cancel] on the result will send a cancel message to the target - for remote calls. *) + Cancelling the fiber will send a cancel message to the target for remote calls. *) val call_for_value : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> - 'a Request.t -> ('b StructStorage.reader_t, [> `Capnp of Error.t]) Lwt_result.t + 'a Request.t -> ('b StructStorage.reader_t, [> `Capnp of Error.t]) result (** [call_for_value t m req] is similar to [call_and_wait], but automatically releases any capabilities in the response before returning. Use this if you aren't expecting any capabilities in the response. *) val call_for_value_exn : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> - 'a Request.t -> 'b StructStorage.reader_t Lwt.t - (** Wrapper for [call_for_value] that turns errors into Lwt failures. *) + 'a Request.t -> 'b StructStorage.reader_t + (** Wrapper for [call_for_value] that turns errors into exceptions. *) val call_for_unit : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> - 'a Request.t -> (unit, [> `Capnp of Error.t]) Lwt_result.t + 'a Request.t -> (unit, [> `Capnp of Error.t]) result (** Wrapper for [call_for_value] that ignores the result. *) val call_for_unit_exn : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> - 'a Request.t -> unit Lwt.t - (** Wrapper for [call_for_unit] that turns errors into Lwt failures. *) + 'a Request.t -> unit + (** Wrapper for [call_for_unit] that turns errors into exceptions. *) val call_for_caps : 't t -> ('t, 'a, 'b StructStorage.reader_t) Capnp.RPC.MethodID.t -> 'a Request.t -> ('b StructRef.t -> 'c) -> 'c @@ -168,7 +164,7 @@ module Capability : sig peer. Any time you extract a capability from a struct or struct promise, it must eventually be freed by calling [dec_ref] on it. *) - val with_ref : 'a t -> ('a t -> 'b Lwt.t) -> 'b Lwt.t + val with_ref : 'a t -> ('a t -> 'b) -> 'b (** [with_ref t fn] runs [fn t] and then calls [dec_ref t] (whether [fn] succeeds or not). *) @@ -185,22 +181,22 @@ module Sturdy_ref : sig - A way to authenticate the hosting vat (e.g. a fingerprint of the vat's public key) - A way to identify the target service within the vat and prove permission to access it (e.g. a "Swiss number") - *) + *) - val connect : 'a t -> ('a Capability.t, Exception.t) result Lwt.t + val connect : 'a t -> ('a Capability.t, Exception.t) result (** [connect t] returns a live reference to [t]'s service. *) - val connect_exn : 'a t -> 'a Capability.t Lwt.t - (** [connect_exn] is a wrapper for [connect] that returns a failed Lwt thread on error. *) + val connect_exn : 'a t -> 'a Capability.t + (** [connect_exn] is a wrapper for [connect] that raises an exception on error. *) val with_cap : 'a t -> - ('a Capability.t -> ('b, [> `Capnp of Exception.t] as 'e) Lwt_result.t) -> - ('b, 'e) Lwt_result.t + ('a Capability.t -> ('b, [> `Capnp of Exception.t] as 'e) result) -> + ('b, 'e) result (** [with_cap t f] uses [connect t] to get a live-ref [x], then does [Capability.with_ref x f]. *) - val with_cap_exn : 'a t -> ('a Capability.t -> 'b Lwt.t) -> 'b Lwt.t + val with_cap_exn : 'a t -> ('a Capability.t -> 'b) -> 'b (** [with_cap_exn t f] uses [connect_exn t] to get a live-ref [x], then does [Capability.with_ref x f]. *) @@ -253,22 +249,11 @@ module Service : sig val return_empty : unit -> 'a StructRef.t (** [return_empty ()] is a promise for a response with no payload. *) - val return_lwt : (unit -> ('a Response.t, [< `Capnp of Error.t]) Lwt_result.t) -> 'a StructRef.t - (** [return_lwt fn] is a local promise for the result of Lwt thread [fn ()]. - If [fn ()] fails, the error is logged and an "Internal error" returned to the caller. - If it returns an [Error] value then that error is returned to the caller. - Note that this does not support pipelining (any messages sent to the response - will be queued locally until it [fn] has produced a result), so it may be better - to return immediately a result containing a promise in some cases. *) + val error : Error.t -> 'a StructRef.t + (** [error e] is a broken promise for a struct, with error [e]. *) val fail : ?ty:Exception.ty -> ('a, Format.formatter, unit, 'b StructRef.t) format4 -> 'a - (** [fail msg] is an exception with reason [msg]. *) - - val fail_lwt : - ?ty:Exception.ty -> - ('a, Format.formatter, unit, (_, [> `Capnp of Error.t]) Lwt_result.t) format4 -> - 'a - (** [fail_lwt msg] is like [fail msg], but can be used with [return_lwt]. *) + (** [fail msg] is an exception {!error} with reason [msg]. *) end (** Some aliases for common modules. @@ -280,6 +265,8 @@ module Std : sig module Service = Service end +module Leak_handler = Leak_handler + (**/**) module Untyped : sig @@ -329,7 +316,7 @@ module Debug = Capnp_rpc_proto.Debug module Persistence : sig class type ['a] persistent = object - method save : ('a Sturdy_ref.t, Exception.t) result Lwt.t + method save : ('a Sturdy_ref.t, Exception.t) result end val with_persistence : @@ -348,11 +335,11 @@ module Persistence : sig (** [with_sturdy_ref sr Service.Foo.local obj] is like [Service.Foo.local obj], but responds to [save] calls by returning [sr]. *) - val save : 'a Capability.t -> (Uri.t, [> `Capnp of Error.t]) Lwt_result.t + val save : 'a Capability.t -> (Uri.t, [> `Capnp of Error.t]) result (** [save cap] calls the persistent [save] method on [cap]. Note that not all capabilities can be saved. todo: this should return an ['a Sturdy_ref.t]; see {!Sturdy_ref.reader}. *) - val save_exn : 'a Capability.t -> Uri.t Lwt.t + val save_exn : 'a Capability.t -> Uri.t (** [save_exn] is a wrapper for [save] that returns a failed thread on error. *) end diff --git a/capnp-rpc/dune b/capnp-rpc/dune index b69d2cef7..e0cb0cee2 100644 --- a/capnp-rpc/dune +++ b/capnp-rpc/dune @@ -1,7 +1,7 @@ (library (name capnp_rpc) (public_name capnp-rpc) - (libraries astring capnp capnp-rpc.proto fmt logs lwt uri)) + (libraries astring capnp capnp-rpc.proto fmt logs eio uri)) (rule (targets rpc_schema.ml rpc_schema.mli) diff --git a/capnp-rpc/leak_handler.ml b/capnp-rpc/leak_handler.ml new file mode 100644 index 000000000..bf7326812 --- /dev/null +++ b/capnp-rpc/leak_handler.ml @@ -0,0 +1,54 @@ +module M = Map.Make(Int) + +module Log = Capnp_rpc_proto.Debug.Log + +(* A map from thread IDs to (n, q) pairs. + [q] is a queue of callbacks waiting to be run in the thread + and [n] is the number of loops consuming [q] (typically 1). *) +let handlers : (int * (unit -> unit) Eio.Stream.t) M.t Atomic.t = Atomic.make M.empty + +(* [add_handler id] increments the counter for thread [id] and returns the queue. + If there isn't one yet, it creates a new one. *) +let rec add_handler id = + let old = Atomic.get handlers in + let handler = + match M.find_opt id old with + | None -> (1, Eio.Stream.create max_int) + | Some (n, q) -> (n + 1, q) + in + let next = M.add id handler old in + if Atomic.compare_and_set handlers old next then snd handler + else add_handler id + +let rec remove_handler id = + let old = Atomic.get handlers in + let n, q = M.find id old in + let next = + if n > 1 then M.add id (n - 1, q) old + else M.remove id old + in + if not (Atomic.compare_and_set handlers old next) then remove_handler id + +let run () = + let id = Thread.(id (self ())) in + let q = add_handler id in + try + while true do + let fn = Eio.Stream.take q in + try + fn () + with ex -> + let bt = Printexc.get_raw_backtrace () in + Eio.Fiber.check (); + Log.warn (fun f -> f "Uncaught exception handling ref-leak: %a" Fmt.exn_backtrace (ex, bt)) + done + with ex -> + remove_handler id; + raise ex + +let ref_leak_detected thread fn = + match M.find_opt thread (Atomic.get handlers) with + | Some (_, q) -> Eio.Stream.add q fn + | None -> + Capnp_rpc_proto.Debug.Log.debug + (fun f -> f "Leak detected, but no leak reporter is running so ignoring") diff --git a/capnp-rpc/leak_handler.mli b/capnp-rpc/leak_handler.mli new file mode 100644 index 000000000..f54b3f73b --- /dev/null +++ b/capnp-rpc/leak_handler.mli @@ -0,0 +1,22 @@ +(** Handle references that got GC'd with a non-zero ref-count. + + If an application forgets to release a resource and it gets GC'd then we want to + log a warning and clean up (so forgotten refs don't build up over time). + + Because GC finalizers can run at any time and from any thread, + we need to pass the cleanup callback to a fiber running in the owning thread. *) + +val run : unit -> 'a +(** [run ()] registers a leak handler for the current thread and + runs a loop that waits for callbacks and runs them. + If the fiber is cancelled, the handler is removed. + + Each vat runs this in a daemon fiber. + It is safe to have multiple such fibers running in a single systhread. *) + +val ref_leak_detected : int -> (unit -> unit) -> unit +(** [ref_leak_detected thread_id fn] should be called from a GC finalizer if + the resource was not properly released. + + If a handler for [thread_id] is running (see {!run}) then it will schedule + [fn] to run at a safe point in that thread. If not, [fn] is ignored. *) diff --git a/capnp-rpc/persistence.ml b/capnp-rpc/persistence.ml index 060cb1b1f..bce41ac07 100644 --- a/capnp-rpc/persistence.ml +++ b/capnp-rpc/persistence.ml @@ -1,9 +1,7 @@ -open Lwt.Infix - module Api = Persistent.Make(Capnp.BytesMessage) class type ['a] persistent = object - method save : ('a Sturdy_ref.t, Capnp_rpc_proto.Exception.t) result Lwt.t + method save : ('a Sturdy_ref.t, Capnp_rpc_proto.Exception.t) result end let with_persistence @@ -16,13 +14,12 @@ let with_persistence if method_id = Capnp.RPC.MethodID.method_id Api.Client.Persistent.Save.method_id then ( let open Api.Service.Persistent.Save in release_params (); - Service.return_lwt @@ fun () -> - persistent#save >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persistent#save with + | Error e -> Service.error (`Exception e) | Ok sr -> let resp, results = Service.Response.create Results.init_pointer in Sturdy_ref.builder Results.sturdy_ref_get results sr; - Ok resp + Service.return resp ) else ( release_params (); Service.fail ~ty:`Unimplemented "Unknown persistence method %d" method_id @@ -39,18 +36,18 @@ let with_persistence let with_sturdy_ref sr local impl = let persistent = object - method save = Lwt.return (Ok sr) + method save = Ok sr end in with_persistence persistent local impl let save cap = let open Api.Client.Persistent.Save in let request = Capability.Request.create_no_args () in - Capability.call_for_value cap method_id request >|= function + match Capability.call_for_value cap method_id request with | Error _ as e -> e | Ok response -> Ok (Sturdy_ref.reader Results.sturdy_ref_get response) let save_exn cap = - save cap >>= function - | Error (`Capnp e) -> Lwt.fail_with (Fmt.to_to_string Capnp_rpc_proto.Error.pp e) - | Ok x -> Lwt.return x + match save cap with + | Error (`Capnp e) -> failwith (Fmt.to_to_string Capnp_rpc_proto.Error.pp e) + | Ok x -> x diff --git a/capnp-rpc/proto/capTP.ml b/capnp-rpc/proto/capTP.ml index 160669fa1..58e41184f 100644 --- a/capnp-rpc/proto/capTP.ml +++ b/capnp-rpc/proto/capTP.ml @@ -705,6 +705,7 @@ module Make (EP : Message_types.ENDPOINT) = struct tags : Logs.Tag.set; embargoes : (EmbargoId.t * Cap_proxy.resolver_cap) Embargoes.t; restore : restorer; + fork : (unit -> unit) -> unit; questions : Question.t Questions.t; answers : Answer.t Answers.t; @@ -740,11 +741,12 @@ module Make (EP : Message_types.ENDPOINT) = struct let default_restore k _object_id = k @@ Error (Exception.v "This vat has no restorer") - let create ?(restore=default_restore) ~tags ~queue_send = + let create ?(restore=default_restore) ~tags ~fork ~queue_send = { queue_send = (queue_send :> EP.Out.t -> unit); tags; restore = restore; + fork; questions = Questions.make (); answers = Answers.make (); imports = Imports.make (); @@ -1132,6 +1134,7 @@ module Make (EP : Message_types.ENDPOINT) = struct let make ~(release:unit Lazy.t) ~settled ~strong_proxy init = object (self : #Core_types.cap) val id = Debug.OID.next () + val thread_id = Thread.(id (self ())) val mutable state = Unset { rc = RC.one; handler = init; on_set = Queue.create (); on_release = Queue.create () } @@ -1215,7 +1218,7 @@ module Make (EP : Message_types.ENDPOINT) = struct | Gc -> begin match state with | Unset x -> - Core_types.Wire.ref_leak_detected (fun () -> + Core_types.Wire.ref_leak_detected thread_id (fun () -> if RC.is_zero x.rc then ( Log.warn (fun f -> f "@[Reference GC'd with non-zero ref-count!@,%t@,\ But, ref-count is now zero, so a previous GC leak must have fixed it.@]" @@ -1465,8 +1468,10 @@ module Make (EP : Message_types.ENDPOINT) = struct | `Local target -> Log.debug (fun f -> f ~tags:t.tags "Handling call: (%t).call %a" target#pp Core_types.Request_payload.pp msg); - target#call answer_resolver msg; (* Takes ownership of [caps]. *) - dec_ref target + t.fork (fun () -> + target#call answer_resolver msg; (* Takes ownership of [caps]. *) + dec_ref target + ) | #message_target_cap as target -> Log.debug (fun f -> f ~tags:t.tags "Forwarding call: (%a).call %a" pp_message_target_cap target Core_types.Request_payload.pp msg); @@ -1476,6 +1481,7 @@ module Make (EP : Message_types.ENDPOINT) = struct let promise, answer_resolver = Local_struct_promise.make () in let answer = Answer.create id ~answer:promise in Answers.set t.answers id answer; + t.fork @@ fun () -> object_id |> t.restore @@ fun service -> if Answer.needs_return answer && t.disconnected = None then ( let results = diff --git a/capnp-rpc/proto/capTP.mli b/capnp-rpc/proto/capTP.mli index 93eb5fa0f..1b47b8238 100644 --- a/capnp-rpc/proto/capTP.mli +++ b/capnp-rpc/proto/capTP.mli @@ -12,11 +12,13 @@ module Make (EP : Message_types.ENDPOINT) : sig capability. *) val create : ?restore:restorer -> tags:Logs.Tag.set -> + fork:((unit -> unit) -> unit) -> queue_send:([> EP.Out.t] -> unit) -> t - (** [create ~bootstrap ~tags ~queue_send] is a handler for a connection to a remote peer. + (** [create ~restore ~tags ~fork ~queue_send] is a handler for a connection to a remote peer. Messages will be sent to the peer by calling [queue_send] (which MUST deliver them in order). - If the remote peer asks for the bootstrap object, it will be given a reference to [bootstrap]. - Log messages will be tagged with [tags]. *) + If the remote peer asks for a bootstrap object, [restore] will be used to get it. + Log messages will be tagged with [tags]. + @param fork is used when dispatching a local method handler. *) val bootstrap : t -> string -> EP.Core_types.cap (** [bootstrap t object_id] returns a reference to the remote peer's bootstrap object, if any. diff --git a/capnp-rpc/proto/core_types.ml b/capnp-rpc/proto/core_types.ml index d6ef5ac1b..045c20c13 100644 --- a/capnp-rpc/proto/core_types.ml +++ b/capnp-rpc/proto/core_types.ml @@ -51,6 +51,7 @@ module Make(Wire : S.WIRE) = struct class virtual ref_counted = object (self : #base_ref) + val thread_id = Thread.(id (self ())) val mutable ref_count = RC.one method private virtual release : unit method virtual pp : Format.formatter -> unit @@ -72,7 +73,7 @@ module Make(Wire : S.WIRE) = struct method sealed_dispatch : type a. a S.brand -> a option = function | Gc -> if not (RC.is_zero ref_count) then ( - ref_leak_detected (fun () -> + ref_leak_detected thread_id (fun () -> if RC.is_zero ref_count then ( Log.warn (fun f -> f "@[Reference GC'd with non-zero ref-count!@,%t@,\ But, ref-count is now zero, so a previous GC leak must have fixed it.@]" diff --git a/capnp-rpc/proto/dune b/capnp-rpc/proto/dune index dcf77fc2e..1ba98d6e6 100644 --- a/capnp-rpc/proto/dune +++ b/capnp-rpc/proto/dune @@ -1,4 +1,4 @@ (library (name capnp_rpc_proto) (public_name capnp-rpc.proto) - (libraries astring fmt logs stdint asetmap)) + (libraries astring fmt logs stdint asetmap threads)) diff --git a/capnp-rpc/proto/s.ml b/capnp-rpc/proto/s.ml index a846fcebb..0e6e215cd 100644 --- a/capnp-rpc/proto/s.ml +++ b/capnp-rpc/proto/s.ml @@ -57,16 +57,16 @@ module type WIRE = sig (** The (empty) content for the reply to the bootstrap message. *) end - val ref_leak_detected : (unit -> unit) -> unit - (** [ref_leak_detected fn] is called when a promise or capability is GC'd while + val ref_leak_detected : int -> (unit -> unit) -> unit + (** [ref_leak_detected thread_id fn] is called when a promise or capability is GC'd while its ref-count is non-zero, indicating that resources may have been leaked. [fn ()] will log a warning about this and free the resources itself. The reason for going via [ref_leak_detected] rather than calling [fn] directly is because the OCaml GC may detect the problem at any point (e.g. while we're sending another message). The implementation should arrange for [fn] to be - called at a safe point (e.g. when returning to the main loop). - Unit-tests may wish to call [fn] immediately to show the error and then - fail the test. *) + called at a safe point in thread [thread_id] (e.g. when returning to the + thread's main loop). Unit-tests may wish to call [fn] immediately to show + the error and then fail the test. *) end module type PAYLOAD = sig diff --git a/capnp-rpc/service.ml b/capnp-rpc/service.ml index 00f8cd72a..01b0e40ec 100644 --- a/capnp-rpc/service.ml +++ b/capnp-rpc/service.ml @@ -1,5 +1,4 @@ open Capnp_core -open Lwt.Infix module Log = Capnp_rpc_proto.Debug.Log @@ -47,6 +46,10 @@ let local (s:#generic) = Payload.content_get p |> Schema.ReaderOps.deref_opt_struct_pointer |> Schema.ReaderOps.cast_struct in match m contents release_params with | r -> results#resolve r + | exception (Eio.Cancel.Cancelled _ as ex) -> + release_params (); + Core_types.resolve_payload results (Error `Cancelled); + raise ex | exception ex -> release_params (); Log.warn (fun f -> f "Uncaught exception handling %a: %a" pp_method (interface_id, method_id) Fmt.exn ex); @@ -61,27 +64,6 @@ let return resp = let return_empty () = return @@ Response.create_empty () -(* A convenient way to implement a simple blocking local function, where - pipelining is not supported (messages sent to the result promise will be - queued up at this host until it returns). *) -let return_lwt fn = - let result, resolver = Local_struct_promise.make () in - Lwt.async (fun () -> - Lwt.catch (fun () -> - fn () >|= function - | Ok resp -> Core_types.resolve_ok resolver @@ Response.finish resp; - | Error (`Capnp e) -> Core_types.resolve_payload resolver (Error e) - ) - (fun ex -> - Log.warn (fun f -> f "Uncaught exception: %a" Fmt.exn ex); - Core_types.resolve_exn resolver @@ Capnp_rpc_proto.Exception.v "Internal error"; - Lwt.return_unit - ); - ); - result - let fail = Core_types.fail -let fail_lwt ?ty fmt = - fmt |> Fmt.kstr @@ fun msg -> - Lwt_result.fail (`Capnp (`Exception (Capnp_rpc_proto.Exception.v ?ty msg))) +let error = Core_types.broken_struct diff --git a/capnp-rpc/sturdy_ref.ml b/capnp-rpc/sturdy_ref.ml index a172b3d50..3841ea672 100644 --- a/capnp-rpc/sturdy_ref.ml +++ b/capnp-rpc/sturdy_ref.ml @@ -1,13 +1,11 @@ -open Lwt.Infix - class type [+'a] t = Capnp_core.sturdy_ref let connect t = t#connect let connect_exn t = - connect t >>= function - | Ok x -> Lwt.return x - | Error e -> Lwt.fail_with (Fmt.to_to_string Capnp_rpc_proto.Exception.pp e) + match connect t with + | Ok x -> x + | Error e -> failwith (Fmt.to_to_string Capnp_rpc_proto.Exception.pp e) let reader fn s = fn s |> Schema.ReaderOps.string_of_pointer |> Uri.of_string @@ -18,10 +16,10 @@ let builder fn (s : 'a Capnp.BytesMessage.StructStorage.builder_t) (sr : 'a t) = let cast t = t let with_cap t f = - connect t >>= function + match connect t with | Ok x -> Capability.with_ref x f - | Error e -> Lwt_result.fail (`Capnp e) + | Error e -> Error (`Capnp e) let with_cap_exn t f = - connect_exn t >>= fun x -> + let x = connect_exn t in Capability.with_ref x f diff --git a/examples/pipelining/dune b/examples/pipelining/dune index 78c2c8869..c3b83a62c 100644 --- a/examples/pipelining/dune +++ b/examples/pipelining/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets echo_api.ml echo_api.mli) diff --git a/examples/pipelining/echo.ml b/examples/pipelining/echo.ml index a30a310d3..e05728db2 100644 --- a/examples/pipelining/echo.ml +++ b/examples/pipelining/echo.ml @@ -1,6 +1,6 @@ module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Callback = struct @@ -26,23 +26,23 @@ module Callback = struct Capability.call_for_unit t method_id request end -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~delay msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.Timeout.sleep delay; + loop (i - 1) in loop 3 let service_logger = - Callback.local (Printf.printf "[server] Received %S\n%!") + Callback.local (traceln "[server] Received %S") -let local = +let local ~delay = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +63,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~delay msg) (* $MDX part-begin=server-get-logger *) method get_logger_impl _ release_params = @@ -82,7 +81,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get let heartbeat t msg callback = let open Echo.Heartbeat in diff --git a/examples/pipelining/main.ml b/examples/pipelining/main.ml index 073b0459d..123e735fa 100644 --- a/examples/pipelining/main.ml +++ b/examples/pipelining/main.ml @@ -1,17 +1,19 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg (* $MDX part-begin=run-client *) let run_client service = let logger = Echo.get_logger service in - Echo.Callback.log logger "Message from client" >|= function + match Echo.Callback.log logger "Message from client" with | Ok () -> () | Error (`Capnp err) -> Fmt.epr "Server's logger failed: %a" Capnp_rpc.Error.pp err @@ -20,18 +22,22 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~delay net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let service = Echo.local ~delay in + Switch.on_release sw (fun () -> Capability.dec_ref service); + let restore = Capnp_rpc_net.Restorer.single service_id service in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "[client] Connecting to echo service...@."; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + let uri = start_server ~sw ~delay env#net in + traceln "[client] Connecting to echo service..."; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client diff --git a/examples/sturdy-refs-2/dune b/examples/sturdy-refs-2/dune index 96c0988f1..cfebddc5a 100644 --- a/examples/sturdy-refs-2/dune +++ b/examples/sturdy-refs-2/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets api.ml api.mli) diff --git a/examples/sturdy-refs-2/main.ml b/examples/sturdy-refs-2/main.ml index 57b33d1d0..9104cf391 100644 --- a/examples/sturdy-refs-2/main.ml +++ b/examples/sturdy-refs-2/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Restorer = Capnp_rpc_net.Restorer @@ -14,7 +14,7 @@ let or_fail = function | Ok x -> x | Error (`Msg m) -> failwith m -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in @@ -22,20 +22,22 @@ let start_server () = let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in let root = Logger.local "root" in Restorer.Table.add services root_id root; - Capnp_rpc_unix.serve config ~restore >|= fun _vat -> + let _vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat_config.sturdy_uri config root_id (* $MDX part-begin=main *) let () = - Lwt_main.run begin - start_server () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - Capability.with_ref (Logger.sub root "alice") @@ fun for_alice -> - Capability.with_ref (Logger.sub root "bob") @@ fun for_bob -> - Logger.log for_alice "Message from Alice" >>= fun () -> - Logger.log for_bob "Message from Bob" - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + Capability.with_ref (Logger.sub root "alice") @@ fun for_alice -> + Capability.with_ref (Logger.sub root "bob") @@ fun for_bob -> + Logger.log for_alice "Message from Alice"; + Logger.log for_bob "Message from Bob" (* $MDX part-end *) diff --git a/examples/sturdy-refs-3/dune b/examples/sturdy-refs-3/dune index 96c0988f1..cfebddc5a 100644 --- a/examples/sturdy-refs-3/dune +++ b/examples/sturdy-refs-3/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets api.ml api.mli) diff --git a/examples/sturdy-refs-3/main.ml b/examples/sturdy-refs-3/main.ml index 1a962f80e..f02a37ba3 100644 --- a/examples/sturdy-refs-3/main.ml +++ b/examples/sturdy-refs-3/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Restorer = Capnp_rpc_net.Restorer @@ -14,11 +14,11 @@ let or_fail = function | Ok x -> x | Error (`Msg m) -> failwith m -let start_server ~switch () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in - Lwt_switch.add_hook (Some switch) (fun () -> Restorer.Table.clear services; Lwt.return_unit); + Switch.on_release sw (fun () -> Restorer.Table.clear services); let restore = Restorer.of_table services in (* $MDX part-begin=root *) let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in @@ -28,31 +28,30 @@ let start_server ~switch () = in (* $MDX part-end *) Restorer.Table.add services root_id root; - Capnp_rpc_unix.serve ~switch config ~restore >|= fun _vat -> + let _vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat_config.sturdy_uri config root_id -let run_client cap_file = - Lwt_switch.with_switch @@ fun switch -> - let vat = Capnp_rpc_unix.client_only_vat ~switch () in +let run_client ~net cap_file = + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Sturdy_ref.with_cap_exn sr @@ fun for_alice -> Logger.log for_alice "Message from Alice" let () = - Lwt_main.run begin - Lwt_switch.with_switch @@ fun switch -> - start_server ~switch () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat ~switch () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - (* $MDX part-begin=save *) - (* The admin creates a logger for Alice and saves it: *) - Capability.with_ref (Logger.sub root "alice") (fun for_alice -> - Capnp_rpc.Persistence.save_exn for_alice >|= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail - ) >>= fun () -> - (* Alice uses it: *) - run_client "alice.cap" - (* $MDX part-end *) - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + (* $MDX part-begin=save *) + (* The admin creates a logger for Alice and saves it: *) + let uri = Capability.with_ref (Logger.sub root "alice") Capnp_rpc.Persistence.save_exn in + Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; + (* Alice uses it: *) + run_client ~net "alice.cap" + (* $MDX part-end *) diff --git a/examples/sturdy-refs-4/db.ml b/examples/sturdy-refs-4/db.ml index ba820a827..b9bb244ba 100644 --- a/examples/sturdy-refs-4/db.ml +++ b/examples/sturdy-refs-4/db.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std open Capnp_rpc_net @@ -9,7 +9,7 @@ type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restore type t = { store : Store.Reader.SavedService.struct_t File_store.t; - loader : loader Lwt.t; + loader : loader Promise.t; make_sturdy : Restorer.Id.t -> Uri.t; } @@ -32,16 +32,17 @@ let save_new t ~label = let load t sr digest = match File_store.load t.store ~digest with - | None -> Lwt.return Restorer.unknown_service_id + | None -> Restorer.unknown_service_id | Some saved_service -> let logger = Store.Reader.SavedService.logger_get saved_service in let label = Store.Reader.SavedLogger.label_get logger in let sr = Capnp_rpc.Sturdy_ref.cast sr in - t.loader >|= fun loader -> + let loader = Promise.await t.loader in loader sr ~label let create ~make_sturdy dir = - let loader, set_loader = Lwt.wait () in - if not (Sys.file_exists dir) then Unix.mkdir dir 0o755; + let loader, set_loader = Promise.create () in + if not (Eio.Path.is_directory dir) then + Eio.Path.mkdir dir ~perm:0o755; let store = File_store.create dir in {store; loader; make_sturdy}, set_loader diff --git a/examples/sturdy-refs-4/db.mli b/examples/sturdy-refs-4/db.mli index 201950340..349ea573c 100644 --- a/examples/sturdy-refs-4/db.mli +++ b/examples/sturdy-refs-4/db.mli @@ -6,7 +6,7 @@ include Restorer.LOADER type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restorer.resolution (** A function to create a new in-memory logger with the given label and sturdy-ref. *) -val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Lwt.u +val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> _ Eio.Path.t -> t * loader Eio.Promise.u (** [create ~make_sturdy dir] is a database that persists services in [dir] and a resolver to let you set the loader (we're not ready to set the loader when we create the database). *) diff --git a/examples/sturdy-refs-4/dune b/examples/sturdy-refs-4/dune index db3eeccd6..364c6ebba 100644 --- a/examples/sturdy-refs-4/dune +++ b/examples/sturdy-refs-4/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt cmdliner)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt cmdliner)) (rule (targets api.ml api.mli) diff --git a/examples/sturdy-refs-4/logger.ml b/examples/sturdy-refs-4/logger.ml index d53516173..af66c8a69 100644 --- a/examples/sturdy-refs-4/logger.ml +++ b/examples/sturdy-refs-4/logger.ml @@ -1,5 +1,3 @@ -open Lwt.Infix - module Api = Api.MakeRPC(Capnp_rpc) open Capnp_rpc.Std @@ -22,14 +20,13 @@ let local ~persist_new sr label = let sub_label = Params.label_get params in release_param_caps (); let label = Printf.sprintf "%s/%s" label sub_label in - Service.return_lwt @@ fun () -> - persist_new ~label >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persist_new ~label with + | Error e -> Service.error (`Exception e) | Ok logger -> let response, results = Service.Response.create Results.init_pointer in Results.logger_set results (Some logger); Capability.dec_ref logger; - Ok response + Service.return response (* $MDX part-end *) method! pp f = diff --git a/examples/sturdy-refs-4/main.ml b/examples/sturdy-refs-4/main.ml index 563250f7c..a09f9ea7a 100644 --- a/examples/sturdy-refs-4/main.ml +++ b/examples/sturdy-refs-4/main.ml @@ -1,8 +1,10 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Restorer = Capnp_rpc_net.Restorer +let ( / ) = Eio.Path.( / ) + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) @@ -13,56 +15,58 @@ let or_fail = function (* $MDX part-begin=server *) let serve config = - Lwt_main.run begin - (* Create the on-disk store *) - let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let db, set_loader = Db.create ~make_sturdy "./store" in - (* Create the restorer *) - let services = Restorer.Table.of_loader (module Db) db in - let restore = Restorer.of_table services in - (* Add the root service *) - let persist_new ~label = - let id = Db.save_new db ~label in - Capnp_rpc_net.Restorer.restore restore id - in - let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in - let root = - let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in - Logger.local ~persist_new sr "root" - in - Restorer.Table.add services root_id root; - (* Tell the database how to restore saved loggers *) - Lwt.wakeup set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); - (* Run the server *) - Capnp_rpc_unix.serve config ~restore >>= fun _vat -> - let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in - Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; - print_endline "Wrote admin.cap"; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + (* Create the on-disk store *) + let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in + let db, set_loader = Db.create ~make_sturdy (env#cwd / "store") in + (* Create the restorer *) + let services = Restorer.Table.of_loader ~sw (module Db) db in + Switch.on_release sw (fun () -> Restorer.Table.clear services); + let restore = Restorer.of_table services in + (* Add the root service *) + let persist_new ~label = + let id = Db.save_new db ~label in + Capnp_rpc_net.Restorer.restore restore id + in + let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in + let root = + let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in + Logger.local ~persist_new sr "root" + in + Restorer.Table.add services root_id root; + (* Tell the database how to restore saved loggers *) + Promise.resolve set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); + (* Run the server *) + let _vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in + Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; + print_endline "Wrote admin.cap"; + Fiber.await_cancel () (* $MDX part-end *) let log cap_file msg = - Lwt_main.run begin - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in - Sturdy_ref.with_cap_exn sr @@ fun logger -> - Logger.log logger msg - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in + Sturdy_ref.with_cap_exn sr @@ fun logger -> + Logger.log logger msg let sub cap_file label = - Lwt_main.run begin - let sub_file = label ^ ".cap" in - if Sys.file_exists sub_file then Fmt.failwith "%S already exists!" sub_file; - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in - Sturdy_ref.with_cap_exn sr @@ fun logger -> - Capability.with_ref (Logger.sub logger label) @@ fun sub -> - Capnp_rpc.Persistence.save_exn sub >>= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri sub_file |> or_fail; - Printf.printf "Wrote %S\n%!" sub_file; - Lwt.return_unit - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let sub_file = label ^ ".cap" in + if Sys.file_exists sub_file then Fmt.failwith "%S already exists!" sub_file; + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in + Sturdy_ref.with_cap_exn sr @@ fun logger -> + let uri = Capability.with_ref (Logger.sub root "alice") Capnp_rpc.Persistence.save_exn in + Capnp_rpc_unix.Cap_file.save_uri uri sub_file |> or_fail; + Printf.printf "Wrote %S\n%!" sub_file; open Cmdliner diff --git a/examples/sturdy-refs/dune b/examples/sturdy-refs/dune index 96c0988f1..cfebddc5a 100644 --- a/examples/sturdy-refs/dune +++ b/examples/sturdy-refs/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets api.ml api.mli) diff --git a/examples/sturdy-refs/main.ml b/examples/sturdy-refs/main.ml index c5b59dd8b..8d8fb6578 100644 --- a/examples/sturdy-refs/main.ml +++ b/examples/sturdy-refs/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Restorer = Capnp_rpc_net.Restorer @@ -21,30 +21,33 @@ let make_service ~config ~services name = Restorer.Table.add services id service; name, id -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in let restore = Restorer.of_table services in let services = List.map (make_service ~config ~services) ["alice"; "bob"] in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in services |> List.iter (fun (name, id) -> let cap_file = name ^ ".cap" in Capnp_rpc_unix.Cap_file.save_service vat id cap_file |> or_fail; Printf.printf "[server] saved %S\n%!" cap_file ) -let run_client cap_file msg = - let vat = Capnp_rpc_unix.client_only_vat () in +let run_client ~net cap_file msg = + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Printf.printf "[client] loaded %S\n%!" cap_file; Sturdy_ref.with_cap_exn sr @@ fun cap -> Logger.log cap msg let () = - Lwt_main.run begin - start_server () >>= fun () -> - run_client "./alice.cap" "Message from Alice" >>= fun () -> - run_client "./bob.cap" "Message from Bob" - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let net = env#net in + start_server ~sw net; + run_client ~net "./alice.cap" "Message from Alice"; + run_client ~net "./bob.cap" "Message from Bob" (* $MDX part-end *) diff --git a/examples/testlib/calc.ml b/examples/testlib/calc.ml index 1b0e3b728..170a4d918 100644 --- a/examples/testlib/calc.ml +++ b/examples/testlib/calc.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std module Api = Calculator.MakeRPC(Capnp_rpc) @@ -85,10 +85,10 @@ module Value = struct let read v = let open Api.Client.Calculator.Value.Read in let req = Capability.Request.create_no_args () in - Capability.call_for_value_exn v method_id req >|= Results.value_get + Capability.call_for_value_exn v method_id req |> Results.value_get let final_read v = - read v >|= fun result -> + let result = read v in Capability.dec_ref v; result @@ -114,26 +114,31 @@ let call_fn fn args = let open Api.Client.Calculator.Function.Call in let req, p = Capability.Request.create Params.init_pointer in ignore (Params.params_set_list p args); - Capability.call_for_value_exn fn method_id req >|= Results.value_get + Capability.call_for_value_exn fn method_id req |> Results.value_get -let pp_result_lwt f x = - match Lwt.state x with - | Lwt.Return v -> Fmt.float f v - | Lwt.Fail ex -> Fmt.exn f ex - | Lwt.Sleep -> Fmt.string f "(still calculating)" +let pp_result_promise f x = + match Promise.peek x with + | Some (Ok v) -> Fmt.float f v + | Some (Error ex) -> Fmt.exn f ex + | None -> Fmt.string f "(still calculating)" -(* Evaluate an expression, where some sub-expressions may require remote calls. *) -let rec eval ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = +(* Evaluate an expression, where some sub-expressions may require remote calls. + Immediately returns a service for the result, while the calculation continues in [sw]. *) +let rec eval ~sw ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = let open Expr in function | Float f -> Value.local f | Prev v -> Capability.inc_ref v; v | Param p -> Value.local args.(p) | Call (f, params) -> - let params = params |> Lwt_list.map_p (fun p -> - let value = eval ~args p in - Value.final_read value - ) in - let result = params >>= call_fn f in + let result = Fiber.fork_promise ~sw (fun () -> + params + |> Fiber.List.map (fun p -> + let value = eval ~sw ~args p in + Value.final_read value + ) + |> call_fn f + ) + in let open Api.Service.Calculator in Value.local @@ object inherit Value.service @@ -141,17 +146,15 @@ let rec eval ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = val id = Capnp_rpc.Debug.OID.next () method! pp f = - Fmt.pf f "EvalResultValue(%a) = %a" Capnp_rpc.Debug.OID.pp id pp_result_lwt result + Fmt.pf f "EvalResultValue(%a) = %a" Capnp_rpc.Debug.OID.pp id pp_result_promise result method read_impl _ release_params = let open Value.Read in release_params (); - Service.return_lwt (fun () -> - result >|= fun result -> - let resp, c = Service.Response.create Results.init_pointer in - Results.value_set c result; - Ok resp - ) + let result = Promise.await_exn result in + let resp, c = Service.Response.create Results.init_pointer in + Results.value_set c result; + Service.return resp end module Fn = struct @@ -168,15 +171,14 @@ module Fn = struct let open Function.Call in let args = Params.params_get_array params in assert (Array.length args = n_args); - let value = eval ~args body in - release_params (); (* Functions return floats, not Value objects, so we have to wait here. *) - Service.return_lwt (fun () -> - Value.final_read value >|= fun value -> - let resp, r = Service.Response.create ~message_size:200 Results.init_pointer in - Results.value_set r value; - Ok resp - ) + Switch.run @@ fun sw -> + let value = eval ~sw ~args body in + release_params (); + let value = Value.final_read value in + let resp, r = Service.Response.create ~message_size:200 Results.init_pointer in + Results.value_set r value; + Service.return resp end let local_binop op : Api.Builder.Calculator.Function.t Capability.t = @@ -204,7 +206,7 @@ module Fn = struct end (* The main calculator service *) -let local = +let local ~sw = let module Calculator = Api.Service.Calculator in Calculator.local @@ object inherit Calculator.service @@ -224,7 +226,7 @@ let local = let open Calculator.Evaluate in let expr = Expr.parse (Params.expression_get params) in release_params (); - let value_obj = eval expr in + let value_obj = eval ~sw expr in Expr.release expr; let resp, results = Service.Response.create ~message_size:200 Results.init_pointer in Results.value_set results (Some value_obj); diff --git a/examples/testlib/calc.mli b/examples/testlib/calc.mli index f019d8368..115de8ff3 100644 --- a/examples/testlib/calc.mli +++ b/examples/testlib/calc.mli @@ -7,10 +7,10 @@ type t = [`Calculator_97983392df35cc36] Capability.t module rec Value : sig type t = [`Value_c3e69d34d3ee48d2] Capability.t - val read : t -> float Lwt.t + val read : t -> float (** [read t] reads the value of the remote value object. *) - val final_read : t -> float Lwt.t + val final_read : t -> float (** [final_read t] reads the value and dec_ref's [t]. *) val local : float -> t @@ -20,7 +20,7 @@ end and Fn : sig type t = [`Function_ede83a3d96840394] Capability.t - val call : t -> float list -> float Lwt.t + val call : t -> float list -> float (** [call fn args] does [fn args]. *) val local : int -> Expr.t -> Fn.t @@ -58,5 +58,6 @@ val evaluate : t -> Expr.t -> Value.t val getOperator : t -> [`Add | `Subtract | `Multiply | `Divide] -> Fn.t (** [getOperator t op] is a remote operator function provided by [t]. *) -val local : t -(** A capability to a local calculator service *) +val local : sw:Eio.Switch.t -> t +(** A capability to a local calculator service. + It may immediately return a promise of a result, while continuing the calculation in [sw]. *) diff --git a/examples/testlib/echo.ml b/examples/testlib/echo.ml index c52cdd690..b99d1dd13 100644 --- a/examples/testlib/echo.ml +++ b/examples/testlib/echo.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std type t = Api.Service.Echo.t Capability.t @@ -10,7 +10,7 @@ let local () = Echo.local @@ object inherit Echo.service - val mutable blocked = Lwt.wait () + val mutable blocked = Promise.create () val mutable count = 0 val id = Capnp_rpc.Debug.OID.next () @@ -25,16 +25,15 @@ let local () = Results.reply_set results (Fmt.str "got:%d:%s" count msg); count <- count + 1; if Params.slow_get params then ( - Service.return_lwt (fun () -> - fst blocked >|= fun () -> Ok resp - ) + Promise.await (fst blocked); + Service.return resp ) else Service.return resp method unblock_impl _ release_params = release_params (); - Lwt.wakeup (snd blocked) (); - blocked <- Lwt.wait (); + Promise.resolve (snd blocked) (); + blocked <- Promise.create (); Service.return_empty () end @@ -45,14 +44,14 @@ let ping t ?(slow=false) msg = let req, p = Capability.Request.create Params.init_pointer in Params.slow_set p slow; Params.msg_set p msg; - Capability.call_for_value_exn t method_id req >|= Results.reply_get + Capability.call_for_value_exn t method_id req |> Results.reply_get let ping_result t ?(slow=false) msg = let open Echo.Ping in let req, p = Capability.Request.create Params.init_pointer in Params.slow_set p slow; Params.msg_set p msg; - Capability.call_for_value t method_id req >|= function + match Capability.call_for_value t method_id req with | Ok x -> Ok (Results.reply_get x) | Error _ as e -> e diff --git a/examples/testlib/echo.mli b/examples/testlib/echo.mli index 03bcd8c87..dfc54381d 100644 --- a/examples/testlib/echo.mli +++ b/examples/testlib/echo.mli @@ -5,13 +5,13 @@ type t = [`Echo_bb48258560861cec] Capability.t val local : unit -> t (** [local ()] is a capability to a new local echo service. *) -val ping : t -> ?slow:bool -> string -> string Lwt.t +val ping : t -> ?slow:bool -> string -> string (** [ping t msg] sends [msg] to [t] and returns its response. If [slow] is given, the service will wait until [unblock] is called before replying. *) -val ping_result : t -> ?slow:bool -> string -> (string, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t +val ping_result : t -> ?slow:bool -> string -> (string, [> `Capnp of Capnp_rpc.Error.t]) result (** [ping t msg] sends [msg] to [t] and returns its response. If [slow] is given, the service will wait until [unblock] is called before replying. *) -val unblock : t -> unit Lwt.t +val unblock : t -> unit (** [unblock t] tells the service to return any blocked ping responses. *) diff --git a/examples/testlib/registry.ml b/examples/testlib/registry.ml index c1cf6cc36..b7bc24566 100644 --- a/examples/testlib/registry.ml +++ b/examples/testlib/registry.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std type t = Api.Service.Registry.t Capability.t @@ -17,12 +17,12 @@ let version_service = Service.return resp end -let local () = +let local ~sw () = let module Registry = Api.Service.Registry in Registry.local @@ object inherit Registry.service - val mutable blocked = Lwt.wait () + val mutable blocked = Promise.create () val mutable echo_service = Echo.local () method! release = Capability.dec_ref echo_service @@ -45,9 +45,8 @@ let local () = let open Registry.EchoService in let resp, results = Service.Response.create Results.init_pointer in Results.service_set results (Some echo_service); - Service.return_lwt (fun () -> - fst blocked >|= fun () -> Ok resp - ) + Promise.await (fst blocked); + Service.return resp method echo_service_promise_impl _params release_params = release_params (); @@ -56,8 +55,8 @@ let local () = let promise, resolver = Capability.promise () in Results.service_set results (Some promise); Capability.dec_ref promise; - Lwt.async (fun () -> - fst blocked >|= fun () -> + Fiber.fork ~sw (fun () -> + Promise.await (fst blocked); Capability.inc_ref echo_service; Capability.resolve_ok resolver echo_service ); @@ -65,8 +64,8 @@ let local () = method unblock_impl _ release_params = release_params (); - Lwt.wakeup (snd blocked) (); - blocked <- Lwt.wait (); + Promise.resolve (snd blocked) (); + blocked <- Promise.create (); Service.return_empty () method complex_impl _ release_params = @@ -131,5 +130,5 @@ module Version = struct let read t = let open Version.Read in let req = Capability.Request.create_no_args () in - Capability.call_for_value_exn t method_id req >|= Results.version_get + Capability.call_for_value_exn t method_id req |> Results.version_get end diff --git a/examples/testlib/registry.mli b/examples/testlib/registry.mli index dede96a2c..7948b3e45 100644 --- a/examples/testlib/registry.mli +++ b/examples/testlib/registry.mli @@ -1,14 +1,15 @@ +open Eio.Std open Capnp_rpc.Std module Version : sig type t = [`Version_ed7d11372e0a7243] Capability.t - val read : t -> string Lwt.t + val read : t -> string end type t = [`Registry_d9975f668b337b6d] Capability.t -val set_echo_service : t -> Echo.t -> unit Lwt.t +val set_echo_service : t -> Echo.t -> unit val echo_service : t -> Echo.t (** Waits until unblocked before returning. *) @@ -17,10 +18,10 @@ val echo_service_promise : t -> Echo.t (** Returns a promise immediately. Resolves promise when unblocked. (should appear to work the same as [echo_service] to users) *) -val unblock : t -> unit Lwt.t +val unblock : t -> unit val complex : t -> Echo.t * Version.t (** [complex t] returns two capabilities in a single, somewhat complex, message. *) -val local : unit -> t -(** [local ()] is a new local registry. *) +val local : sw:Switch.t -> unit -> t +(** [local ~sw ()] is a new local registry. *) diff --git a/examples/testlib/store.ml b/examples/testlib/store.ml index 7038bcd43..1b5c94c1c 100644 --- a/examples/testlib/store.ml +++ b/examples/testlib/store.ml @@ -1,4 +1,3 @@ -open Lwt.Infix open Capnp_rpc.Std open Capnp_rpc_net @@ -57,7 +56,7 @@ module File = struct let get t = let open Api.Client.File.Get in let request = Capability.Request.create_no_args () in - Capability.call_for_value_exn t method_id request >|= Results.data_get + Capability.call_for_value_exn t method_id request |> Results.data_get let local (db:DB.t) sr digest = let module File = Api.Service.File in @@ -92,14 +91,14 @@ module File = struct let load t sr digest = if DB.mem t.db digest then ( let sr = Sturdy_ref.cast sr in - Lwt.return @@ Restorer.grant @@ local t.db sr digest + Restorer.grant @@ local t.db sr digest ) else ( - Lwt.return Restorer.unknown_service_id + Restorer.unknown_service_id ) end - let table ~make_sturdy db = - Restorer.Table.of_loader (module Loader) {Loader.db; make_sturdy} + let table ~sw ~make_sturdy db = + Restorer.Table.of_loader ~sw (module Loader) {Loader.db; make_sturdy} end type t = Api.Client.Store.t Capability.t @@ -121,12 +120,11 @@ let local ~restore db = let open Store.CreateFile in release_params (); let id = DB.add db in - Service.return_lwt @@ fun () -> - Restorer.restore restore id >|= function - | Error e -> Error (`Capnp (`Exception e)) + match Restorer.restore restore id with + | Error e -> Service.error (`Exception e) | Ok x -> let resp, results = Service.Response.create Results.init_pointer in Results.file_set results (Some x); Capability.dec_ref x; - Ok resp + Service.return resp end diff --git a/examples/testlib/store.mli b/examples/testlib/store.mli index 05a52ad47..487359654 100644 --- a/examples/testlib/store.mli +++ b/examples/testlib/store.mli @@ -16,13 +16,13 @@ end module File : sig type t = [`File_aec5916d9557ed0e] Capability.t - val set : t -> string -> unit Lwt.t + val set : t -> string -> unit (** [set t data] saves [data] as [t]'s contents. *) - val get : t -> string Lwt.t + val get : t -> string (** [get t] is the current contents of [t]. *) - val table : make_sturdy:(Restorer.Id.t -> Uri.t) -> DB.t -> Restorer.Table.t + val table : sw:Eio.Switch.t -> make_sturdy:(Restorer.Id.t -> Uri.t) -> DB.t -> Restorer.Table.t (** [table ~make_sturdy db] is a table of file services, backed by [db]. [make_sturdy] is used to generate sturdy URIs for files. *) end diff --git a/examples/v1/dune b/examples/v1/dune index b7e41d1c0..dbb916768 100644 --- a/examples/v1/dune +++ b/examples/v1/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc logs.fmt)) + (libraries eio_main capnp-rpc logs.fmt)) (rule (targets echo_api.ml echo_api.mli) diff --git a/examples/v1/echo.ml b/examples/v1/echo.ml index 31b9117ca..b4547b7bb 100644 --- a/examples/v1/echo.ml +++ b/examples/v1/echo.ml @@ -1,7 +1,6 @@ (* $MDX part-begin=server *) module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std let local = @@ -26,5 +25,5 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-end *) diff --git a/examples/v1/main.ml b/examples/v1/main.ml index a8a396bec..a15068fcb 100644 --- a/examples/v1/main.ml +++ b/examples/v1/main.ml @@ -1,13 +1,11 @@ -open Lwt.Infix +open Eio.Std let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let () = - Lwt_main.run begin - let service = Echo.local in - Echo.ping service "foo" >>= fun reply -> - Fmt.pr "Got reply %S@." reply; - Lwt.return_unit - end + Eio_main.run @@ fun _ -> + let service = Echo.local in + let reply = Echo.ping service "foo" in + traceln "Got reply %S" reply diff --git a/examples/v2/dune b/examples/v2/dune index b7e41d1c0..dbb916768 100644 --- a/examples/v2/dune +++ b/examples/v2/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc logs.fmt)) + (libraries eio_main capnp-rpc logs.fmt)) (rule (targets echo_api.ml echo_api.mli) diff --git a/examples/v2/echo.ml b/examples/v2/echo.ml index 4ff150635..1784dffa8 100644 --- a/examples/v2/echo.ml +++ b/examples/v2/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std module Callback = struct @@ -27,21 +26,21 @@ module Callback = struct end (* $MDX part-begin=notify *) -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~delay msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.Timeout.sleep delay; + loop (i - 1) in loop 3 (* $MDX part-end *) -let local = +let local ~delay = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +62,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~delay msg) (* $MDX part-end *) end @@ -74,7 +72,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-begin=client-heartbeat *) let heartbeat t msg callback = diff --git a/examples/v2/main.ml b/examples/v2/main.ml index b44aad6c5..8fee79e13 100644 --- a/examples/v2/main.ml +++ b/examples/v2/main.ml @@ -1,18 +1,21 @@ +open Eio.Std open Capnp_rpc.Std +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let () = - Lwt_main.run begin - let service = Echo.local in - run_client service - end + Eio_main.run @@ fun env -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + let service = Echo.local ~delay in + run_client service diff --git a/examples/v3/dune b/examples/v3/dune index 78c2c8869..c3b83a62c 100644 --- a/examples/v3/dune +++ b/examples/v3/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets echo_api.ml echo_api.mli) diff --git a/examples/v3/echo.ml b/examples/v3/echo.ml index 4ff150635..1784dffa8 100644 --- a/examples/v3/echo.ml +++ b/examples/v3/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std module Callback = struct @@ -27,21 +26,21 @@ module Callback = struct end (* $MDX part-begin=notify *) -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~delay msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.Timeout.sleep delay; + loop (i - 1) in loop 3 (* $MDX part-end *) -let local = +let local ~delay = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +62,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~delay msg) (* $MDX part-end *) end @@ -74,7 +72,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-begin=client-heartbeat *) let heartbeat t msg callback = diff --git a/examples/v3/main.ml b/examples/v3/main.ml index 8d854e2c5..58e0e245d 100644 --- a/examples/v3/main.ml +++ b/examples/v3/main.ml @@ -1,12 +1,14 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> @@ -15,18 +17,20 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~delay net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~delay) in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + let uri = start_server ~sw ~delay env#net in + traceln "Connecting to echo service at: %a" Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client diff --git a/examples/v4/client.ml b/examples/v4/client.ml index b13405ff9..a35f62fc3 100644 --- a/examples/v4/client.ml +++ b/examples/v4/client.ml @@ -1,3 +1,4 @@ +open Eio.Std open Capnp_rpc.Std let () = @@ -5,18 +6,19 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let connect uri = - Lwt_main.run begin - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Capnp_rpc_unix.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Capnp_rpc_unix.with_cap_exn sr run_client open Cmdliner diff --git a/examples/v4/dune b/examples/v4/dune index cbd177f9c..64753b864 100644 --- a/examples/v4/dune +++ b/examples/v4/dune @@ -1,6 +1,6 @@ (executables (names client server) - (libraries lwt.unix capnp-rpc logs.fmt capnp-rpc-unix)) + (libraries eio_main capnp-rpc logs.fmt capnp-rpc-unix mirage-crypto-rng-eio)) (rule (targets echo_api.ml echo_api.mli) diff --git a/examples/v4/echo.ml b/examples/v4/echo.ml index d46038dfe..d624e96d0 100644 --- a/examples/v4/echo.ml +++ b/examples/v4/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std module Callback = struct @@ -26,20 +25,20 @@ module Callback = struct Capability.call_for_unit t method_id request end -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~delay msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.Timeout.sleep delay; + loop (i - 1) in loop 3 -let local = +let local ~delay = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -60,8 +59,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~delay msg) end module Echo = Api.Client.Echo @@ -70,7 +68,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get let heartbeat t msg callback = let open Echo.Heartbeat in diff --git a/examples/v4/server.ml b/examples/v4/server.ml index 6439c6e11..32e8071da 100644 --- a/examples/v4/server.ml +++ b/examples/v4/server.ml @@ -1,6 +1,8 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_net +let delay = if Sys.getenv_opt "CI" = None then 1.0 else 0.0 + let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) @@ -8,16 +10,18 @@ let () = let cap_file = "echo.cap" let serve config = - Lwt_main.run begin - let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >>= fun vat -> - match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with - | Error `Msg m -> failwith m - | Ok () -> - Fmt.pr "Server running. Connect using %S.@." cap_file; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + let delay = Eio.Time.Timeout.seconds env#mono_clock delay in + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in + let restore = Restorer.single service_id (Echo.local ~delay) in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with + | Error `Msg m -> failwith m + | Ok () -> + traceln "Server running. Connect using %S." cap_file; + Fiber.await_cancel () open Cmdliner diff --git a/fuzz/fuzz.ml b/fuzz/fuzz.ml index c6577efb7..93089cd93 100644 --- a/fuzz/fuzz.ml +++ b/fuzz/fuzz.ml @@ -143,7 +143,7 @@ module Msg = struct type request = Request.contents type response = Response.contents - let ref_leak_detected fn = + let ref_leak_detected _ fn = fn (); failwith "ref_leak_detected" @@ -214,9 +214,11 @@ module Endpoint = struct let check t = Conn.check t.conn + let fork f = f () + let create ~restore ~tags ~dump ~local_id ~remote_id xmit_queue recv_queue = let queue_send x = Queue.add x xmit_queue in - let conn = Conn.create ~restore ~tags ~queue_send in + let conn = Conn.create ~restore ~tags ~queue_send ~fork in { local_id; remote_id; diff --git a/test-bin/calc.ml b/test-bin/calc.ml index 881ad436e..502664235 100644 --- a/test-bin/calc.ml +++ b/test-bin/calc.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std module Vat = Capnp_rpc_unix.Vat module Calc = Testlib.Calc @@ -30,29 +30,31 @@ let reporter = (* Run as server *) let serve vat_config = - Lwt_main.run begin - let service_id = Capnp_rpc_net.Restorer.Id.public "" in - let restore = Capnp_rpc_net.Restorer.single service_id Calc.local in - Capnp_rpc_unix.serve vat_config ~restore >>= fun vat -> - let sr = Vat.sturdy_uri vat service_id in - Fmt.pr "Waiting for incoming connections at:@.%a@." Uri.pp_hum sr; - fst @@ Lwt.wait () - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_net.Restorer.Id.public "" in + let service = Calc.local ~sw in + let restore = Capnp_rpc_net.Restorer.single service_id service in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore vat_config in + let sr = Vat.sturdy_uri vat service_id in + traceln "Waiting for incoming connections at:@.%a" Uri.pp_hum sr; + Fiber.await_cancel () (* Run as client *) let connect addr = - Lwt_main.run begin - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Vat.import_exn vat addr in - Capnp_rpc_unix.with_cap_exn sr @@ fun calc -> - Logs.info (fun f -> f "Evaluating expression..."); - let remote_add = Calc.getOperator calc `Add in - let result = Calc.evaluate calc Calc.Expr.(Call (remote_add, [Float 40.0; Float 2.0])) in - Calc.Value.read result >>= fun v -> - Fmt.pr "Result: %f@." v; - Lwt.return_unit - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Vat.import_exn vat addr in + Capnp_rpc_unix.with_cap_exn sr @@ fun calc -> + Logs.info (fun f -> f "Evaluating expression..."); + let remote_add = Calc.getOperator calc `Add in + let result = Calc.evaluate calc Calc.Expr.(Call (remote_add, [Float 40.0; Float 2.0])) in + let v = Calc.Value.read result in + traceln "Result: %f" v (* Command-line parsing *) diff --git a/test-bin/calc_direct.ml b/test-bin/calc_direct.ml index 1f75b66bc..acf1a4322 100644 --- a/test-bin/calc_direct.ml +++ b/test-bin/calc_direct.ml @@ -1,7 +1,7 @@ (* Run the calc service as a child process, connecting directly over a socketpair. Unlike a normal connection, there is no encryption or use of sturdy refs here. *) -open Lwt.Infix +open Eio.Std module Calc = Testlib.Calc @@ -32,75 +32,59 @@ end module Parent = struct let run socket = Logging.init "parent"; + Switch.run @@ fun sw -> + (* Normally the vat runs a leak handler to free resources that get GC'd + with a non-zero reference count. We're not using a vat, so run it ourselves. *) + Fiber.fork_daemon ~sw Capnp_rpc.Leak_handler.run; (* Run Cap'n Proto RPC protocol on [socket]: *) - Lwt_switch.with_switch @@ fun switch -> - let p = Lwt_unix.of_unix_file_descr socket - |> Capnp_rpc_unix.Unix_flow.connect ~switch - |> Capnp_rpc_net.Endpoint.of_flow (module Capnp_rpc_unix.Unix_flow) - ~peer_id:Capnp_rpc_net.Auth.Digest.insecure - ~switch in + let p = Capnp_rpc_net.Endpoint.of_flow socket ~peer_id:Capnp_rpc_net.Auth.Digest.insecure in Logs.info (fun f -> f "Connecting to child process..."); - let conn = Capnp_rpc_unix.CapTP.connect ~restore:Capnp_rpc_net.Restorer.none p in + let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore:Capnp_rpc_net.Restorer.none p in + Fiber.fork_daemon ~sw (fun () -> Capnp_rpc_unix.CapTP.listen conn; `Stop_daemon); (* Get the child's service object: *) let calc = Capnp_rpc_unix.CapTP.bootstrap conn service_name in (* Use the service: *) Logs.app (fun f -> f "Sending request..."); let remote_mul = Calc.getOperator calc `Multiply in let result = Calc.evaluate calc Calc.Expr.(Call (remote_mul, [Float 21.0; Float 2.0])) in - Calc.Value.read result >>= fun v -> + let v = Calc.Value.read result in Logs.app (fun f -> f "Result: %f" v); - Logs.app (fun f -> f "Shutting down..."); - Lwt.return_unit + Logs.app (fun f -> f "Shutting down...") end module Child = struct - let service = Calc.local - let run socket = Logging.init "child"; - Lwt_main.run begin - Lwt_switch.with_switch @@ fun switch -> - let restore = Capnp_rpc_net.Restorer.single service_name service in - (* Run Cap'n Proto RPC protocol on [socket]: *) - let endpoint = Capnp_rpc_unix.Unix_flow.connect (Lwt_unix.of_unix_file_descr socket) - |> Capnp_rpc_net.Endpoint.of_flow (module Capnp_rpc_unix.Unix_flow) - ~peer_id:Capnp_rpc_net.Auth.Digest.insecure - ~switch - in - let _ : Capnp_rpc_unix.CapTP.t = Capnp_rpc_unix.CapTP.connect ~restore endpoint in - Logs.info (fun f -> f "Serving requests..."); - fst (Lwt.wait ()) (* Wait forever *) - end + Switch.run @@ fun sw -> + Fiber.fork_daemon ~sw Capnp_rpc.Leak_handler.run; + let socket = Eio_unix.Net.import_socket_stream ~sw ~close_unix:false socket in + let service = Calc.local ~sw in + let restore = Capnp_rpc_net.Restorer.single service_name service in + (* Run Cap'n Proto RPC protocol on [socket]: *) + let endpoint = Capnp_rpc_net.Endpoint.of_flow socket + ~peer_id:Capnp_rpc_net.Auth.Digest.insecure + in + let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore endpoint in + Logs.info (fun f -> f "Serving requests..."); + Capnp_rpc_unix.CapTP.listen conn end -let find_our_path prog = - if Sys.file_exists prog then prog - else ( - (* Hack for running under "dune exec" *) - let prog = "./_build/default/" ^ prog in - if Sys.file_exists prog then prog - else Fmt.failwith "Can't find path to own binary %S from %S" prog (Sys.getcwd ()) - ) - let () = - Lwt_main.run begin - match Sys.argv with - | [| prog |] -> - (* We are the parent. *) - let prog = find_our_path prog in - let p, c = Unix.(socketpair PF_UNIX SOCK_STREAM 0 ~cloexec:true) in - Unix.clear_close_on_exec c; - (* Run the child, passing the socket as its stdin. *) - let child = Lwt_process.open_process_none ~stdin:(`FD_move c) ("", [| prog; "--child" |]) in - Parent.run p >>= fun () -> - Logs.info (fun f -> f "Waiting for child to exit..."); - child#terminate; - child#status >>= fun _ -> - Logs.info (fun f -> f "Done"); - Lwt.return_unit - | [| _prog; "--child" |] -> - (* We are the child. Our socket is on stdin. *) - Child.run Unix.stdin - | _ -> - failwith "Run this command without arguments." - end + Eio_main.run @@ fun env -> + let prog_mgr = env#process_mgr in + match Sys.argv with + | [| prog |] -> + (* We are the parent. *) + Switch.run @@ fun sw -> + let prog = if Filename.is_implicit prog then "./" ^ prog else prog in + let p, c = Eio_unix.Net.socketpair_stream ~sw () in + (* Run the child, passing the socket as its stdin. *) + let _child = Eio.Process.spawn ~sw prog_mgr [prog; "--child"] ~stdin:c in + Eio.Net.close c; + Parent.run p; + Logs.info (fun f -> f "Done") + | [| _prog; "--child" |] -> + (* We are the child. Our socket is on stdin. *) + Child.run Unix.stdin + | _ -> + failwith "Run this command without arguments." diff --git a/test-bin/dune b/test-bin/dune index 9c46081cf..f51935d10 100644 --- a/test-bin/dune +++ b/test-bin/dune @@ -1,3 +1,4 @@ (executables (names calc calc_direct) - (libraries testlib cmdliner astring logs.fmt fmt.tty capnp-rpc-unix)) + (libraries testlib cmdliner astring logs.fmt fmt.tty capnp-rpc-unix eio_main + mirage-crypto-rng-eio)) diff --git a/test-bin/echo/dune b/test-bin/echo/dune index f84b208aa..3cc3789b3 100644 --- a/test-bin/echo/dune +++ b/test-bin/echo/dune @@ -1,6 +1,6 @@ (executable (name echo_bench) - (libraries lwt.unix capnp-rpc capnp-rpc-net capnp-rpc-unix logs.fmt)) + (libraries eio_main capnp-rpc capnp-rpc-net capnp-rpc-unix mirage-crypto-rng-eio logs.fmt)) (rule (targets echo_api.ml echo_api.mli) diff --git a/test-bin/echo/echo.ml b/test-bin/echo/echo.ml index c25f9c1ef..f521a644b 100755 --- a/test-bin/echo/echo.ml +++ b/test-bin/echo/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc) -open Lwt.Infix open Capnp_rpc.Std (*-- Server ----------------------------------------*) @@ -29,4 +28,4 @@ let ping t msg = let message_size = 200 + String.length msg in (* (rough estimate) *) let request, params = Capability.Request.create ~message_size Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get diff --git a/test-bin/echo/echo_bench.ml b/test-bin/echo/echo_bench.ml index afe93f84f..47ccac9a6 100755 --- a/test-bin/echo/echo_bench.ml +++ b/test-bin/echo/echo_bench.ml @@ -1,5 +1,4 @@ - -open Lwt.Infix +open Eio.Std open Capnp_rpc.Std @@ -8,36 +7,37 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let run_client service = - let n = 100000 in + (* let n = 100000 in *) (* XXX: improve speed *) + let n = 1000 in let ops = List.init n (fun i -> let payload = Int.to_string i in let desired_result = "echo:" ^ payload in fun () -> - Echo.ping service payload >|= fun res -> + let res = Echo.ping service payload in assert (res = desired_result) ) in let st = Unix.gettimeofday () in - Lwt_stream.of_list ops |> Lwt_stream.iter_n ~max_concurrency:12 (fun v -> v ()) >>= fun () -> + ops |> Fiber.List.iter ~max_fibers:12 (fun v -> v ()); let ed = Unix.gettimeofday () in let rate = (Int.to_float n) /. (ed -. st) in - Logs.info (fun m -> m "rate = %f" rate ); - Lwt.return_unit + Logs.info (fun m -> m "rate = %f" rate ) let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key ~serve_tls:false listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + Switch.run @@ fun sw -> + let uri = start_server ~sw env#net in + Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client diff --git a/test/dune b/test/dune index bbe21982f..4c951ae6e 100644 --- a/test/dune +++ b/test/dune @@ -1,5 +1,5 @@ (test (package capnp-rpc-unix) (name test) - (libraries capnp-rpc capnp-rpc-unix alcotest-lwt testlib logs.fmt - testbed)) + (libraries capnp-rpc capnp-rpc-unix testlib logs.fmt + mirage-crypto-rng-eio testbed eio_main)) diff --git a/test/proto/testbed/capnp_direct.ml b/test/proto/testbed/capnp_direct.ml index 183911d23..9539530a7 100644 --- a/test/proto/testbed/capnp_direct.ml +++ b/test/proto/testbed/capnp_direct.ml @@ -42,7 +42,7 @@ module String_content = struct module Response = Request - let ref_leak_detected fn = + let ref_leak_detected _ fn = fn (); incr ref_leaks end diff --git a/test/proto/testbed/connection.ml b/test/proto/testbed/connection.ml index 1e03b50bc..f161c3df8 100644 --- a/test/proto/testbed/connection.ml +++ b/test/proto/testbed/connection.ml @@ -81,13 +81,15 @@ module Endpoint (EP : Capnp_direct.ENDPOINT) = struct | _ -> k @@ Error (Capnp_rpc_proto.Exception.v "Only a main interface is available") ) + let fork fn = fn () + let create ?bootstrap ~tags (xmit_queue:[EP.Out.t | `Unimplemented of EP.In.t] Queue.t) (recv_queue:[EP.In.t | `Unimplemented of EP.Out.t] Queue.t) = let queue_send x = Queue.add (x :> [EP.Out.t | `Unimplemented of EP.In.t]) xmit_queue in let bootstrap = (bootstrap :> EP.Core_types.cap option) in let restore = restore_single bootstrap in - let conn = Conn.create ?restore ~tags ~queue_send in + let conn = Conn.create ?restore ~tags ~queue_send ~fork in { conn; recv_queue; diff --git a/test/test.ml b/test/test.ml index db85f63c3..41f5c5f2c 100644 --- a/test/test.ml +++ b/test/test.ml @@ -1,6 +1,6 @@ +open Eio.Std open Astring open Testlib -open Lwt.Infix open Capnp_rpc.Std open Capnp_rpc_net @@ -8,23 +8,27 @@ module Test_utils = Testbed.Test_utils module Vat = Capnp_rpc_unix.Vat module CapTP = Capnp_rpc_unix.CapTP -module Unix_flow = Capnp_rpc_unix.Unix_flow -module Tls_wrapper = Capnp_rpc_net.Tls_wrapper.Make(Unix_flow) +module Tls_wrapper = Capnp_rpc_net.Tls_wrapper module Exception = Capnp_rpc.Exception +exception Simulated_failure + +let ( let/ ) x f = f (x ()) +let ( and/ ) x y () = Fiber.pair x y + +let _debug () = + Logs.Src.set_level Capnp_rpc.Debug.src (Some Logs.Debug) + type cs = { client : Vat.t; + client_cancel : unit -> unit; server : Vat.t; - server_switch : Lwt_switch.t; + server_cancel : unit -> unit; } -let ensure_removed path = - try Unix.unlink path - with Unix.Unix_error(Unix.ENOENT, _, _) -> () - let next_port = ref 8000 -let get_test_address ~switch name = +let get_test_address name = match Sys.os_type with | "Win32" -> (* No Unix-domain sockets on Windows *) @@ -32,9 +36,7 @@ let get_test_address ~switch name = incr next_port; `TCP ("127.0.0.1", port) | _ -> - let socket_path = Filename.(concat (Filename.get_temp_dir_name ())) name in - Lwt_switch.add_hook (Some switch) (fun () -> Lwt.return @@ ensure_removed socket_path); - `Unix socket_path + `Unix (Filename.(concat (Filename.get_temp_dir_name ())) name) (* Have the client ask the server for its bootstrap object, and return the resulting client-side proxy to it. *) @@ -61,210 +63,226 @@ let cap_equal_exn a b = let cap = Alcotest.testable Capability.pp cap_equal_exn let () = Logs.(set_level (Some Logs.Warning)) -let server_key = Auth.Secret_key.generate () -let client_key = Auth.Secret_key.generate () -let bad_key = Auth.Secret_key.generate () +let server_key = lazy (Auth.Secret_key.generate ()) +let client_key = lazy (Auth.Secret_key.generate ()) +let bad_key = lazy (Auth.Secret_key.generate ()) let () = Logs.(set_level (Some Logs.Info)) -let server_pem = `PEM (Auth.Secret_key.to_pem_data server_key) - -let make_vats_full ?(serve_tls=false) ~client_switch ~server_switch ~restore () = - let server_config = - let addr = get_test_address ~switch:server_switch "capnp-rpc-test-server" in - Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem ~serve_tls addr +let server_pem = lazy (`PEM (Auth.Secret_key.to_pem_data (Lazy.force server_key))) + +(* Run [fn ~sw] in a daemon fiber with a sub-switch. + Return a function to cancel the switch and the result of [fn]. *) +let fork_with_cancel ~sw ~tags fn = + let x, set_x = Promise.create () in + Fiber.fork_daemon ~sw (fun () -> + let is_cancelled = ref false in + try + Switch.run @@ fun sw -> + let cancel () = is_cancelled := true; Switch.fail sw Simulated_failure in + Promise.resolve set_x (cancel, fn ~sw); + Fiber.await_cancel () + with Simulated_failure when !is_cancelled -> + Logs.info (fun f -> f ~tags "Vat shut down by simulated failure"); + `Stop_daemon + ); + Promise.await x + +let make_vats_full ?(serve_tls=false) ~sw ~net ~restore () = + let server_cancel, server = + let server_config = + let addr = get_test_address "capnp-rpc-test-server" in + Capnp_rpc_unix.Vat_config.create ~secret_key:(Lazy.force server_pem) ~serve_tls addr + in + let tags = Test_utils.server_tags in + fork_with_cancel ~sw ~tags (Capnp_rpc_unix.serve ~net ~tags ~restore server_config) in - Capnp_rpc_unix.serve ~switch:server_switch ~tags:Test_utils.server_tags ~restore server_config >>= fun server -> - Lwt.return { - client = Vat.create ~switch:client_switch ~tags:Test_utils.client_tags ~secret_key:(lazy client_key) (); + let client_cancel, client = + let tags = Test_utils.client_tags in + fork_with_cancel ~sw ~tags (Vat.create ~tags ~secret_key:client_key net) + in + { + client; + client_cancel; server; - server_switch; + server_cancel; } -let make_vats ?serve_tls ~switch ~service () = - let server_switch = Lwt_switch.create () in - Lwt_switch.add_hook (Some switch) (fun () -> Lwt_switch.turn_off server_switch); +let with_vats ?serve_tls ~net ~service fn = + Switch.run @@ fun sw -> let id = Restorer.Id.public "" in let restore = Restorer.single id service in - Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit); - make_vats_full ?serve_tls ~client_switch:switch ~server_switch ~restore () + Switch.on_release sw (fun () -> Capability.dec_ref service); + fn @@ make_vats_full ?serve_tls ~sw ~net ~restore (); + Logs.info (fun f -> f "Test finished; shutting down vats..."); + (* Check for leaks while the vat is still running. *) + Gc.full_major (); + Fiber.yield () -(* Generic Lwt running for Alcotest. *) -let run_lwt name ?(expected_warnings=0) fn = - Alcotest_lwt.test_case name `Quick @@ fun sw () -> +(* Generic runner for Alcotest. *) +let run_eio ~net name ?(expected_warnings=0) fn = + Alcotest.test_case name `Quick @@ fun () -> let warnings_at_start = Logs.(err_count () + warn_count ()) in Logs.info (fun f -> f "Start test-case"); - let finished = ref false in - Lwt_switch.add_hook (Some sw) (fun () -> - if not !finished then !Lwt.async_exception_hook (Failure "Switch turned off early"); - Lwt.return_unit - ); - fn sw >>= fun () -> finished := true; - Lwt_switch.turn_off sw >|= fun () -> - Gc.full_major (); - Lwt.wakeup_paused (); - Gc.full_major (); - Lwt.wakeup_paused (); + fn ~net; Gc.full_major (); let warnings_at_end = Logs.(err_count () + warn_count ()) in Alcotest.(check int) "Check log for warnings" expected_warnings (warnings_at_end - warnings_at_start) -let test_simple switch ~serve_tls = - make_vats ~switch ~serve_tls ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >>= fun reply -> +let test_simple ~net ~serve_tls = + with_vats ~net ~serve_tls ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + let reply = Echo.ping service "ping" in Alcotest.(check string) "Ping response" "got:0:ping" reply; - Capability.dec_ref service; - Lwt.return () + Capability.dec_ref service -let test_bad_crypto switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> +let test_bad_crypto ~net = + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> let id = Restorer.Id.public "" in let uri = Vat.sturdy_uri cs.server id in - let bad_digest = Auth.Secret_key.digest ~hash:`SHA256 bad_key in + let bad_digest = Auth.Secret_key.digest ~hash:`SHA256 (Lazy.force bad_key) in let uri = Auth.Digest.add_to_uri bad_digest uri in let sr = Capnp_rpc_unix.Vat.import_exn cs.client uri in let old_warnings = Logs.warn_count () in - Sturdy_ref.connect sr >>= function + match Sturdy_ref.connect sr with | Ok _ -> Alcotest.fail "Wrong TLS key should have been rejected" | Error e -> let msg = Fmt.to_to_string Capnp_rpc.Exception.pp e in - assert (String.is_prefix ~affix:"Failed: TLS connection failed: authentication failure" msg); - (* Wait for server to log warning *) - let rec wait () = - if Logs.warn_count () = old_warnings then Lwt.pause () >>= wait - else Lwt.return_unit - in - wait () - -let test_parallel switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - let reply1 = Echo.ping service ~slow:true "ping1" in - Echo.ping service "ping2" >|= Alcotest.(check string) "Ping2 response" "got:1:ping2" >>= fun () -> - assert (Lwt.state reply1 = Lwt.Sleep); - Echo.unblock service >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping1 response" "got:0:ping1" >>= fun () -> - Capability.dec_ref service; - Lwt.return () + assert (String.is_prefix ~affix:"Failed: TLS connection failed: TLS failure: authentication failure" msg); + Logs.info (fun f -> f "Wait for server to log warning..."); + while Logs.warn_count () = old_warnings do + Fiber.yield () + done + +let test_parallel ~net = + with_vats ~net ~service:(Echo.local ()) @@ fun cs -> + Switch.run @@ fun sw -> + let service = get_bootstrap cs in + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping service ~slow:true "ping1") in + Echo.ping service "ping2" |> Alcotest.(check string) "Ping2 response" "got:1:ping2"; + assert (Promise.peek reply1 = None); + Echo.unblock service; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping1 response" "got:0:ping1"; + Capability.dec_ref service -let test_registry switch = - let registry_impl = Registry.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> +let test_registry ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in Capability.with_ref (Registry.echo_service registry) @@ fun echo_service -> - Registry.unblock registry >>= fun () -> - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Capability.dec_ref registry; - Lwt.return () + Registry.unblock registry; + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Capability.dec_ref registry -let test_embargo switch = - let registry_impl = Registry.local () in +let test_embargo ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in let local_echo = Echo.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> - Registry.set_echo_service registry local_echo >>= fun () -> + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in + Registry.set_echo_service registry local_echo; Capability.dec_ref local_echo; let echo_service = Registry.echo_service registry in - let reply1 = Echo.ping echo_service "ping" in - Registry.unblock registry >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping echo_service "ping") in + Registry.unblock registry; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping response" "got:0:ping"; (* Flush, to ensure we resolve the echo_service's location. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:1:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping"; (* Test local connection. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:2:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref echo_service; - Capability.dec_ref registry; - Lwt.return () + Capability.dec_ref registry -let test_resolve switch = - let registry_impl = Registry.local () in +let test_resolve ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in let local_echo = Echo.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> - Registry.set_echo_service registry local_echo >>= fun () -> + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in + Registry.set_echo_service registry local_echo; Capability.dec_ref local_echo; let echo_service = Registry.echo_service_promise registry in - let reply1 = Echo.ping echo_service "ping" in - Registry.unblock registry >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping echo_service "ping") in + Registry.unblock registry; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping response" "got:0:ping"; (* Flush, to ensure we resolve the echo_service's location. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:1:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping"; (* Test local connection. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:2:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref echo_service; - Capability.dec_ref registry; - Lwt.return () - -let test_cancel switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - let reply1 = Echo.ping service ~slow:true "ping1" in - assert (Lwt.state reply1 = Lwt.Sleep); - Lwt.cancel reply1; - Lwt.try_bind - (fun () -> reply1) - (fun _ -> Alcotest.fail "Should have been cancelled!") - (function - | Lwt.Canceled -> Lwt.return () - | ex -> Lwt.fail ex + Capability.dec_ref registry + +(* todo: we stop waiting and we send a finish message, but we don't currently + abort the service operation. *) +let test_cancel ~net = + with_vats ~net ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Fiber.first + (fun () -> + ignore (Echo.ping service ~slow:true "ping1" : string); + assert false ) - >>= fun () -> - Echo.unblock service >|= fun () -> + (fun () -> + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping" + ); + Echo.unblock service; + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref service let float = Alcotest.testable Fmt.float (=) -let test_calculator switch = +let test_calculator ~net = let open Calc in - Capability.inc_ref Calc.local; - make_vats ~switch ~service:Calc.local () >>= fun cs -> - get_bootstrap cs >>= fun c -> - Calc.evaluate c (Float 1.) |> Value.final_read >|= Alcotest.check float "Simple calc" 1. >>= fun () -> + Switch.run @@ fun sw -> + let service = Calc.local ~sw in + with_vats ~net ~service @@ fun cs -> + let c = get_bootstrap cs in + Calc.evaluate c (Float 1.) |> Value.final_read |> Alcotest.check float "Simple calc" 1.; let local_add = Calc.Fn.add in let expr = Expr.(Call (local_add, [Float 1.; Float 2.])) in - Calc.evaluate c expr |> Value.final_read >|= Alcotest.check float "Complex with local fn" 3. >>= fun () -> + Calc.evaluate c expr |> Value.final_read |> Alcotest.check float "Complex with local fn" 3.; let remote_add = Calc.getOperator c `Add in - Calc.Fn.call remote_add [5.; 3.] >|= Alcotest.check float "Check fn" 8. >>= fun () -> + Calc.Fn.call remote_add [5.; 3.] |> Alcotest.check float "Check fn" 8.; let expr = Expr.(Call (remote_add, [Float 1.; Float 2.])) in - Calc.evaluate c expr |> Value.final_read >|= Alcotest.check float "Complex with remote fn" 3. >>= fun () -> + Calc.evaluate c expr |> Value.final_read |> Alcotest.check float "Complex with remote fn" 3.; Capability.dec_ref remote_add; - Capability.dec_ref c; - Lwt.return () + Capability.dec_ref c -let test_calculator2 switch = +let test_calculator2 ~net = let open Calc in - Capability.inc_ref Calc.local; - make_vats ~switch ~service:Calc.local () >>= fun cs -> - get_bootstrap cs >>= fun c -> + Switch.run @@ fun sw -> + let service = Calc.local ~sw in + with_vats ~net ~service @@ fun cs -> + let c = get_bootstrap cs in let remote_add = Calc.getOperator c `Add in let remote_mul = Calc.getOperator c `Multiply in let expr = Expr.(Call (remote_mul, [Float 4.; Float 6.])) in let result = Calc.evaluate c expr in - let expr = Expr.(Call (remote_add, [Prev result; Float 3.])) in - let add3 = Calc.evaluate c expr |> Value.final_read in - let expr = Expr.(Call (remote_add, [Prev result; Float 5.])) in - let add5 = Calc.evaluate c expr |> Value.final_read in - add3 >>= fun add3 -> - add5 >>= fun add5 -> + let/ add3 () = + let expr = Expr.(Call (remote_add, [Prev result; Float 3.])) in + Calc.evaluate c expr |> Value.final_read + and/ add5 () = + let expr = Expr.(Call (remote_add, [Prev result; Float 5.])) in + Calc.evaluate c expr |> Value.final_read + in Alcotest.check float "First" 27.0 add3; Alcotest.check float "Second" 29.0 add5; Capability.dec_ref result; Capability.dec_ref remote_add; Capability.dec_ref remote_mul; - Capability.dec_ref c; - Lwt.return () + Capability.dec_ref c -let test_indexing switch = - let registry_impl = Registry.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> +let test_indexing ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in let echo_service, version = Registry.complex registry in - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Registry.Version.read version >|= Alcotest.(check string) "Version response" "0.1" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Registry.Version.read version |> Alcotest.(check string) "Version response" "0.1"; Capability.dec_ref registry; Capability.dec_ref echo_service; - Capability.dec_ref version; - Lwt.return () + Capability.dec_ref version let cmd_result t = let pp f (x : ('a Cmdliner.Cmd.eval_ok, Cmdliner.Cmd.eval_error) result) = @@ -332,17 +350,16 @@ let test_sturdy_uri () = let sr = (`Unix "/sock", auth), "main" in check "Secure Unix" "capnp://sha-256:s16WV4JeGusAL_nTjvICiQOFqm3LqYrDj3K-HXdMi8s@/sock/bWFpbg" sr -let test_sturdy_self switch = +let test_sturdy_self ~net = let service = Echo.local () in Capability.inc_ref service; - make_vats ~switch ~serve_tls:true ~service () >>= fun cs -> + with_vats ~net ~serve_tls:true ~service @@ fun cs -> let id = Restorer.Id.public "" in let sr = Vat.sturdy_uri cs.server id |> Vat.import_exn cs.server in - Sturdy_ref.connect_exn sr >>= fun service2 -> + let service2 = Sturdy_ref.connect_exn sr in Alcotest.check cap "Restore from same vat" service service2; Capability.dec_ref service2; - Capability.dec_ref service; - Lwt.return () + Capability.dec_ref service let expect_non_exn = function | Ok x -> x @@ -351,7 +368,8 @@ let expect_non_exn = function let except = Alcotest.testable Capnp_rpc.Exception.pp (=) let except_ty = Alcotest.testable Capnp_rpc.Exception.pp_ty (=) -let test_table_restorer _switch = +let test_table_restorer ~net:_ = + Switch.run @@ fun sw -> let make_sturdy id = Uri.make ~path:(Restorer.Id.to_string id) () in let table = Restorer.Table.create make_sturdy in let echo_id = Restorer.Id.public "echo" in @@ -359,120 +377,120 @@ let test_table_restorer _switch = let broken_id = Restorer.Id.public "broken" in let unknown_id = Restorer.Id.public "unknown" in Restorer.Table.add table echo_id @@ Echo.local (); - Restorer.Table.add table registry_id @@ Registry.local (); + Restorer.Table.add table registry_id @@ Registry.local ~sw (); Restorer.Table.add table broken_id @@ Capability.broken (Capnp_rpc.Exception.v "broken"); let r = Restorer.of_table table in - Restorer.restore r echo_id >|= expect_non_exn >>= fun a1 -> - Echo.ping a1 "ping" >>= fun reply -> + let a1 = Restorer.restore r echo_id |> expect_non_exn in + let reply = Echo.ping a1 "ping" in Alcotest.(check string) "Ping response" "got:0:ping" reply; - Restorer.restore r echo_id >|= expect_non_exn >>= fun a2 -> + let a2 = Restorer.restore r echo_id |> expect_non_exn in Alcotest.check cap "Same cap" a1 a2; - Restorer.restore r registry_id >|= expect_non_exn >>= fun r1 -> + let r1 = Restorer.restore r registry_id |> expect_non_exn in assert (a1 <> r1); - Restorer.restore r broken_id >|= expect_non_exn >>= fun x -> + let x = Restorer.restore r broken_id |> expect_non_exn in let expected = Some (Capnp_rpc.Exception.v "broken") in Alcotest.(check (option except)) "Broken response" expected (Capability.problem x); - Restorer.restore r unknown_id >>= fun x -> + let x = Restorer.restore r unknown_id in let expected = Error (Capnp_rpc.Exception.v "Unknown persistent service ID") in Alcotest.(check (result reject except)) "Missing mapping" expected x; Capability.dec_ref a1; Capability.dec_ref a2; Capability.dec_ref r1; Restorer.Table.remove table echo_id; - Restorer.Table.clear table; - Lwt.return () + Restorer.Table.clear table module Loader = struct - type t = string -> Restorer.resolution Lwt.t + type t = string -> Restorer.resolution let hash _ = `SHA256 let make_sturdy _ id = Uri.make ~path:(Restorer.Id.to_string id) () let load t _sr digest = t digest end -let test_fn_restorer _switch = +let test_fn_restorer ~net:_ = + Switch.run @@ fun sw -> let cap = Alcotest.testable Capability.pp (=) in let a = Restorer.Id.public "a" in let b = Restorer.Id.public "b" in let c = Restorer.Id.public "c" in let current_c = ref (Restorer.reject (Exception.v "Broken C")) in - let delay = Lwt_condition.create () in + let delay = Eio.Condition.create () in let digest = Restorer.Id.digest (Loader.hash ()) in let load d = - if d = digest a then Lwt.return @@ Restorer.grant @@ Echo.local () - else if d = digest b then Lwt_condition.wait delay >|= fun () -> Restorer.grant @@ Echo.local () - else if d = digest c then Lwt_condition.wait delay >|= fun () -> !current_c - else Lwt.return @@ Restorer.unknown_service_id + if d = digest a then Restorer.grant @@ Echo.local () + else if d = digest b then (Eio.Condition.await_no_mutex delay; Restorer.grant @@ Echo.local ()) + else if d = digest c then (Eio.Condition.await_no_mutex delay; !current_c) + else Restorer.unknown_service_id in - let table = Restorer.Table.of_loader (module Loader) load in + let table = Restorer.Table.of_loader ~sw (module Loader) load in let restorer = Restorer.of_table table in let restore x = Restorer.restore restorer x in (* Check that restoring the same ID twice caches the capability. *) - restore a >|= expect_non_exn >>= fun a1 -> - restore a >|= expect_non_exn >>= fun a2 -> + let a1 = restore a |> expect_non_exn in + let a2 = restore a |> expect_non_exn in Alcotest.check cap "Restore cached" a1 a2; Capability.dec_ref a1; Capability.dec_ref a2; (* But if it's released, the next lookup loads a fresh one. *) - restore a >|= expect_non_exn >>= fun a3 -> + let a3 = restore a |> expect_non_exn in if a1 = a3 then Alcotest.fail "Returned released cap!"; Capability.dec_ref a3; (* Doing two lookups in parallel only does one load. *) - let b1 = restore b in - let b2 = restore b in - assert (Lwt.state b1 = Lwt.Sleep); - Lwt_condition.broadcast delay (); - b1 >|= expect_non_exn >>= fun b1 -> - b2 >|= expect_non_exn >>= fun b2 -> + let b1 = Fiber.fork_promise ~sw (fun () -> restore b) in + let b2 = Fiber.fork_promise ~sw (fun () -> restore b) in + assert (Promise.peek b1 = None); + Eio.Condition.broadcast delay; + let b1 = Promise.await_exn b1 |> expect_non_exn in + let b2 = Promise.await_exn b2 |> expect_non_exn in Alcotest.check cap "Restore delayed cached" b1 b2; Restorer.Table.clear table; (* (should have no effect) *) Capability.dec_ref b1; Capability.dec_ref b2; (* Failed lookups aren't cached. *) - let c1 = restore c in - Lwt_condition.broadcast delay (); - c1 >>= fun c1 -> + let c1 = Fiber.fork_promise ~sw (fun () -> restore c) in + Eio.Condition.broadcast delay; + let c1 = Promise.await_exn c1 in let reject = Alcotest.result cap except in Alcotest.check reject "C initially fails" (Error (Exception.v "Broken C")) c1; - let c2 = restore c in + let c2 = Fiber.fork_promise ~sw (fun () -> restore c) in let c_service = Echo.local () in current_c := Restorer.grant c_service; - Lwt_condition.broadcast delay (); - c2 >|= expect_non_exn >>= fun c2 -> + Eio.Condition.broadcast delay; + let c2 = Promise.await_exn c2 |> expect_non_exn in Alcotest.check cap "C now works" c_service c2; Capability.dec_ref c2; (* Two users; one frees the cap immediately *) let b1 = - restore b >|= expect_non_exn >|= fun b1 -> + Fiber.fork_promise ~sw @@ fun () -> + restore b |> expect_non_exn |> fun b1 -> Capability.dec_ref b1; b1 in - let b2 = restore b in - Lwt_condition.broadcast delay (); - b1 >>= fun b1 -> - b2 >|= expect_non_exn >>= fun b2 -> + let b2 = Fiber.fork_promise ~sw (fun () -> restore b) in + Eio.Condition.broadcast delay; + let b1 = Promise.await_exn b1 in + let b2 = Promise.await_exn b2 |> expect_non_exn in Alcotest.check cap "Cap not freed" b1 b2; - Capability.dec_ref b2; - Lwt.return_unit - -let test_broken switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - let problem, set_problem = Lwt.wait () in - Capability.when_broken (fun x -> Lwt.wakeup set_problem x) service; + Capability.dec_ref b2 + +let test_broken ~net = + with_vats ~net ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + let problem, set_problem = Promise.create () in + Capability.when_broken (fun x -> Promise.resolve set_problem x) service; Alcotest.check (Alcotest.option except) "Still OK" None @@ Capability.problem service; - assert (Lwt.state problem = Lwt.Sleep); + assert (Promise.peek problem = None); Logs.info (fun f -> f "Turning off server..."); - Lwt_switch.turn_off cs.server_switch >>= fun () -> - problem >>= fun problem -> + cs.server_cancel (); + let problem = Promise.await problem in Alcotest.check except_ty "Broken callback ran" `Disconnected problem.ty; assert (Capability.problem service <> None); - Lwt.catch - (fun () -> Echo.ping service "ping" >|= fun _ -> Alcotest.fail "Should have failed!") - (fun _ -> Lwt.return ()) - >|= fun () -> - Capability.dec_ref service + try + ignore (Echo.ping service "ping" : string); + Alcotest.fail "Should have failed!" + with Failure _ -> + Capability.dec_ref service (* [when_broken] follows promises. *) let test_broken2 () = @@ -501,42 +519,37 @@ let test_broken4 () = Capability.dec_ref promise; Alcotest.check (Alcotest.option except) "Released, not called" None !problem -let test_parallel_connect switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> - let service = get_bootstrap cs in - let service2 = get_bootstrap cs in - service >>= fun service -> - service2 >>= fun service2 -> - Capability.await_settled_exn service >>= fun () -> - Capability.await_settled_exn service2 >>= fun () -> +let test_parallel_connect ~net = + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let/ service () = get_bootstrap cs + and/ service2 () = get_bootstrap cs in + Capability.await_settled_exn service; + Capability.await_settled_exn service2; Alcotest.check cap "Shared connection" service service2; Capability.dec_ref service; - Capability.dec_ref service2; - Lwt.return_unit - -let test_parallel_fails switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> - let service = get_bootstrap cs in - let service2 = get_bootstrap cs in - service >>= fun service -> - service2 >>= fun service2 -> - Lwt_switch.turn_off cs.server_switch >>= fun () -> - Capability.await_settled_exn service >>= fun () -> - Capability.await_settled_exn service2 >>= fun () -> + Capability.dec_ref service2 + +let test_parallel_fails ~net = + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let/ service () = get_bootstrap cs + and/ service2 () = get_bootstrap cs in + cs.server_cancel (); + ignore (Capability.await_settled service : _ result); + ignore (Capability.await_settled service2 : _ result); Alcotest.check cap "Shared failure" service service2; Capability.dec_ref service; Capability.dec_ref service2; (* Restart server (ignore new client) *) - Lwt.pause () >>= fun () -> - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun _cs2 -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Capability.dec_ref service; - Lwt.return_unit + Fiber.yield (); + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Capability.dec_ref service -let test_crossed_calls switch = +let test_crossed_calls ~net = (* Would be good to control the ordering here, to test the various cases. Currently, it's not certain which path is actually tested. *) + Switch.run @@ fun sw -> let id = Restorer.Id.public "" in let make_vat ~secret_key ~tags addr = let service = Echo.local () in @@ -544,106 +557,106 @@ let test_crossed_calls switch = let config = let secret_key = `PEM (Auth.Secret_key.to_pem_data secret_key) in let name = Fmt.str "capnp-rpc-test-%s" addr in - Capnp_rpc_unix.Vat_config.create ~secret_key (get_test_address ~switch name) + Capnp_rpc_unix.Vat_config.create ~secret_key (get_test_address name) in - Capnp_rpc_unix.serve ~switch ~tags ~restore config >>= fun vat -> - Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit); - Lwt.return vat + let vat = Capnp_rpc_unix.serve ~net ~sw ~tags ~restore config in + Switch.on_release sw (fun () -> Capability.dec_ref service); + vat in - make_vat ~secret_key:client_key ~tags:Test_utils.client_tags "client" >>= fun client -> - make_vat ~secret_key:server_key ~tags:Test_utils.server_tags "server" >>= fun server -> + let client = make_vat ~secret_key:(Lazy.force client_key) ~tags:Test_utils.client_tags "client" in + let server = make_vat ~secret_key:(Lazy.force server_key) ~tags:Test_utils.server_tags "server" in let sr_to_client = Capnp_rpc_unix.Vat.sturdy_uri client id |> Vat.import_exn server in let sr_to_server = Capnp_rpc_unix.Vat.sturdy_uri server id |> Vat.import_exn client in - let to_client = Sturdy_ref.connect_exn sr_to_client in - let to_server = Sturdy_ref.connect_exn sr_to_server in - to_client >>= fun to_client -> - to_server >>= fun to_server -> + let/ to_client () = Sturdy_ref.connect_exn sr_to_client + and/ to_server () = Sturdy_ref.connect_exn sr_to_server in Logs.info (fun f -> f ~tags:Test_utils.client_tags "%a" Capnp_rpc_unix.Vat.dump client); Logs.info (fun f -> f ~tags:Test_utils.server_tags "%a" Capnp_rpc_unix.Vat.dump server); - let s_got = Echo.ping_result to_client "ping" in - let c_got = Echo.ping_result to_server "ping" in - s_got >>= fun s_got -> - c_got >>= fun c_got -> - begin match c_got, s_got with - | Ok x, Ok y -> Lwt.return (x, y) + let/ s_got () = Echo.ping_result to_client "ping" + and/ c_got () = Echo.ping_result to_server "ping" in + let c_got, s_got = + match c_got, s_got with + | Ok x, Ok y -> (x, y) | Ok x, Error _ -> (* Server got an error. Try client again. *) - Sturdy_ref.connect_exn sr_to_client >>= fun to_client -> + let to_client = Sturdy_ref.connect_exn sr_to_client in Capability.with_ref to_client @@ fun to_client -> - Echo.ping to_client "ping" >|= fun s_got -> (x, s_got) + Echo.ping to_client "ping" |> fun s_got -> (x, s_got) | Error _, Ok y -> (* Client got an error. Try server again. *) - Sturdy_ref.connect_exn sr_to_server >>= fun to_server -> + let to_server = Sturdy_ref.connect_exn sr_to_server in Capability.with_ref to_server @@ fun to_server -> - Echo.ping to_server "ping" >|= fun c_got -> (c_got, y) + Echo.ping to_server "ping" |> fun c_got -> (c_got, y) | Error (`Capnp e1), Error (`Capnp e2) -> Fmt.failwith "@[Both connections failed!@,%a@,%a@]" Capnp_rpc.Error.pp e1 Capnp_rpc.Error.pp e2 - end >>= fun (c_got, s_got) -> + in Alcotest.(check string) "Client's ping response" "got:0:ping" c_got; Alcotest.(check string) "Server's ping response" "got:0:ping" s_got; Capability.dec_ref to_client; - Capability.dec_ref to_server; - Lwt.return_unit + Capability.dec_ref to_server (* Run test_crossed_calls several times to try to trigger the various behaviours. *) -let test_crossed_calls _switch = - let rec aux i = - if i = 0 then Lwt.return_unit - else ( - Lwt_switch.with_switch test_crossed_calls >>= fun () -> - aux (i - 1) - ) - in - aux 10 +let test_crossed_calls ~net = + for _ = 1 to 10 do + test_crossed_calls ~net + done -let test_store switch = +let test_store ~net = + Switch.run @@ fun sw -> (* Persistent server configuration *) let db = Store.DB.create () in let config = - let addr = get_test_address ~switch "capnp-rpc-test-server" in - Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem addr + let addr = get_test_address "capnp-rpc-test-server" in + Capnp_rpc_unix.Vat_config.create ~secret_key:(Lazy.force server_pem) addr in let main_id = Restorer.Id.generate () in - let start_server ~switch () = + let start_server ~sw () = let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let table = Store.File.table ~make_sturdy db in - Lwt_switch.add_hook (Some switch) (fun () -> Restorer.Table.clear table; Lwt.return_unit); + let table = Store.File.table ~sw ~make_sturdy db in + Switch.on_release sw (fun () -> Restorer.Table.clear table); let restore = Restorer.of_table table in let service = Store.local ~restore db in Restorer.Table.add table main_id service; - Capnp_rpc_unix.serve ~switch ~restore ~tags:Test_utils.server_tags config + Capnp_rpc_unix.serve ~sw ~net ~restore ~tags:Test_utils.server_tags config in (* Start server *) - let server_switch = Lwt_switch.create () in - start_server ~switch:server_switch () >>= fun server -> - let store_uri = Capnp_rpc_unix.Vat.sturdy_uri server main_id in - (* Set up client *) - let client = Capnp_rpc_unix.client_only_vat ~tags:Test_utils.client_tags ~switch () in - let sr = Capnp_rpc_unix.Vat.import_exn client store_uri in - Sturdy_ref.with_cap_exn sr @@ fun store -> - (* Try creating a file *) - let file = Store.create_file store in - Store.File.set file "Hello" >>= fun () -> - Capnp_rpc.Persistence.save_exn file >>= fun file_sr -> - let file_sr = Vat.import_exn client file_sr in (* todo: get rid of this step *) - (* Shut down server *) - Lwt.async (fun () -> Lwt_switch.turn_off server_switch); - let broken, set_broken = Lwt.wait () in - Capability.when_broken (Lwt.wakeup set_broken) file; - broken >>= fun _ex -> + let file, file_sr = + Switch.run (fun server_switch -> + let server = start_server ~sw:server_switch () in + let store_uri = Capnp_rpc_unix.Vat.sturdy_uri server main_id in + (* Set up client *) + let client = Capnp_rpc_unix.client_only_vat ~tags:Test_utils.client_tags ~sw net in + let sr = Capnp_rpc_unix.Vat.import_exn client store_uri in + Sturdy_ref.with_cap_exn sr @@ fun store -> + (* Try creating a file *) + let file = Store.create_file store in + Store.File.set file "Hello"; + let file_sr = Capnp_rpc.Persistence.save_exn file in + let file_sr = Vat.import_exn client file_sr in (* todo: get rid of this step *) + file, file_sr + ) + in + let broken, set_broken = Promise.create () in + Capability.when_broken (Promise.resolve set_broken) file; + ignore (Promise.await broken : Exception.t); assert (Capability.problem file <> None); (* Restart server *) - start_server ~switch () >>= fun _server -> + let _server = start_server ~sw () in (* Reconnect client *) Sturdy_ref.with_cap_exn file_sr @@ fun file -> - Store.File.get file >>= fun data -> - Alcotest.(check string) "Read file" "Hello" data; - Lwt.return_unit + let data = Store.File.get file in + Alcotest.(check string) "Read file" "Hello" data + +let ( / ) = Eio.Path.( / ) + +let with_temp_dir path fn = + Eio.Path.mkdir path ~perm:0o700; + Fun.protect (fun () -> Eio.Path.with_open_dir path fn) + ~finally:(fun () -> Eio.Path.rmtree path) -let test_file_store _switch = - Lwt_io.with_temp_dir ~prefix:"capnp-tests-" @@ fun tmpdir -> +let test_file_store ~dir ~net:_ = + with_temp_dir (dir / "capnp-tests") @@ fun tmpdir -> let module S = Capnp_rpc_unix.File_store in let s = S.create tmpdir in Alcotest.(check (option reject)) "Missing file" None @@ S.load s ~digest:"missing"; @@ -655,84 +668,94 @@ let test_file_store _switch = Builder.to_reader b in S.save s ~digest:"!/.." data; - Alcotest.(check (option string)) "Restored" (Some "Test") @@ Option.map Reader.text_get (S.load s ~digest:"!/.."); - Lwt.return_unit + Alcotest.(check (option string)) "Restored" (Some "Test") @@ Option.map Reader.text_get (S.load s ~digest:"!/..") let capnp_error = Alcotest.of_pp Capnp_rpc.Exception.pp -let test_await_settled _switch = +let test_await_settled ~net:_ = (* Ok *) + Switch.run @@ fun sw -> let p, r = Capability.promise () in - let check = Capability.await_settled p in + let check = Fiber.fork_promise ~sw (fun () -> Capability.await_settled p) in Capability.resolve_ok r @@ Echo.local (); - check >>= fun check -> + let check = Promise.await_exn check in Alcotest.(check (result unit capnp_error)) "Check await success" (Ok ()) check; Capability.dec_ref p; (* Error *) let p, r = Capability.promise () in - let check = Capability.await_settled p in + let check = Fiber.fork_promise ~sw (fun () -> Capability.await_settled p) in let err = Capnp_rpc.Exception.v "Test" in Capability.resolve_exn r err; - check >>= fun check -> - Alcotest.(check (result unit capnp_error)) "Check await failure" (Error err) check; - Lwt.return_unit + let check = Promise.await_exn check in + Alcotest.(check (result unit capnp_error)) "Check await failure" (Error err) check (* The client disconnects before the server has finished loading the bootstrap object. *) -let test_late_bootstrap switch = - let connected, set_connected = Lwt.wait () in - let service, set_service = Lwt.wait () in +let test_late_bootstrap ~net = + Switch.run @@ fun sw -> + let connected, set_connected = Promise.create () in + let service, set_service = Promise.create () in let module Loader = struct type t = unit let hash () = `SHA256 let make_sturdy () _id = assert false let load () _sr _name = - Lwt.wakeup_later set_connected (); - service + Promise.resolve set_connected (); + Promise.await service; + Capnp_rpc_net.Restorer.grant @@ Echo.local () end in - let table = Capnp_rpc_net.Restorer.Table.of_loader (module Loader) () in + let table = Capnp_rpc_net.Restorer.Table.of_loader ~sw (module Loader) () in let restore = Restorer.of_table table in - let client_switch = Lwt_switch.create () in - make_vats_full ~client_switch ~server_switch:switch ~restore () >>= fun cs -> + let cs = make_vats_full ~sw ~restore ~net () in let service = get_bootstrap cs in - connected >>= fun () -> - Lwt_switch.turn_off client_switch >>= fun () -> - Lwt.wakeup set_service @@ Capnp_rpc_net.Restorer.grant @@ Echo.local (); - service >>= fun _ -> - Lwt.return () - -let run name fn = Alcotest_lwt.test_case_sync name `Quick fn - -let rpc_tests = [ - run_lwt "Simple" (test_simple ~serve_tls:false); - run_lwt "Crypto" (test_simple ~serve_tls:true); - run_lwt "Bad crypto" test_bad_crypto ~expected_warnings:1; - run_lwt "Parallel" test_parallel; - run_lwt "Embargo" test_embargo; - run_lwt "Resolve" test_resolve; - run_lwt "Registry" test_registry; - run_lwt "Calculator" test_calculator; - run_lwt "Calculator 2" test_calculator2; - run_lwt "Cancel" test_cancel; - run_lwt "Indexing" test_indexing; - run "Options" test_options; - run "Sturdy URI" test_sturdy_uri; - run_lwt "Sturdy self" test_sturdy_self; - run_lwt "Table restorer" test_table_restorer; - run_lwt "Fn restorer" test_fn_restorer; - run_lwt "Broken ref" test_broken; - run "Broken ref 2" test_broken2; - run "Broken ref 3" test_broken3; - run "Broken ref 4" test_broken4; - run_lwt "Parallel connect" test_parallel_connect; - run_lwt "Parallel fails" test_parallel_fails; - run_lwt "Crossed calls" test_crossed_calls; - run_lwt "Store" test_store; - run_lwt "File store" test_file_store; - run_lwt "Await settled" test_await_settled; - run_lwt "Late bootstrap" test_late_bootstrap; -] + Promise.await connected; + cs.client_cancel (); + let service = Capability.await_settled service |> Result.get_error in + Logs.info (fun f -> f "client got: %a" Capnp_rpc.Exception.pp service); + assert (service.Capnp_rpc.Exception.ty = `Disconnected); + Promise.resolve set_service (); + (* The restorer yields once before returning the cap, + so we wait too, to ensure it's done. *) + Fiber.yield () + +let run name fn = Alcotest.test_case name `Quick fn + +let rpc_tests ~net ~dir = + let net = Capnp_rpc_unix.Network.v net in + let run_eio = run_eio ~net in + [ + run_eio "Simple" (test_simple ~serve_tls:false); + run_eio "Crypto" (test_simple ~serve_tls:true); + run_eio "Bad crypto" test_bad_crypto ~expected_warnings:1; + run_eio "Parallel" test_parallel; + run_eio "Embargo" test_embargo; + run_eio "Resolve" test_resolve; + run_eio "Registry" test_registry; + run_eio "Calculator" test_calculator; + run_eio "Calculator 2" test_calculator2; + run_eio "Cancel" test_cancel; + run_eio "Indexing" test_indexing; + run "Options" test_options; + run "Sturdy URI" test_sturdy_uri; + run_eio "Sturdy self" test_sturdy_self; + run_eio "Table restorer" test_table_restorer; + run_eio "Fn restorer" test_fn_restorer; + run_eio "Broken ref" test_broken; + run "Broken ref 2" test_broken2; + run "Broken ref 3" test_broken3; + run "Broken ref 4" test_broken4; + run_eio "Parallel connect" test_parallel_connect; + run_eio "Parallel fails" test_parallel_fails; + run_eio "Crossed calls" test_crossed_calls; + run_eio "Store" test_store; + run_eio "File store" (test_file_store ~dir); + run_eio "Await settled" test_await_settled; + run_eio "Late bootstrap" test_late_bootstrap; + ] let () = - Alcotest_lwt.run ~and_exit:false "capnp-rpc" [ - "lwt", rpc_tests; - ] |> Lwt_main.run + Eio_main.run @@ fun env -> + Mirage_crypto_rng_eio.run (module Mirage_crypto_rng.Fortuna) env @@ fun () -> + (* Eio_unix.Ctf.with_tracing "/tmp/trace.ctf" @@ fun () -> *) + Alcotest.run ~and_exit:false "capnp-rpc" [ + "eio", rpc_tests ~net:env#net ~dir:env#cwd; + ] diff --git a/unix/capnp_rpc_unix.ml b/unix/capnp_rpc_unix.ml index da6ab43fa..1114ff7c9 100644 --- a/unix/capnp_rpc_unix.ml +++ b/unix/capnp_rpc_unix.ml @@ -1,12 +1,7 @@ +open Eio.Std open Astring -open Lwt.Infix module Log = Capnp_rpc.Debug.Log -module Unix_flow = Unix_flow - -let () = Mirage_crypto_rng_lwt.initialize (module Mirage_crypto_rng.Fortuna) - -type flow = Unix_flow.flow module CapTP = Vat_network.CapTP module Vat = Vat_network.Vat @@ -95,8 +90,8 @@ module Console = struct clear (); messages := msg :: !messages; show (); - Lwt.finalize f - (fun () -> + Fun.protect f + ~finally:(fun () -> clear (); let rec remove_first = function | [] -> assert false @@ -104,8 +99,7 @@ module Console = struct | x :: xs -> x :: remove_first xs in messages := remove_first !messages; - show (); - Lwt.return_unit + show () ) end @@ -122,7 +116,7 @@ let rec connect_with_progress ?(mode=`Auto) sr = let did_log = ref false in Log.info (fun f -> did_log := true; f "Connecting to %a..." pp sr); if !did_log then ( - Sturdy_ref.connect sr >|= function + match Sturdy_ref.connect sr with | Ok _ as x -> Log.info (fun f -> f "Connected to %a" pp sr); x | Error _ as e -> e ) else ( @@ -133,108 +127,91 @@ let rec connect_with_progress ?(mode=`Auto) sr = ) | `Batch -> Fmt.epr "Connecting to %a... %!" pp sr; - begin Sturdy_ref.connect sr >|= function + begin match Sturdy_ref.connect sr with | Ok _ as x -> Fmt.epr "OK@."; x | Error _ as x -> Fmt.epr "ERROR@."; x end | `Console -> - let x = Sturdy_ref.connect sr in - Lwt.choose [Lwt_unix.sleep 0.5; Lwt.map ignore x] >>= fun () -> - if Lwt.is_sleeping x then ( - Console.with_msg (Fmt.str "[ connecting to %a ]" pp sr) - (fun () -> x) - ) else x + Switch.run ~name:"connect_with_progress" @@ fun sw -> + Fiber.fork_daemon ~sw (fun () -> + Eio_unix.sleep 0.5; + Console.with_msg (Fmt.str "[ connecting to %a ]" pp sr) Fiber.await_cancel + ); + Sturdy_ref.connect sr | `Silent -> Sturdy_ref.connect sr let with_cap_exn ?progress sr f = - connect_with_progress ?mode:progress sr >>= function + match connect_with_progress ?mode:progress sr with | Error ex -> Fmt.failwith "%a" Capnp_rpc.Exception.pp ex | Ok x -> Capnp_rpc.Capability.with_ref x f let handle_connection ?tags ~secret_key vat client = - Lwt.catch (fun () -> - let switch = Lwt_switch.create () in - let raw_flow = Unix_flow.connect ~switch client in - Network.accept_connection ~switch ~secret_key raw_flow >>= function - | Error (`Msg msg) -> - Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg); - Lwt.return_unit - | Ok ep -> - Vat.add_connection vat ~switch ~mode:`Accept ep >|= fun (_ : CapTP.t) -> - () - ) - (fun ex -> - Log.err (fun f -> f "Uncaught exception handling connection: %a" Fmt.exn ex); - Lwt.return_unit - ) + match Network.accept_connection ~secret_key client with + | Error (`Msg msg) -> + Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg) + | Ok ep -> Vat.run_connection vat ~mode:`Accept ep ignore -let addr_of_host host = - match Unix.gethostbyname host with - | exception Not_found -> - Capnp_rpc.Debug.failf "Unknown host %S" host - | addr -> - if Array.length addr.Unix.h_addr_list = 0 then - Capnp_rpc.Debug.failf "No addresses found for host name %S" host - else - addr.Unix.h_addr_list.(0) - -let serve ?switch ?tags ?restore config = +let create_server ?tags ?restore ~sw ~net config = let {Vat_config.backlog; secret_key = _; serve_tls; listen_address; public_address} = config in let vat = let auth = Vat_config.auth config in let secret_key = lazy (fst (Lazy.force config.secret_key)) in - Vat.create ?switch ?tags ?restore ~address:(public_address, auth) ~secret_key () + Vat.create ?tags ?restore ~sw ~address:(public_address, auth) ~secret_key net in let socket = match listen_address with - | `Unix path -> - begin match Unix.lstat path with - | { Unix.st_kind = Unix.S_SOCK; _ } -> Unix.unlink path - | _ -> () - | exception Unix.Unix_error(Unix.ENOENT, _, _) -> () - end; - let socket = Unix.(socket PF_UNIX SOCK_STREAM 0) in - Unix.bind socket (Unix.ADDR_UNIX path); - socket + | `Unix _ as addr -> Eio.Net.listen ~sw ~backlog ~reuse_addr:true net addr | `TCP (host, port) -> - let socket = Unix.(socket PF_INET SOCK_STREAM 0) in - Unix.setsockopt socket Unix.SO_REUSEADDR true; - Unix.setsockopt socket Unix.SO_KEEPALIVE true; - Keepalive.try_set_idle socket 60; - Unix.bind socket (Unix.ADDR_INET (addr_of_host host, port)); - socket + match Eio.Net.getaddrinfo_stream net host ~service:(string_of_int port) with + | [] -> Capnp_rpc.Debug.failf "No addresses found for host name %S" host + | addr :: _ -> + let socket = Eio.Net.listen ~sw ~backlog ~reuse_addr:true net addr in + let unix_socket = Eio_unix.Resource.fd_opt socket |> Option.get in + Eio_unix.Fd.use_exn "keep-alive" unix_socket @@ fun unix_socket -> + Unix.setsockopt unix_socket Unix.SO_KEEPALIVE true; + Keepalive.try_set_idle unix_socket 60; + socket in - Unix.listen socket backlog; Log.info (fun f -> f ?tags "Waiting for %s connections on %a" - (if serve_tls then "(encrypted)" else "UNENCRYPTED") - Vat_config.Listen_address.pp listen_address); - let lwt_socket = Lwt_unix.of_unix_file_descr socket in - let rec loop () = - Lwt_switch.check switch; - Lwt_unix.accept lwt_socket >>= fun (client, _addr) -> - Log.info (fun f -> f ?tags "Accepting new connection"); - let secret_key = if serve_tls then Some (Vat_config.secret_key config) else None in - Lwt.async (fun () -> handle_connection ?tags ~secret_key vat client); - loop () - in - Lwt.async (fun () -> - Lwt.catch - (fun () -> - let th = loop () in - Lwt_switch.add_hook switch (fun () -> Lwt.cancel th; Lwt.return_unit); - th - ) - (function - | Lwt.Canceled -> Lwt.return_unit - | ex -> Lwt.fail ex - ) - >>= fun () -> - Lwt_unix.close lwt_socket + (if serve_tls then "(encrypted)" else "UNENCRYPTED") + Vat_config.Listen_address.pp listen_address); + vat, socket + +let listen ?tags ~sw (config, vat, socket) = + while true do + (* This is like [Eio.Net.accept_fork], but using [fork_daemon] instead of [fork]. *) + let child_started = ref false in + let client, addr = Eio.Net.accept ~sw socket in + Fun.protect ~finally:(fun () -> if !child_started = false then Eio.Net.close client) + (fun () -> + Log.info (fun f -> f ?tags "Accepting new connection from %a" Eio.Net.Sockaddr.pp addr); + Fiber.fork_daemon ~sw (fun () -> + match + child_started := true; + let secret_key = if config.Vat_config.serve_tls then Some (Vat_config.secret_key config) else None in + handle_connection ?tags ~secret_key vat client + with + | () -> Eio.Net.close client; `Stop_daemon + | exception ex -> + Eio.Net.close client; + Fiber.check (); + Log.info (fun f -> f ?tags "Error handling connection from %a: %a" Eio.Net.Sockaddr.pp addr Eio.Exn.pp ex); + `Stop_daemon + ) + ) + done + +let serve ?tags ?restore ~sw ~net config = + let net = (net :> [`Generic] Eio.Net.ty r) in + let (vat, socket) = create_server ?tags ?restore ~sw ~net config in + Fiber.fork_daemon ~sw (fun () -> + listen ?tags ~sw (config, vat, socket) ); - Lwt.return vat + vat -let client_only_vat ?switch ?tags ?restore () = +let client_only_vat ?tags ?restore ~sw net = + let net = (net :> [`Generic] Eio.Net.ty r) in let secret_key = lazy (Capnp_rpc_net.Auth.Secret_key.generate ()) in - Vat.create ?switch ?tags ?restore ~secret_key () + Vat.create ?tags ?restore ~secret_key ~sw net let manpage_capnp_options = Vat_config.docs diff --git a/unix/capnp_rpc_unix.mli b/unix/capnp_rpc_unix.mli index 05a3c59fd..cde4d5266 100644 --- a/unix/capnp_rpc_unix.mli +++ b/unix/capnp_rpc_unix.mli @@ -3,10 +3,7 @@ open Capnp_rpc.Std open Capnp_rpc_net -module Unix_flow = Unix_flow - include Capnp_rpc_net.VAT_NETWORK with - type flow = Unix_flow.flow and module Network = Network (** Configuration for a {!Vat}. *) @@ -66,7 +63,7 @@ module File_store : sig type 'a t (** A store of values of type ['a]. *) - val create : string -> 'a t + val create : _ Eio.Path.t -> 'a t (** [create dir] is a store for Cap'n Proto structs. Items are stored inside [dir]. *) @@ -102,7 +99,7 @@ val sturdy_uri : Uri.t Cmdliner.Arg.conv val connect_with_progress : ?mode:[`Auto | `Log | `Batch | `Console | `Silent] -> - 'a Sturdy_ref.t -> ('a Capability.t, Capnp_rpc.Exception.t) Lwt_result.t + 'a Sturdy_ref.t -> ('a Capability.t, Capnp_rpc.Exception.t) result (** [connect_with_progress sr] is like [Sturdy_ref.connect], but shows that a connection is in progress. Note: On failure, it does {e not} display the error, which should instead be handled by the caller. @param mode Controls how progress is displayed: @@ -116,26 +113,27 @@ val connect_with_progress : val with_cap_exn : ?progress:[`Auto | `Log | `Batch | `Console | `Silent] -> 'a Sturdy_ref.t -> - ('a Capability.t -> 'b Lwt.t) -> - 'b Lwt.t + ('a Capability.t -> 'b) -> + 'b (** Like [Sturdy_ref.with_cap_exn], but using [connect_with_progress] to show progress. *) val serve : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:Restorer.t -> + sw:Eio.Switch.t -> + net:_ Eio.Net.t -> Vat_config.t -> - Vat.t Lwt.t -(** [serve ~restore vat_config] is a new vat that is listening for new connections + Vat.t +(** [serve ~restore ~sw ~net vat_config] is a new vat that is listening for new connections as specified by [vat_config]. After connecting to it, clients can get access to services using [restore]. *) val client_only_vat : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:Restorer.t -> - unit -> Vat.t -(** [client_only_vat ()] is a new vat that does not listen for incoming connections. *) + sw:Eio.Switch.t -> + _ Eio.Net.t -> Vat.t +(** [client_only_vat net] is a new vat that does not listen for incoming connections. *) val manpage_capnp_options : string (** [manpage_capnp_options] is the title of the section of the man-page containing the Cap'n Proto options. diff --git a/unix/dune b/unix/dune index 240722463..57ef2fbc7 100644 --- a/unix/dune +++ b/unix/dune @@ -1,5 +1,4 @@ (library (name capnp_rpc_unix) (public_name capnp-rpc-unix) - (libraries lwt.unix astring capnp-rpc capnp-rpc-net fmt logs - mirage-crypto-rng-lwt cmdliner cstruct-lwt extunix)) + (libraries eio.unix astring capnp-rpc capnp-rpc-net fmt logs cmdliner cstruct extunix)) diff --git a/unix/file_store.ml b/unix/file_store.ml index 0e242eee3..2d0427b89 100644 --- a/unix/file_store.ml +++ b/unix/file_store.ml @@ -2,53 +2,43 @@ open Capnp_rpc module ReaderOps = Capnp.Runtime.ReaderInc.Make(Capnp_rpc) +let ( / ) = Eio.Path.( / ) + type 'a t = { - dir : string; + dir : Eio.Fs.dir_ty Eio.Path.t; } -let create dir = { dir } +let create dir = { dir = (dir :> Eio.Fs.dir_ty Eio.Path.t) } -let path_of_digest t digest = - match Base64.encode ~alphabet:Base64.uri_safe_alphabet ~pad:false digest with - | Ok filename -> Filename.concat t.dir filename - | Error (`Msg m) -> failwith m (* Encoding can't really fail *) +let leaf_of_digest digest = + Base64.encode_exn ~alphabet:Base64.uri_safe_alphabet ~pad:false digest let segments_of_reader = function | None -> [] | Some ss -> Message.to_storage ss.StructStorage.data.Slice.msg let save t ~digest data = - let path = path_of_digest t digest in - let tmp_path = path ^ ".new" in - let ch = open_out_bin tmp_path in - Fun.protect ~finally:(fun () -> close_out ch) (fun () -> + let leaf = leaf_of_digest digest in + let tmp_leaf = leaf ^ ".new" in + Eio.Path.with_open_out ~create:(`Or_truncate 0o644) (t.dir / tmp_leaf) (fun flow -> let segments = segments_of_reader data in segments |> List.iter (fun {Message.segment; bytes_consumed} -> - output ch segment 0 bytes_consumed + let buf = Cstruct.of_bytes segment ~len:bytes_consumed in + Eio.Flow.write flow [buf] ); ); - Unix.rename tmp_path path + Eio.Path.rename (t.dir / tmp_leaf) (t.dir / leaf) let remove t ~digest = - let path = path_of_digest t digest in - Unix.unlink path + Eio.Path.unlink (t.dir / leaf_of_digest digest) let load t ~digest = - let path = path_of_digest t digest in - if Sys.file_exists path then ( - let ch = open_in_bin path in - let segment = - Fun.protect ~finally:(fun () -> close_in ch) (fun () -> - let len = in_channel_length ch in - let segment = Bytes.create len in - really_input ch segment 0 len; - segment - ) - in - let msg = Message.of_storage [segment] in + let leaf = leaf_of_digest digest in + match Eio.Path.load (t.dir / leaf) with + | segment -> + let msg = Message.of_storage [Bytes.unsafe_of_string segment] in let reader = ReaderOps.get_root_struct (Message.readonly msg) in Some reader - ) else ( - Logs.info (fun f -> f "File %S not found" path); + | exception Eio.Io (Eio.Fs.E Not_found _, _) -> + Logs.info (fun f -> f "File %S not found" leaf); None - ) diff --git a/unix/network.ml b/unix/network.ml index 6e397f339..f104c54d7 100644 --- a/unix/network.ml +++ b/unix/network.ml @@ -1,7 +1,7 @@ -open Lwt.Infix +open Eio.Std module Log = Capnp_rpc.Debug.Log -module Tls_wrapper = Capnp_rpc_net.Tls_wrapper.Make(Unix_flow) +module Tls_wrapper = Capnp_rpc_net.Tls_wrapper module Location = struct open Astring @@ -50,7 +50,7 @@ module Types = struct type join_key_part end -type t = unit +type t = [`Generic] Eio.Net.ty r let error fmt = fmt |> Fmt.kstr @@ fun msg -> @@ -58,45 +58,32 @@ let error fmt = let parse_third_party_cap_id _ = `Two_party_only -let addr_of_host host = - match Unix.gethostbyname host with - | exception Not_found -> - Capnp_rpc.Debug.failf "Unknown host %S" host - | addr -> - if Array.length addr.Unix.h_addr_list = 0 then - Capnp_rpc.Debug.failf "No addresses found for host name %S" host - else - addr.Unix.h_addr_list.(0) - -let connect_socket = function - | `Unix path -> - Log.info (fun f -> f "Connecting to %S..." path); - let socket = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in - Lwt.catch - (fun () -> Lwt_unix.connect socket (Unix.ADDR_UNIX path) >|= fun () -> socket) - (fun ex -> Lwt_unix.close socket >>= fun () -> Lwt.fail ex) - | `TCP (host, port) -> - Log.info (fun f -> f "Connecting to %s:%d..." host port); - let socket = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - Lwt.catch - (fun () -> - Lwt_unix.setsockopt socket Unix.SO_KEEPALIVE true; - Keepalive.try_set_idle (Lwt_unix.unix_file_descr socket) 60; - Lwt_unix.connect socket (Unix.ADDR_INET (addr_of_host host, port)) >|= fun () -> - socket - ) - (fun ex -> Lwt_unix.close socket >>= fun () -> Lwt.fail ex) - -let connect () ~switch ~secret_key (addr, auth) = - Lwt.try_bind - (fun () -> connect_socket addr) - (fun socket -> - let flow = Unix_flow.connect ~switch socket in - Tls_wrapper.connect_as_client ~switch flow secret_key auth - ) - (fun ex -> - Lwt.return @@ error "@[Network connection for %a failed:@,%a@]" Location.pp addr Fmt.exn ex - ) - -let accept_connection ~switch ~secret_key flow = - Tls_wrapper.connect_as_server ~switch flow secret_key +let connect net ~sw ~secret_key (addr, auth) = + let eio_addr = + match addr with + | `Unix _ as x -> x + | `TCP (host, port) -> + match Eio.Net.getaddrinfo_stream net host ~service:(string_of_int port) with + | [] -> Capnp_rpc.Debug.failf "No addresses found for host name %S" host + | addr :: _ -> addr + in + Log.info (fun f -> f "Connecting to %a..." Eio.Net.Sockaddr.pp eio_addr); + match Eio.Net.connect ~sw net eio_addr with + | socket -> + begin match addr with + | `Unix _ -> () + | `TCP _ -> + let socket = Eio_unix.Resource.fd_opt socket |> Option.get in + Eio_unix.Fd.use_exn "keep-alive" socket @@ fun socket -> + Unix.setsockopt socket Unix.SO_KEEPALIVE true; + Keepalive.try_set_idle socket 60 + end; + Tls_wrapper.connect_as_client socket secret_key auth + | exception ex -> + Fiber.check (); + error "@[Network connection for %a failed:@,%a@]" Location.pp addr Fmt.exn ex + +let accept_connection ~secret_key flow = + Tls_wrapper.connect_as_server flow secret_key + +let v t = (t :> [`Generic] Eio.Net.ty r) diff --git a/unix/network.mli b/unix/network.mli index 7ba6d427c..17cd939f6 100644 --- a/unix/network.mli +++ b/unix/network.mli @@ -1,5 +1,7 @@ (** A network using TCP and Unix-domain sockets. *) +open Eio.Std + module Location : sig type t = [ | `Unix of string @@ -25,14 +27,15 @@ module Location : sig end include Capnp_rpc_net.S.NETWORK with - type t = unit and + type t = [`Generic] Eio.Net.ty Eio.Resource.t and type Address.t = Location.t * Capnp_rpc_net.Auth.Digest.t +val v : _ Eio.Net.t -> t + val accept_connection : - switch:Lwt_switch.t -> secret_key:Capnp_rpc_net.Auth.Secret_key.t option -> - Unix_flow.flow -> - (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result Lwt.t + [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> + (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result (** [accept_connection ~switch ~secret_key flow] is a new endpoint for [flow]. If [secret_key] is not [None], it is used to perform a TLS server-side handshake. Otherwise, the connection is not encrypted. *) diff --git a/unix/unix_flow.ml b/unix/unix_flow.ml deleted file mode 100644 index 6d5d9c35c..000000000 --- a/unix/unix_flow.ml +++ /dev/null @@ -1,109 +0,0 @@ -open Lwt.Infix - -(* Slightly rude to set signal handlers in a library, but SIGPIPE makes no sense - in a modern application. *) -let () = if not Sys.win32 then Sys.(set_signal sigpipe Signal_ignore) - -type flow = { - fd : Lwt_unix.file_descr; - mutable current_write : int Lwt.t option; - mutable current_read : int Lwt.t option; - mutable closed : bool; -} -type error = [`Exception of exn] -type write_error = [`Closed | `Exception of exn] - -let opt_cancel = function - | None -> () - | Some x -> Lwt.cancel x - -let close t = - if t.closed then Lwt.return_unit - else ( - t.closed <- true; - opt_cancel t.current_read; - opt_cancel t.current_write; - Lwt.catch - (fun () -> Lwt_unix.close t.fd) - (function - | Unix.Unix_error (Unix.ECONNRESET, _, _) -> Lwt.return_unit (* FreeBSD *) - | ex -> raise ex - ) - ) - -let pp_error f = function - | `Exception ex -> Fmt.exn f ex - | `Closed -> Fmt.string f "Closed" - -let pp_write_error = pp_error - -let write t buf = - let rec aux buf = - if t.closed then Lwt.return (Error `Closed) - else ( - assert (t.current_write = None); - let write_thread = Lwt_cstruct.write t.fd buf in - t.current_write <- Some write_thread; - write_thread >>= fun wrote -> - t.current_write <- None; - if wrote = Cstruct.length buf then Lwt.return (Ok ()) - else aux (Cstruct.shift buf wrote) - ) - in - Lwt.catch - (fun () -> aux buf) - (function - | Unix.Unix_error (Unix.ECONNRESET, _, _) - | Unix.Unix_error (Unix.ENOTCONN, _, _) (* macos *) - | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return @@ Error `Closed - | ex -> Lwt.return @@ Error (`Exception ex)) - -let rec writev t = function - | [] -> Lwt.return (Ok ()) - | x :: xs -> - write t x >>= function - | Ok () -> writev t xs - | Error _ as e -> Lwt.return e - -let read t = - let len = 4096 in - let buf = Cstruct.create_unsafe len in - Lwt.try_bind - (fun () -> - assert (t.current_read = None); - if t.closed then raise Lwt.Canceled; - let read_thread = Lwt_cstruct.read t.fd buf in - t.current_read <- Some read_thread; - read_thread - ) - (function - | 0 -> - Lwt.return @@ Ok `Eof - | got -> - t.current_read <- None; - Lwt.return @@ Ok (`Data (Cstruct.sub buf 0 got)) - ) - (function - | Lwt.Canceled - | Unix.Unix_error (Unix.EPIPE, _, _) - | Unix.Unix_error (Unix.ECONNRESET, _, _) -> Lwt_result.return `Eof - | ex -> Lwt.return @@ Error (`Exception ex) - ) - -let connect ?switch fd = - let t = { fd; closed = false; current_read = None; current_write = None } in - Lwt_switch.add_hook switch (fun () -> close t); - t - -let socketpair ?switch () = - let a, b = Lwt_unix.(socketpair PF_UNIX SOCK_STREAM 0) in - connect ?switch a, connect ?switch b - -let shutdown t cmd = - Lwt_unix.shutdown t.fd - (match cmd with - | `read -> SHUTDOWN_RECEIVE - | `read_write -> SHUTDOWN_ALL - | `write -> SHUTDOWN_SEND - ); - Lwt.return_unit diff --git a/unix/unix_flow.mli b/unix/unix_flow.mli deleted file mode 100644 index 78d85bfee..000000000 --- a/unix/unix_flow.mli +++ /dev/null @@ -1,7 +0,0 @@ -(** Wraps a Unix [file_descr] to provide the Mirage flow API. *) - -include Mirage_flow.S - -val connect : ?switch:Lwt_switch.t -> Lwt_unix.file_descr -> flow - -val socketpair : ?switch:Lwt_switch.t -> unit -> flow * flow diff --git a/unix/vat_network.ml b/unix/vat_network.ml index d5d810f8e..922b2eda7 100644 --- a/unix/vat_network.ml +++ b/unix/vat_network.ml @@ -1 +1 @@ -include Capnp_rpc_net.Networking (Network) (Unix_flow) +include Capnp_rpc_net.Networking (Network)