Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify and Merge protocol test request unit tests codegen logic #447

Merged
merged 5 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public final class SmithyGoDependency {
public static final GoDependency SMITHY_TRANSPORT = smithy("transport", "smithytransport");
public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp");
public static final GoDependency SMITHY_MIDDLEWARE = smithy("middleware");
public static final GoDependency SMITHY_PRIVATE_PROTOCOL = smithy("private/protocol", "smithyprivateprotocol");
public static final GoDependency SMITHY_TIME = smithy("time", "smithytime");
public static final GoDependency SMITHY_HTTP_BINDING = smithy("encoding/httpbinding");
public static final GoDependency SMITHY_JSON = smithy("encoding/json", "smithyjson");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.function.Consumer;
import java.util.logging.Logger;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
Expand Down Expand Up @@ -196,7 +197,7 @@ protected void generateTestCaseValues(GoWriter writer, HttpRequestTestCase testC
* @param writer writer to write generated code with.
*/
protected void generateTestBodySetup(GoWriter writer) {
writer.write("var actualReq *http.Request");
writer.write("actualReq := &http.Request{}");
}

/**
Expand All @@ -205,26 +206,6 @@ protected void generateTestBodySetup(GoWriter writer) {
* @param writer writer to write generated code with.
*/
protected void generateTestServerHandler(GoWriter writer) {
writer.write("actualReq = r.Clone(r.Context())");
// Go does not set RawPath on http server if nothing is escaped
writer.openBlock("if len(actualReq.URL.RawPath) == 0 {", "}", () -> {
writer.write("actualReq.URL.RawPath = actualReq.URL.Path");
});
// Go automatically removes Content-Length header setting it to the member.
writer.addUseImports(SmithyGoDependency.STRCONV);
writer.openBlock("if v := actualReq.ContentLength; v != 0 {", "}", () -> {
writer.write("actualReq.Header.Set(\"Content-Length\", strconv.FormatInt(v, 10))");
});

writer.addUseImports(SmithyGoDependency.BYTES);
writer.write("var buf bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&buf, r.Body); err != nil {", "}", () -> {
writer.write("t.Errorf(\"failed to read request body, %v\", err)");
});
writer.addUseImports(SmithyGoDependency.IOUTIL);
writer.write("actualReq.Body = ioutil.NopCloser(&buf)");
writer.write("");

super.generateTestServerHandler(writer);
}

Expand All @@ -236,8 +217,18 @@ protected void generateTestServerHandler(GoWriter writer) {
*/
@Override
protected void generateTestInvokeClientOperation(GoWriter writer, String clientName) {
Symbol stackSymbol = SymbolUtils.createPointableSymbolBuilder("Stack",
SmithyGoDependency.SMITHY_MIDDLEWARE).build();
writer.addUseImports(SmithyGoDependency.CONTEXT);
writer.write("result, err := $L.$T(context.Background(), c.Params)", clientName, opSymbol);
writer.openBlock("result, err := $L.$T(context.Background(), c.Params, func(options *Options) {", "})",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in the future, we should avoid using a series of openBlocks like this, and just rely on gofmt to do all the formatting. see something like this
https://github.com/aws/smithy-go/blob/main/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/endpoints/EndpointMiddlewareGenerator.java#L266

clientName, opSymbol, () -> {
writer.openBlock("options.APIOptions = append(options.APIOptions, func(stack $P) error {", "})",
stackSymbol, () -> {
writer.write("return $T(stack, actualReq)",
SymbolUtils.createValueSymbolBuilder("AddCaptureRequestMiddleware",
SmithyGoDependency.SMITHY_PRIVATE_PROTOCOL).build());
});
});
}

/**
Expand All @@ -254,6 +245,7 @@ protected void generateTestAssertions(GoWriter writer) {
writeAssertScalarEqual(writer, "c.ExpectURIPath", "actualReq.URL.RawPath", "path");

writeQueryItemBreakout(writer, "actualReq.URL.RawQuery", "queryItems");

writeAssertHasQuery(writer, "c.ExpectQuery", "queryItems");
writeAssertRequireQuery(writer, "c.RequireQuery", "queryItems");
writeAssertForbidQuery(writer, "c.ForbidQuery", "queryItems");
Expand Down Expand Up @@ -282,7 +274,8 @@ protected void generateTestServer(
String name,
Consumer<GoWriter> handler
) {
super.generateTestServer(writer, name, handler);
// We aren't using a test server, but we do need a URL to set.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: actually this comment doesnt do a lot of good when its generated. because the generated code wont have this comment. if we can, we should make the URL value even more obvious that its a placeholder (while still being parse-able by whatever needs it). for example, would http://thisisa.placeholder:4242 work?

writer.write("serverURL := \"http://localhost:8888/\"");
writer.pushState();
writer.putContext("parse", SymbolUtils.createValueSymbolBuilder("Parse", SmithyGoDependency.NET_URL)
.build());
Expand Down
47 changes: 47 additions & 0 deletions private/protocol/middleware_capture_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package protocol

import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"net/http"
"strconv"
)

const captureRequestID = "CaptureProtocolTestRequest"

// AddCaptureRequestMiddleware captures serialized http request during protocol test for check
func AddCaptureRequestMiddleware(stack *middleware.Stack, req *http.Request) error {
return stack.Build.Add(&captureRequestMiddleware{
req: req,
}, middleware.After)
}

type captureRequestMiddleware struct {
req *http.Request
}

func (*captureRequestMiddleware) ID() string {
return captureRequestID
}

func (m *captureRequestMiddleware) HandleBuild(ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
) (
output middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
request, ok := input.Request.(*smithyhttp.Request)
if !ok {
return output, metadata, fmt.Errorf("error while retrieving http request")
}

*m.req = *request.Build(ctx)
if len(m.req.URL.RawPath) == 0 {
m.req.URL.RawPath = m.req.URL.Path
}
if v := m.req.ContentLength; v != 0 {
m.req.Header.Set("Content-Length", strconv.FormatInt(v, 10))
}

return next.HandleBuild(ctx, input)
}
115 changes: 115 additions & 0 deletions private/protocol/middleware_capture_request_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package protocol

import (
"context"
"github.com/aws/smithy-go/middleware"
smithytesting "github.com/aws/smithy-go/testing"
smithyhttp "github.com/aws/smithy-go/transport/http"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"testing"
)

// TestAddCaptureRequestMiddleware tests AddCaptureRequestMiddleware
func TestAddCaptureRequestMiddleware(t *testing.T) {
cases := map[string]struct {
Request *http.Request
ExpectRequest *http.Request
ExpectQuery []smithytesting.QueryItem
Stream io.Reader
}{
"normal request": {
Request: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
},
URL: &url.URL{
Path: "test/path",
RawQuery: "language=us&region=us-west+east",
},
ContentLength: 100,
},
ExpectRequest: &http.Request{
Method: "PUT",
Header: map[string][]string{
"Foo": {"bar", "too"},
"Checksum": {"SHA256"},
"Content-Length": {"100"},
},
URL: &url.URL{
Path: "test/path",
RawPath: "test/path",
},
Body: ioutil.NopCloser(strings.NewReader("hello world.")),
},
ExpectQuery: []smithytesting.QueryItem{
{
Key: "language",
Value: "us",
},
{
Key: "region",
Value: "us-west%20east",
},
},
Stream: strings.NewReader("hello world."),
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
var err error
req := &smithyhttp.Request{
Request: c.Request,
}
if c.Stream != nil {
req, err = req.SetStream(c.Stream)
if err != nil {
t.Fatalf("Got error while retrieving case stream: %v", err)
}
}
capturedRequest := &http.Request{}
m := captureRequestMiddleware{
req: capturedRequest,
}
_, _, err = m.HandleBuild(context.Background(),
middleware.BuildInput{Request: req},
middleware.BuildHandlerFunc(func(ctx context.Context, input middleware.BuildInput) (
out middleware.BuildOutput, metadata middleware.Metadata, err error) {
return out, metadata, nil
}),
)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.ExpectRequest.Method, capturedRequest.Method; e != a {
t.Errorf("expect request method %v found, got %v", e, a)
}
if e, a := c.ExpectRequest.URL.Path, capturedRequest.URL.RawPath; e != a {
t.Errorf("expect %v path, got %v", e, a)
}
if c.ExpectRequest.Body != nil {
expect, err := ioutil.ReadAll(c.ExpectRequest.Body)
if capturedRequest.Body == nil {
t.Errorf("Expect request stream %v captured, get nil", string(expect))
}
actual, err := ioutil.ReadAll(capturedRequest.Body)
if err != nil {
t.Errorf("unable to read captured request body, %v", err)
}
if e, a := string(expect), string(actual); e != a {
t.Errorf("expect request body to be %s, got %s", e, a)
}
}
queryItems := smithytesting.ParseRawQuery(capturedRequest.URL.RawQuery)
smithytesting.AssertHasQuery(t, c.ExpectQuery, queryItems)
smithytesting.AssertHasHeader(t, c.ExpectRequest.Header, capturedRequest.Header)
})
}
}
Loading