From 9a47e68e74b84ea0b149e229196744f6e29da574 Mon Sep 17 00:00:00 2001 From: Tianyi Wang Date: Fri, 11 Aug 2023 16:52:16 -0400 Subject: [PATCH 1/4] Modify and Merge protocol test request unit tests codegen logic --- .../HttpProtocolUnitTestRequestGenerator.java | 56 ++++----- transport/http/middleware_capture_request.go | 46 ++++++++ .../http/middleware_capture_request_test.go | 109 ++++++++++++++++++ 3 files changed, 179 insertions(+), 32 deletions(-) create mode 100644 transport/http/middleware_capture_request.go create mode 100644 transport/http/middleware_capture_request_test.go diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 97bdcb1b4..d1f738ab6 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import java.util.logging.Logger; +import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.go.codegen.GoWriter; import software.amazon.smithy.go.codegen.SmithyGoDependency; import software.amazon.smithy.go.codegen.SymbolUtils; @@ -195,9 +196,7 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC * * @param writer writer to write generated code with. */ - protected void generateTestBodySetup(GoWriter writer) { - writer.write("var actualReq *http.Request"); - } + protected void generateTestBodySetup(GoWriter writer) {} /** * Hook to generate the HTTP response body of the protocol test. @@ -205,26 +204,6 @@ protected void generateTestBodySetup(GoWriter writer) { * @param writer writer to write generated code with. */ protected void generateTestServerHandler(GoWriter writer) { - writer.write("actualReq = r.Clone(r.Context())"); - // Go does not set RawPath on http server if nothing is escaped - writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> { - writer.write("actualReq.URL.RawPath = actualReq.URL.Path"); - }); - // Go automatically removes Content-Length header setting it to the member. - writer.addUseImports(SmithyGoDependency.STRCONV); - writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> { - writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))"); - }); - - writer.addUseImports(SmithyGoDependency.BYTES); - writer.write("var buf bytes.Buffer"); - writer.openBlock("if _, err := io.Copy(&buf, r.Body); err != nil {", "}", () -> { - writer.write("t.Errorf(\"failed to read request body, %v\", err)"); - }); - writer.addUseImports(SmithyGoDependency.IOUTIL); - writer.write("actualReq.Body = ioutil.NopCloser(&buf)"); - writer.write(""); - super.generateTestServerHandler(writer); } @@ -236,8 +215,19 @@ protected void generateTestServerHandler(GoWriter writer) { */ @Override protected void generateTestInvokeClientOperation(GoWriter writer, String clientName) { + Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", + SmithyGoDependency.SMITHY_MIDDLEWARE).build(); writer.addUseImports(SmithyGoDependency.CONTEXT); - writer.write("result, err := $L.$T(context.Background(), c.Params)", clientName, opSymbol); + writer.write("capturedReq := &http.Request{}"); + writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})", + clientName, opSymbol, () -> { + writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})", + stackSymbol, () -> { + writer.write("return $T(stack, capturedReq)", + SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware", + SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build()); + }); + }); } /** @@ -250,20 +240,21 @@ protected void generateTestAssertions(GoWriter writer) { writeAssertNil(writer, "err"); writeAssertNotNil(writer, "result"); - writeAssertScalarEqual(writer, "c.ExpectMethod", "actualReq.Method", "method"); - writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path"); + writeAssertScalarEqual(writer, "c.ExpectMethod", "capturedReq.Method", "method"); + writeAssertScalarEqual(writer, "c.ExpectURIPath", "capturedReq.URL.RawPath", "path"); + + writeQueryItemBreakout(writer, "capturedReq.URL.RawQuery", "queryItems"); - writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems"); writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems"); writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems"); writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems"); - writeAssertHasHeader(writer, "c.ExpectHeader", "actualReq.Header"); - writeAssertRequireHeader(writer, "c.RequireHeader", "actualReq.Header"); - writeAssertForbidHeader(writer, "c.ForbidHeader", "actualReq.Header"); + writeAssertHasHeader(writer, "c.ExpectHeader", "capturedReq.Header"); + writeAssertRequireHeader(writer, "c.RequireHeader", "capturedReq.Header"); + writeAssertForbidHeader(writer, "c.ForbidHeader", "capturedReq.Header"); writer.openBlock("if c.BodyAssert != nil {", "}", () -> { - writer.openBlock("if err := c.BodyAssert(actualReq.Body); err != nil {", "}", () -> { + writer.openBlock("if err := c.BodyAssert(capturedReq.Body); err != nil {", "}", () -> { writer.write("t.Errorf(\"expect body equal, got %v\", err)"); }); }); @@ -282,7 +273,8 @@ protected void generateTestServer( String name, Consumer handler ) { - super.generateTestServer(writer, name, handler); + // We aren't using a test server, but we do need a URL to set. + writer.write("serverURL := \"http://localhost:8888/\""); writer.pushState(); writer.putContext("parse", SymbolUtils.createValueSymbolBuilder("Parse", SmithyGoDependency.NET_URL) .build()); diff --git a/transport/http/middleware_capture_request.go b/transport/http/middleware_capture_request.go new file mode 100644 index 000000000..0d44e72fe --- /dev/null +++ b/transport/http/middleware_capture_request.go @@ -0,0 +1,46 @@ +package http + +import ( + "context" + "fmt" + "github.com/aws/smithy-go/middleware" + "net/http" + "strconv" +) + +const captureRequestID = "CaptureProtocolTestRequest" + +// AddCaptureRequestMiddleware captures serialized http request during protocol test for check +func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error { + return stack.Build.Add(&captureRequestMiddleware{ + req: req, + }, middleware.After) +} + +type captureRequestMiddleware struct { + req *http.Request +} + +func (*captureRequestMiddleware) ID() string { + return captureRequestID +} + +func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler, +) ( + output middleware.BuildOutput, metadata middleware.Metadata, err error, +) { + request, ok := input.Request.(*Request) + if !ok { + return output, metadata, fmt.Errorf("error while retrieving http request") + } + + *m.req = *request.Build(ctx) + if len(m.req.URL.RawPath) == 0 { + m.req.URL.RawPath = m.req.URL.Path + } + if v := m.req.ContentLength; v != 0 { + m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10)) + } + + return next.HandleBuild(ctx, input) +} diff --git a/transport/http/middleware_capture_request_test.go b/transport/http/middleware_capture_request_test.go new file mode 100644 index 000000000..5d9457eb7 --- /dev/null +++ b/transport/http/middleware_capture_request_test.go @@ -0,0 +1,109 @@ +package http + +import ( + "context" + "github.com/aws/smithy-go/middleware" + smithytesting "github.com/aws/smithy-go/testing" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" +) + +// TestAddCaptureRequestMiddleware tests AddCaptureRequestMiddleware +func TestAddCaptureRequestMiddleware(t *testing.T) { + cases := map[string]struct { + Request *http.Request + ExpectRequest *http.Request + ExpectQuery []smithytesting.QueryItem + Stream io.Reader + }{ + "normal request": { + Request: &http.Request{ + Method: "PUT", + Header: map[string][]string{ + "Foo": {"bar", "too"}, + "Checksum": {"SHA256"}, + }, + URL: &url.URL{ + Path: "test/path", + RawQuery: "language=us®ion=us-west+east", + }, + ContentLength: 100, + }, + ExpectRequest: &http.Request{ + Method: "PUT", + Header: map[string][]string{ + "Foo": {"bar", "too"}, + "Checksum": {"SHA256"}, + "Content-Length": {"100"}, + }, + URL: &url.URL{ + Path: "test/path", + RawPath: "test/path", + }, + Body: io.NopCloser(strings.NewReader("hello world.")), + }, + ExpectQuery: []smithytesting.QueryItem{ + { + Key: "language", + Value: "us", + }, + { + Key: "region", + Value: "us-west%20east", + }, + }, + Stream: strings.NewReader("hello world."), + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var err error + req := &Request{ + Request: c.Request, + stream: c.Stream, + } + capturedRequest := &http.Request{} + m := captureRequestMiddleware{ + req: capturedRequest, + } + _, _, err = m.HandleBuild(context.Background(), + middleware.BuildInput{Request: req}, + middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) ( + out middleware.BuildOutput, metadata middleware.Metadata, err error) { + return out, metadata, nil + }), + ) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := c.ExpectRequest.Method, capturedRequest.Method; e != a { + t.Errorf("expect request method %v found, got %v", e, a) + } + if e, a := c.ExpectRequest.URL.Path, capturedRequest.URL.RawPath; e != a { + t.Errorf("expect %v path, got %v", e, a) + } + if c.ExpectRequest.Body != nil { + expect, err := io.ReadAll(c.ExpectRequest.Body) + if capturedRequest.Body == nil { + t.Errorf("Expect request stream %v captured, get nil", string(expect)) + } + actual, err := ioutil.ReadAll(capturedRequest.Body) + if err != nil { + t.Errorf("unable to read captured request body, %v", err) + } + if e, a := string(expect), string(actual); e != a { + t.Errorf("expect request body to be %s, got %s", e, a) + } + } + queryItems := smithytesting.ParseRawQuery(capturedRequest.URL.RawQuery) + smithytesting.AssertHasQuery(t, c.ExpectQuery, queryItems) + smithytesting.AssertHasHeader(t, c.ExpectRequest.Header, capturedRequest.Header) + }) + } +} From ba8995c3b3948b4f1a274f29a99cadd64b16a85d Mon Sep 17 00:00:00 2001 From: Tianyi Wang Date: Fri, 11 Aug 2023 17:07:05 -0400 Subject: [PATCH 2/4] Modify and Merge protocol test generator syntax --- .../integration/HttpProtocolUnitTestRequestGenerator.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index d1f738ab6..1fc76227a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -196,7 +196,8 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC * * @param writer writer to write generated code with. */ - protected void generateTestBodySetup(GoWriter writer) {} + protected void generateTestBodySetup(GoWriter writer) { + } /** * Hook to generate the HTTP response body of the protocol test. From b5640fca08276c8e6af9c1bcea523f180f66691a Mon Sep 17 00:00:00 2001 From: Tianyi Wang Date: Fri, 11 Aug 2023 17:26:28 -0400 Subject: [PATCH 3/4] Modify and Merge some unit test syntax --- transport/http/middleware_capture_request_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transport/http/middleware_capture_request_test.go b/transport/http/middleware_capture_request_test.go index 5d9457eb7..4b0f3d518 100644 --- a/transport/http/middleware_capture_request_test.go +++ b/transport/http/middleware_capture_request_test.go @@ -44,7 +44,7 @@ func TestAddCaptureRequestMiddleware(t *testing.T) { Path: "test/path", RawPath: "test/path", }, - Body: io.NopCloser(strings.NewReader("hello world.")), + Body: ioutil.NopCloser(strings.NewReader("hello world.")), }, ExpectQuery: []smithytesting.QueryItem{ { @@ -89,7 +89,7 @@ func TestAddCaptureRequestMiddleware(t *testing.T) { t.Errorf("expect %v path, got %v", e, a) } if c.ExpectRequest.Body != nil { - expect, err := io.ReadAll(c.ExpectRequest.Body) + expect, err := ioutil.ReadAll(c.ExpectRequest.Body) if capturedRequest.Body == nil { t.Errorf("Expect request stream %v captured, get nil", string(expect)) } From 12e7bce77a40128c4d2016ac70952a436adacd45 Mon Sep 17 00:00:00 2001 From: Tianyi Wang Date: Mon, 14 Aug 2023 17:12:11 -0400 Subject: [PATCH 4/4] Modify and Merge protocol test codegen code --- .../smithy/go/codegen/SmithyGoDependency.java | 1 + .../HttpProtocolUnitTestRequestGenerator.java | 20 +++++++++---------- .../protocol}/middleware_capture_request.go | 5 +++-- .../middleware_capture_request_test.go | 12 ++++++++--- 4 files changed, 23 insertions(+), 15 deletions(-) rename {transport/http => private/protocol}/middleware_capture_request.go (89%) rename {transport/http => private/protocol}/middleware_capture_request_test.go (91%) diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index f62ad9c1d..2d206e909 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -49,6 +49,7 @@ public final class SmithyGoDependency { public static final GoDependency SMITHY_TRANSPORT = smithy("transport", "smithytransport"); public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp"); public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware"); + public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol"); public static final GoDependency SMITHY_TIME = smithy("time", "smithytime"); public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding"); public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java index 1fc76227a..15d4adf6a 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolUnitTestRequestGenerator.java @@ -197,6 +197,7 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC * @param writer writer to write generated code with. */ protected void generateTestBodySetup(GoWriter writer) { + writer.write("actualReq := &http.Request{}"); } /** @@ -219,14 +220,13 @@ protected void generateTestInvokeClientOperation(GoWriter writer, String clientN Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack", SmithyGoDependency.SMITHY_MIDDLEWARE).build(); writer.addUseImports(SmithyGoDependency.CONTEXT); - writer.write("capturedReq := &http.Request{}"); writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})", clientName, opSymbol, () -> { writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})", stackSymbol, () -> { - writer.write("return $T(stack, capturedReq)", + writer.write("return $T(stack, actualReq)", SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware", - SmithyGoDependency.SMITHY_HTTP_TRANSPORT).build()); + SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build()); }); }); } @@ -241,21 +241,21 @@ protected void generateTestAssertions(GoWriter writer) { writeAssertNil(writer, "err"); writeAssertNotNil(writer, "result"); - writeAssertScalarEqual(writer, "c.ExpectMethod", "capturedReq.Method", "method"); - writeAssertScalarEqual(writer, "c.ExpectURIPath", "capturedReq.URL.RawPath", "path"); + writeAssertScalarEqual(writer, "c.ExpectMethod", "actualReq.Method", "method"); + writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path"); - writeQueryItemBreakout(writer, "capturedReq.URL.RawQuery", "queryItems"); + writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems"); writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems"); writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems"); writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems"); - writeAssertHasHeader(writer, "c.ExpectHeader", "capturedReq.Header"); - writeAssertRequireHeader(writer, "c.RequireHeader", "capturedReq.Header"); - writeAssertForbidHeader(writer, "c.ForbidHeader", "capturedReq.Header"); + writeAssertHasHeader(writer, "c.ExpectHeader", "actualReq.Header"); + writeAssertRequireHeader(writer, "c.RequireHeader", "actualReq.Header"); + writeAssertForbidHeader(writer, "c.ForbidHeader", "actualReq.Header"); writer.openBlock("if c.BodyAssert != nil {", "}", () -> { - writer.openBlock("if err := c.BodyAssert(capturedReq.Body); err != nil {", "}", () -> { + writer.openBlock("if err := c.BodyAssert(actualReq.Body); err != nil {", "}", () -> { writer.write("t.Errorf(\"expect body equal, got %v\", err)"); }); }); diff --git a/transport/http/middleware_capture_request.go b/private/protocol/middleware_capture_request.go similarity index 89% rename from transport/http/middleware_capture_request.go rename to private/protocol/middleware_capture_request.go index 0d44e72fe..c812036f0 100644 --- a/transport/http/middleware_capture_request.go +++ b/private/protocol/middleware_capture_request.go @@ -1,9 +1,10 @@ -package http +package protocol import ( "context" "fmt" "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" "net/http" "strconv" ) @@ -29,7 +30,7 @@ func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middle ) ( output middleware.BuildOutput, metadata middleware.Metadata, err error, ) { - request, ok := input.Request.(*Request) + request, ok := input.Request.(*smithyhttp.Request) if !ok { return output, metadata, fmt.Errorf("error while retrieving http request") } diff --git a/transport/http/middleware_capture_request_test.go b/private/protocol/middleware_capture_request_test.go similarity index 91% rename from transport/http/middleware_capture_request_test.go rename to private/protocol/middleware_capture_request_test.go index 4b0f3d518..0579260a3 100644 --- a/transport/http/middleware_capture_request_test.go +++ b/private/protocol/middleware_capture_request_test.go @@ -1,9 +1,10 @@ -package http +package protocol import ( "context" "github.com/aws/smithy-go/middleware" smithytesting "github.com/aws/smithy-go/testing" + smithyhttp "github.com/aws/smithy-go/transport/http" "io" "io/ioutil" "net/http" @@ -63,9 +64,14 @@ func TestAddCaptureRequestMiddleware(t *testing.T) { for name, c := range cases { t.Run(name, func(t *testing.T) { var err error - req := &Request{ + req := &smithyhttp.Request{ Request: c.Request, - stream: c.Stream, + } + if c.Stream != nil { + req, err = req.SetStream(c.Stream) + if err != nil { + t.Fatalf("Got error while retrieving case stream: %v", err) + } } capturedRequest := &http.Request{} m := captureRequestMiddleware{