diff --git a/alias.go b/alias.go index d66bacdd5..f805ec522 100644 --- a/alias.go +++ b/alias.go @@ -95,6 +95,7 @@ var ( WithStructuredEncoding = http.WithStructuredEncoding WithPort = http.WithPort WithPath = http.WithPath + WithMiddleware = http.WithMiddleware // HTTP Context diff --git a/pkg/cloudevents/transport/http/options.go b/pkg/cloudevents/transport/http/options.go index 226bb9255..2a86ed385 100644 --- a/pkg/cloudevents/transport/http/options.go +++ b/pkg/cloudevents/transport/http/options.go @@ -173,3 +173,20 @@ func WithPath(path string) Option { return nil } } + +// Middleware is a function that takes an existing http.Handler and wraps it in middleware, +// returning the wrapped http.Handler. +type Middleware func(next nethttp.Handler) nethttp.Handler + +// WithMiddleware adds an HTTP middleware to the transport. It may be specified multiple times. +// Middleware is applied to everything before it. For example +// `NewClient(WithMiddleware(foo), WithMiddleware(bar))` would result in `bar(foo(original))`. +func WithMiddleware(middleware Middleware) Option { + return func (t *Transport) error { + if t == nil { + return fmt.Errorf("http middleware option can not set nil transport") + } + t.middleware = append(t.middleware, middleware) + return nil + } +} diff --git a/pkg/cloudevents/transport/http/options_test.go b/pkg/cloudevents/transport/http/options_test.go index 6300e04cc..9f39fa61b 100644 --- a/pkg/cloudevents/transport/http/options_test.go +++ b/pkg/cloudevents/transport/http/options_test.go @@ -629,3 +629,31 @@ func TestWithStructuredEncoding(t *testing.T) { }) } } + +func TestWithMiddleware(t *testing.T) { + testCases := map[string]struct{ + t *Transport + wantErr string + }{ + "nil transport": { + wantErr: "http middleware option can not set nil transport", + }, + "non-nil transport": { + t: &Transport{}, + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + err := tc.t.applyOptions(WithMiddleware(func(next http.Handler) http.Handler { + return next + })) + if tc.wantErr != "" { + if err == nil || err.Error() != tc.wantErr { + t.Fatalf("Expected error '%s'. Actual '%v'", tc.wantErr, err) + } + } else if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} diff --git a/pkg/cloudevents/transport/http/transport.go b/pkg/cloudevents/transport/http/transport.go index 579149fe7..8434e8254 100644 --- a/pkg/cloudevents/transport/http/transport.go +++ b/pkg/cloudevents/transport/http/transport.go @@ -73,6 +73,8 @@ type Transport struct { crMu sync.Mutex // Receive Mutex reMu sync.Mutex + + middleware []Middleware } func New(opts ...Option) (*Transport, error) { @@ -237,7 +239,7 @@ func (t *Transport) StartReceiver(ctx context.Context) error { addr := fmt.Sprintf(":%d", t.GetPort()) t.server = &http.Server{ Addr: addr, - Handler: t.Handler, + Handler: attachMiddleware(t.Handler, t.middleware), } listener, err := net.Listen("tcp", addr) @@ -274,6 +276,14 @@ func (t *Transport) StartReceiver(ctx context.Context) error { } } +// attachMiddleware attaches the HTTP middleware to the specified handler. +func attachMiddleware(h http.Handler, middleware []Middleware) http.Handler { + for _, m := range middleware { + h = m(h) + } + return h +} + type eventError struct { event *cloudevents.Event err error diff --git a/pkg/cloudevents/transport/http/transport_test.go b/pkg/cloudevents/transport/http/transport_test.go index afb8b3bde..e9ac771e7 100644 --- a/pkg/cloudevents/transport/http/transport_test.go +++ b/pkg/cloudevents/transport/http/transport_test.go @@ -1,10 +1,13 @@ package http_test import ( + "bytes" "context" "fmt" + "io/ioutil" "net" "net/http" + "strings" "sync/atomic" "testing" "time" @@ -129,3 +132,93 @@ func TestStableConnectionsToSingleHost(t *testing.T) { } t.Log("sent ", sent) } + +func TestMiddleware(t *testing.T) { + testCases := map[string]struct { + middleware []string + want string + } { + "none": {}, + "one": { + middleware: []string{ "Foo" }, + }, + "nested": { + middleware: []string{ "Foo", "Bar", "Qux" }, + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + m := make([]cehttp.Option, 0, len(tc.middleware) + 2) + m = append(m, cehttp.WithPort(0), cehttp.WithShutdownTimeout(time.Nanosecond)) + for _, ms := range tc.middleware { + ms := ms + m = append(m, cehttp.WithMiddleware(func(next http.Handler) http.Handler { + return &namedHandler{ + name: ms, + next: next, + } + })) + } + tr, err := cehttp.New(m...) + if err != nil { + t.Fatalf("Unable to create transport, %v", err) + } + innermostResponse := "Original" + origResponse := makeRequestToServer(t, tr, innermostResponse) + + // Verify that the response is all the middlewares run in the correct order (as a stack). + response := string(origResponse) + for i := len(tc.middleware) - 1; i >= 0; i-- { + expected := tc.middleware[i] + if !strings.HasPrefix(response, expected) { + t.Fatalf("Incorrect prefix at offset %d. Expected %s. Actual %s", i, tc.middleware[i], string(origResponse)) + } + response = strings.TrimPrefix(response, expected) + } + if response != innermostResponse { + t.Fatalf("Incorrect prefix at last offset. Expected '%s'. Actual %s", innermostResponse, string(origResponse)) + } + }) + } +} + +// makeRequestToServer starts the transport and makes a request to it, pointing at a custom path that will return +// responseText. +func makeRequestToServer(t *testing.T, tr *cehttp.Transport, responseText string) string { + // Create a custom path that will be used to respond with responseText. + tr.Handler = http.NewServeMux() + path := "/123" + tr.Handler.HandleFunc(path, func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(responseText)) + }) + + // Start the server. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go tr.StartReceiver(ctx) + + // Give some time for the receiver to start. One second was chosen arbitrarily. + time.Sleep(time.Second) + + // Make the request. + port := tr.GetPort() + r, err := http.Post(fmt.Sprintf("http://localhost:%d%s", port, path), "text", &bytes.Buffer{}) + if err != nil { + t.Fatalf("Error posting: %v", err) + } + rb, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + return string(rb) +} + +type namedHandler struct { + name string + next http.Handler +} + +func (h *namedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(h.name)) + h.next.ServeHTTP(w, r) +}