Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithConditionalHandlerOptions for conditional options #538

Merged
merged 4 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,35 @@ func TestHandlerCompressionOptionTest(t *testing.T) {

t.Run("defaults", func(t *testing.T) {
t.Parallel()
config := newHandlerConfig(testProc, nil)
config := newHandlerConfig(testProc, StreamTypeUnary, nil)
assert.Equal(t, config.CompressionNames, []string{compressionGzip})
checkPools(t, config)
})
t.Run("WithCompression", func(t *testing.T) {
t.Parallel()
opts := []HandlerOption{WithCompression("foo", dummyDecompressCtor, dummyCompressCtor)}
config := newHandlerConfig(testProc, opts)
config := newHandlerConfig(testProc, StreamTypeUnary, opts)
assert.Equal(t, config.CompressionNames, []string{compressionGzip, "foo"})
checkPools(t, config)
})
t.Run("WithCompression-empty-name-noop", func(t *testing.T) {
t.Parallel()
opts := []HandlerOption{WithCompression("", dummyDecompressCtor, dummyCompressCtor)}
config := newHandlerConfig(testProc, opts)
config := newHandlerConfig(testProc, StreamTypeUnary, opts)
assert.Equal(t, config.CompressionNames, []string{compressionGzip})
checkPools(t, config)
})
t.Run("WithCompression-nil-ctors-noop", func(t *testing.T) {
t.Parallel()
opts := []HandlerOption{WithCompression("foo", nil, nil)}
config := newHandlerConfig(testProc, opts)
config := newHandlerConfig(testProc, StreamTypeUnary, opts)
assert.Equal(t, config.CompressionNames, []string{compressionGzip})
checkPools(t, config)
})
t.Run("WithCompression-nil-ctors-unregisters", func(t *testing.T) {
t.Parallel()
opts := []HandlerOption{WithCompression("gzip", nil, nil)}
config := newHandlerConfig(testProc, opts)
config := newHandlerConfig(testProc, StreamTypeUnary, opts)
assert.Equal(t, config.CompressionNames, nil)
checkPools(t, config)
})
Expand Down
8 changes: 7 additions & 1 deletion connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,13 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
readMaxBytes := 1024
mux.Handle(pingv1connect.NewPingServiceHandler(
pingServer{},
connect.WithReadMaxBytes(readMaxBytes),
connect.WithConditionalHandlerOptions(func(spec connect.Spec) []connect.HandlerOption {
var options []connect.HandlerOption
if spec.Procedure == pingv1connect.PingServicePingProcedure {
options = append(options, connect.WithReadMaxBytes(readMaxBytes))
}
return options
}),
))
readMaxBytesMatrix := func(t *testing.T, client pingv1connect.PingServiceClient, compressed bool) {
t.Helper()
Expand Down
3 changes: 2 additions & 1 deletion error_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ type ErrorWriter struct {
// NewErrorWriter constructs an ErrorWriter. To properly recognize supported
// RPC Content-Types in net/http middleware, you must pass the same
// HandlerOptions to NewErrorWriter and any wrapped Connect handlers.
// Options supplied via [WithConditionalHandlerOptions] are ignored.
func NewErrorWriter(opts ...HandlerOption) *ErrorWriter {
config := newHandlerConfig("", opts)
config := newHandlerConfig("", StreamTypeUnary, opts)
writer := &ErrorWriter{
bufferPool: config.BufferPool,
protobuf: newReadOnlyCodecs(config.Codecs).Protobuf(),
Expand Down
24 changes: 13 additions & 11 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewUnaryHandler[Req, Res any](
}
return res, err
})
config := newHandlerConfig(procedure, options)
config := newHandlerConfig(procedure, StreamTypeUnary, options)
if interceptor := config.Interceptor; interceptor != nil {
untyped = interceptor.WrapUnary(untyped)
}
Expand Down Expand Up @@ -87,9 +87,9 @@ func NewUnaryHandler[Req, Res any](
return conn.Send(response.Any())
}

protocolHandlers := config.newProtocolHandlers(StreamTypeUnary)
protocolHandlers := config.newProtocolHandlers()
return &Handler{
spec: config.newSpec(StreamTypeUnary),
spec: config.newSpec(),
implementation: implementation,
protocolHandlers: mappedMethodHandlers(protocolHandlers),
allowMethod: sortedAllowMethodValue(protocolHandlers),
Expand Down Expand Up @@ -253,9 +253,10 @@ type handlerConfig struct {
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
StreamType StreamType
}

func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig {
func newHandlerConfig(procedure string, streamType StreamType, options []HandlerOption) *handlerConfig {
protoPath := extractProtoPath(procedure)
config := handlerConfig{
Procedure: protoPath,
Expand All @@ -264,6 +265,7 @@ func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig
HandleGRPC: true,
HandleGRPCWeb: true,
BufferPool: newBufferPool(),
StreamType: streamType,
}
withProtoBinaryCodec().applyToHandler(&config)
withProtoJSONCodecs().applyToHandler(&config)
Expand All @@ -274,15 +276,15 @@ func newHandlerConfig(procedure string, options []HandlerOption) *handlerConfig
return &config
}

func (c *handlerConfig) newSpec(streamType StreamType) Spec {
func (c *handlerConfig) newSpec() Spec {
return Spec{
Procedure: c.Procedure,
StreamType: streamType,
StreamType: c.StreamType,
IdempotencyLevel: c.IdempotencyLevel,
}
}

func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHandler {
func (c *handlerConfig) newProtocolHandlers() []protocolHandler {
protocols := []protocol{&protocolConnect{}}
if c.HandleGRPC {
protocols = append(protocols, &protocolGRPC{web: false})
Expand All @@ -298,7 +300,7 @@ func (c *handlerConfig) newProtocolHandlers(streamType StreamType) []protocolHan
)
for _, protocol := range protocols {
handlers = append(handlers, protocol.NewHandler(&protocolHandlerParams{
Spec: c.newSpec(streamType),
Spec: c.newSpec(),
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
Expand All @@ -318,13 +320,13 @@ func newStreamHandler(
implementation StreamingHandlerFunc,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, options)
config := newHandlerConfig(procedure, streamType, options)
if ic := config.Interceptor; ic != nil {
implementation = ic.WrapStreamingHandler(implementation)
}
protocolHandlers := config.newProtocolHandlers(streamType)
protocolHandlers := config.newProtocolHandlers()
return &Handler{
spec: config.newSpec(streamType),
spec: config.newSpec(),
implementation: implementation,
protocolHandlers: mappedMethodHandlers(protocolHandlers),
allowMethod: sortedAllowMethodValue(protocolHandlers),
Expand Down
11 changes: 11 additions & 0 deletions interceptor_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,14 @@ func ExampleWithInterceptors() {
// inner interceptor: after call
// outer interceptor: after call
}

func ExampleWithConditionalHandlerOptions() {
connect.WithConditionalHandlerOptions(func(spec connect.Spec) []connect.HandlerOption {
var options []connect.HandlerOption
if spec.Procedure == pingv1connect.PingServicePingProcedure {
options = append(options, connect.WithReadMaxBytes(1024))
}
return options
})
// Output:
}
20 changes: 20 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ func WithRequireConnectProtocolHeader() HandlerOption {
return &requireConnectProtocolHeaderOption{}
}

// WithConditionalHandlerOptions accepts a function that returns a HandlerOption.
// It's used to conditionally apply HandlerOption to a Handler based on the Spec.
func WithConditionalHandlerOptions(conditional func(spec Spec) []HandlerOption) HandlerOption {
return &conditionalHandlerOptions{conditional: conditional}
}

// Option implements both [ClientOption] and [HandlerOption], so it can be
// applied both client-side and server-side.
type Option interface {
Expand Down Expand Up @@ -557,3 +563,17 @@ func withProtoJSONCodecs() HandlerOption {
WithCodec(&protoJSONCodec{codecNameJSONCharsetUTF8}),
)
}

type conditionalHandlerOptions struct {
conditional func(spec Spec) []HandlerOption
}

func (o *conditionalHandlerOptions) applyToHandler(config *handlerConfig) {
spec := config.newSpec()
if spec.Procedure == "" {
return // ignore empty specs
}
for _, option := range o.conditional(spec) {
option.applyToHandler(config)
}
}
Loading