diff --git a/.changelog/121053e34770402bb72d4cf2c1057b2e.json b/.changelog/121053e34770402bb72d4cf2c1057b2e.json new file mode 100644 index 00000000000..6c73f9c1aa2 --- /dev/null +++ b/.changelog/121053e34770402bb72d4cf2c1057b2e.json @@ -0,0 +1,8 @@ +{ + "id": "121053e3-4770-402b-b72d-4cf2c1057b2e", + "type": "feature", + "description": "Add `middleware.WithHeaderComment` API, which explicitly re-adds behavior that was previously unintentially possible through `middleware.AddUserAgentKey`.", + "modules": [ + "aws/middleware" + ] +} diff --git a/aws/middleware/header.go b/aws/middleware/header.go new file mode 100644 index 00000000000..c826cfe0870 --- /dev/null +++ b/aws/middleware/header.go @@ -0,0 +1,82 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +// WithHeaderComment instruments a middleware stack to append an HTTP field +// comment to the given header as specified in RFC 9110 +// (https://www.rfc-editor.org/rfc/rfc9110#name-comments). +// +// The header is case-insensitive. If the provided header exists when the +// middleware runs, the content will be inserted as-is enclosed in parentheses. +// +// Note that per the HTTP specification, comments are only allowed in fields +// containing "comment" as part of their field value definition, but this API +// will NOT verify whether the provided header is one of them. +// +// WithHeaderComment MAY be applied more than once to a middleware stack and/or +// more than once per header. +func WithHeaderComment(header, content string) func(*middleware.Stack) error { + return func(s *middleware.Stack) error { + m, err := getOrAddHeaderComment(s) + if err != nil { + return err + } + + m.values.Add(header, content) + return nil + } +} + +type headerCommentMiddleware struct { + values http.Header // hijack case-insensitive access APIs +} + +func (*headerCommentMiddleware) ID() string { + return "headerComment" +} + +func (m *headerCommentMiddleware) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + r, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unknown transport type %T", in.Request) + } + + for h, contents := range m.values { + for _, c := range contents { + if existing := r.Header.Get(h); existing != "" { + r.Header.Set(h, fmt.Sprintf("%s (%s)", existing, c)) + } + } + } + + return next.HandleBuild(ctx, in) +} + +func getOrAddHeaderComment(s *middleware.Stack) (*headerCommentMiddleware, error) { + id := (*headerCommentMiddleware)(nil).ID() + m, ok := s.Build.Get(id) + if !ok { + m := &headerCommentMiddleware{values: http.Header{}} + if err := s.Build.Add(m, middleware.After); err != nil { + return nil, err + } + + return m, nil + } + + hc, ok := m.(*headerCommentMiddleware) + if !ok { + return nil, fmt.Errorf("existing middleware w/ id %s is not *headerCommentMiddleware", id) + } + + return hc, nil +} diff --git a/aws/middleware/header_test.go b/aws/middleware/header_test.go new file mode 100644 index 00000000000..c7c4257e689 --- /dev/null +++ b/aws/middleware/header_test.go @@ -0,0 +1,114 @@ +package middleware + +import ( + "context" + "net/http" + "testing" + + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" +) + +func TestWithHeaderComment_CaseInsensitive(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "bar"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + r.Header.Set("Foo", "baz") + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "baz (bar)") +} + +func TestWithHeaderComment_Noop(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "bar"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "") +} + +func TestWithHeaderComment_MultiCaseInsensitive(t *testing.T) { + stack, err := newTestStack( + WithHeaderComment("foo", "c1"), + WithHeaderComment("Foo", "c2"), + WithHeaderComment("baz", "c3"), + WithHeaderComment("Baz", "c4"), + ) + if err != nil { + t.Errorf("expected no error on new stack, got %v", err) + } + + r := injectBuildRequest(stack) + r.Header.Set("Foo", "1") + r.Header.Set("Baz", "2") + + if err := handle(stack); err != nil { + t.Errorf("expected no error on handle, got %v", err) + } + + expectHeader(t, r.Header, "Foo", "1 (c1) (c2)") + expectHeader(t, r.Header, "Baz", "2 (c3) (c4)") +} + +func newTestStack(fns ...func(*middleware.Stack) error) (*middleware.Stack, error) { + s := middleware.NewStack("", smithyhttp.NewStackRequest) + for _, fn := range fns { + if err := fn(s); err != nil { + return nil, err + } + } + return s, nil +} + +func handle(stack *middleware.Stack) error { + _, _, err := middleware.DecorateHandler( + middleware.HandlerFunc( + func(ctx context.Context, input interface{}) ( + interface{}, middleware.Metadata, error, + ) { + return nil, middleware.Metadata{}, nil + }, + ), + stack, + ).Handle(context.Background(), nil) + return err +} + +func injectBuildRequest(s *middleware.Stack) *smithyhttp.Request { + r := smithyhttp.NewStackRequest() + s.Build.Add( + middleware.BuildMiddlewareFunc( + "injectBuildRequest", + func(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) ( + middleware.BuildOutput, middleware.Metadata, error, + ) { + return next.HandleBuild(ctx, middleware.BuildInput{Request: r}) + }, + ), + middleware.Before, + ) + return r.(*smithyhttp.Request) +} + +func expectHeader(t *testing.T, header http.Header, h, ev string) { + if av := header.Get(h); ev != av { + t.Errorf("expected header '%s: %s', got '%s'", h, ev, av) + } +}