diff --git a/rpc/serve.go b/rpc/serve.go index 789099dd..4642dddf 100644 --- a/rpc/serve.go +++ b/rpc/serve.go @@ -8,18 +8,51 @@ import ( "capnproto.org/go/capnp/v3" ) +// serveOpts are options for the Cap'n Proto server. +type serveOpts struct { + newTransport NewTransportFunc +} + +// defaultServeOpts returns the default server opts. +func defaultServeOpts() serveOpts { + return serveOpts{ + newTransport: NewStreamTransport, + } +} + +type ServeOption func(*serveOpts) + +// WithBasicStreamingTransport enables the streaming transport with basic encoding. +func WithBasicStreamingTransport() ServeOption { + return func(opts *serveOpts) { + opts.newTransport = NewStreamTransport + } +} + +// WithPackedStreamingTransport enables the streaming transport with packed encoding. +func WithPackedStreamingTransport() ServeOption { + return func(opts *serveOpts) { + opts.newTransport = NewPackedStreamTransport + } +} + // Serve serves a Cap'n Proto RPC to incoming connections. // // Serve will take ownership of bootstrapClient and release it after the listener closes. // // Serve exits with the listener error if the listener is closed by the owner. -func Serve(lis net.Listener, boot capnp.Client) error { +func Serve(lis net.Listener, boot capnp.Client, opts ...ServeOption) error { if !boot.IsValid() { err := errors.New("bootstrap client is not valid") return err } // Since we took ownership of the bootstrap client, release it after we're done. defer boot.Release() + + options := defaultServeOpts() + for _, o := range opts { + o(&options) + } for { // Accept incoming connections conn, err := lis.Accept() @@ -33,7 +66,7 @@ func Serve(lis net.Listener, boot capnp.Client) error { BootstrapClient: boot.AddRef(), } // For each new incoming connection, create a new RPC transport connection that will serve incoming RPC requests - transport := NewStreamTransport(conn) + transport := options.newTransport(conn) _ = NewConn(transport, &opts) } } @@ -44,7 +77,7 @@ func Serve(lis net.Listener, boot capnp.Client) error { // and "tcp" for regular TCP IP4 or IP6 connections. // // ListenAndServe will take ownership of bootstrapClient and release it on exit. -func ListenAndServe(ctx context.Context, network, addr string, bootstrapClient capnp.Client) error { +func ListenAndServe(ctx context.Context, network, addr string, bootstrapClient capnp.Client, opts ...ServeOption) error { listener, err := net.Listen(network, addr) diff --git a/rpc/serve_test.go b/rpc/serve_test.go index fb60e407..37646255 100644 --- a/rpc/serve_test.go +++ b/rpc/serve_test.go @@ -110,21 +110,43 @@ func TestServeCapability(t *testing.T) { } func TestListenAndServe(t *testing.T) { - var err error - t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - errChannel := make(chan error) - - // Provide a server that listens - srv := testcp.PingPong_ServerToClient(pingPongServer{}) - bootstrapClient := capnp.Client(srv) - go func() { - t.Log("Starting ListenAndServe") - err2 := rpc.ListenAndServe(ctx, "tcp", ":0", bootstrapClient) - errChannel <- err2 - }() - - cancelFunc() - err = <-errChannel // Will hang if server does not return. - assert.ErrorIs(t, err, net.ErrClosed) + cases := []struct { + name string + opts []rpc.ServeOption + }{ + { + name: "basic encoding transport", + opts: []rpc.ServeOption{ + rpc.WithBasicStreamingTransport(), + }, + }, + { + name: "packed encoding transport", + opts: []rpc.ServeOption{ + rpc.WithPackedStreamingTransport(), + }, + }, + } + + for _, tcase := range cases { + t.Run(tcase.name, func(t *testing.T) { + var err error + t.Parallel() + ctx, cancelFunc := context.WithCancel(context.Background()) + errChannel := make(chan error) + + // Provide a server that listens + srv := testcp.PingPong_ServerToClient(pingPongServer{}) + bootstrapClient := capnp.Client(srv) + go func() { + t.Log("Starting ListenAndServe") + err2 := rpc.ListenAndServe(ctx, "tcp", ":0", bootstrapClient, tcase.opts...) + errChannel <- err2 + }() + + cancelFunc() + err = <-errChannel // Will hang if server does not return. + assert.ErrorIs(t, err, net.ErrClosed) + }) + } } diff --git a/rpc/transport.go b/rpc/transport.go index 51f46548..8135ea9d 100644 --- a/rpc/transport.go +++ b/rpc/transport.go @@ -10,6 +10,7 @@ import ( type Codec = transport.Codec type Transport = transport.Transport +type NewTransportFunc func(io.ReadWriteCloser) Transport // NewStreamTransport is an alias for as transport.NewStream func NewStreamTransport(rwc io.ReadWriteCloser) Transport {