Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add awsJson10 Error Deserialization support. #535

Merged
merged 8 commits into from
Sep 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,29 @@
package software.amazon.smithy.go.codegen.protocol.aws;

import static software.amazon.smithy.go.codegen.ApplicationProtocol.createDefaultHttpApplicationProtocol;
import static software.amazon.smithy.go.codegen.GoWriter.emptyGoTemplate;
import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.protocol.ProtocolUtil.GET_AWS_QUERY_ERROR_CODE;
import static software.amazon.smithy.go.codegen.serde.SerdeUtil.getShapesToSerde;
import static software.amazon.smithy.go.codegen.server.protocol.JsonDeserializerGenerator.getDeserializerName;

import java.util.HashSet;
import java.util.Set;
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait;
import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait;
import software.amazon.smithy.go.codegen.ApplicationProtocol;
import software.amazon.smithy.go.codegen.GoStdlibTypes;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SmithyGoTypes;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.go.codegen.server.protocol.JsonDeserializerGenerator;
import software.amazon.smithy.go.codegen.server.protocol.JsonSerializerGenerator;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SmithyInternalApi;

@SmithyInternalApi
Expand Down Expand Up @@ -77,6 +87,7 @@ public void generateResponseDeserializers(GenerationContext context) {
writer.write("\n");
}
generateSharedDeserializers(context, writer, ops);
generateErrorDeserializers(context, ops);
}

private void generateSharedDeserializers(GenerationContext context, GoWriter writer, Set<OperationShape> ops) {
Expand All @@ -86,10 +97,161 @@ private void generateSharedDeserializers(GenerationContext context, GoWriter wri
op.getOutputShape()));
shared.addAll(shapes);
}
var errorShapes = generateErrorDeserializers(context, ops);
shared.addAll(errorShapes);

var generator = new JsonDeserializerGenerator(context.getModel(), context.getSymbolProvider());
writer.write(generator.generate(shared));

generateOperationErrorDeserializers(context, writer, ops);

writer.write(getProtocolErrorInfo());

if (context.getService().hasTrait(AwsQueryCompatibleTrait.class)) {
writer.write(GET_AWS_QUERY_ERROR_CODE);
}
}

private Set<Shape> generateErrorDeserializers(GenerationContext context, Set<OperationShape> ops) {
Set<Shape> errorShapes = new HashSet<>();
nathanhit marked this conversation as resolved.
Show resolved Hide resolved
for (var op : ops) {
var errors = op.getErrors();
for (var error : errors) {
Set<Shape> shapes = getShapesToSerde(context.getModel(), context.getModel().expectShape(error));
errorShapes.addAll(shapes);
}
}
return errorShapes;
}

private void generateOperationErrorDeserializers(
GenerationContext context, GoWriter writer, Set<OperationShape> operations) {
for (var operation : operations) {
var errors = context.getService().getErrors()
.stream()
.map(it -> deserializeErrorCase(context, context.getModel().expectShape(it, StructureShape.class)))
.toList();
writer.write(goTemplate("""
func $func:L(resp $smithyhttpResponse:P) error {
payload, err := $readAll:T(resp.Body)
if err != nil {
return &$deserError:T{Err: $fmtErrorf:T("read response body: %w", err)}
}

typ, msg, v, err := getProtocolErrorInfo(payload)
if err != nil {
return &$deserError:T{Err: $fmtErrorf:T("get error info: %w", err)}
}

if len(typ) == 0 {
typ = "UnknownError"
}
if len(msg) == 0 {
msg = "UnknownError"
}

_ = v
switch typ {
$errors:W
default:
$awsQueryCompatible:W
return &$genericAPIError:T{Code: typ, Message: msg}
}
}
""",
MapUtils.of(
"deserError", SmithyGoDependency.SMITHY.pointableSymbol("DeserializationError"),
"fmtErrorf", GoStdlibTypes.Fmt.Errorf,
"func", ProtocolGenerator.getOperationErrorDeserFunctionName(operation,
context.getService(), "awsJson10"),
"genericAPIError", SmithyGoDependency.SMITHY.pointableSymbol("GenericAPIError"),
"readAll", SmithyGoDependency.IO.func("ReadAll"),
"smithyhttpResponse", SmithyGoTypes.Transport.Http.Response,
"awsQueryCompatible", context.getService().hasTrait(AwsQueryCompatibleTrait.class)
? deserializeAwsQueryError()
: emptyGoTemplate(),
"errors", GoWriter.ChainWritable.of(errors).compose(false)
)));
}
}

private GoWriter.Writable deserializeErrorCase(GenerationContext ctx, StructureShape error) {
return goTemplate("""
case $type:S:
verr, err := $deserialize:L(v)
if err != nil {
return &$deserError:T{
Err: $fmtErrorf:T("deserialize $type:L: %w", err),
Snapshot: payload,
}
}
$awsQueryCompatible:W
return verr
""",
MapUtils.of(
"deserError", SmithyGoDependency.SMITHY.pointableSymbol("DeserializationError"),
"deserialize", getDeserializerName(error),
"equalFold", SmithyGoDependency.STRINGS.func("EqualFold"),
"fmtErrorf", GoStdlibTypes.Fmt.Errorf,
"type", error.getId().toString(),
"awsQueryCompatible", ctx.getService().hasTrait(AwsQueryCompatibleTrait.class)
? deserializeModeledAwsQueryError()
: emptyGoTemplate()
));
}

private GoWriter.Writable deserializeAwsQueryError() {
return goTemplate("""
if qtype := getAwsQueryErrorCode(resp); len(qt) > 0 {
typ = qtype
}""");
}

private GoWriter.Writable deserializeModeledAwsQueryError() {
return goTemplate("""
if qtype := getAwsQueryErrorCode(resp); len(qt) > 0 {
verr.ErrorCodeOverride = $T(qtype)
}""", SmithyGoTypes.Ptr.String);
}

private GoWriter.Writable getProtocolErrorInfo() {
return goTemplate("""
func getProtocolErrorInfo(payload []byte) (typ, msg string, v $value:T, err error) {

paid := $reader:T(payload)
jsonDecoder := $decoder:T(paid)
var val interface{}
var jv map[string]interface{}

jsonDecoder.Decode(&val)
nathanhit marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return "", "", val.($value:T), $fmtErrorf:T("decode: %w", err)
}

err = jsonDecoder.Decode(&jv)
if err != nil {
return "", "", val.($value:T), $fmtErrorf:T("decode: %w", err)
}

if jtyp, ok := jv["__type"]; ok {
typ = jtyp.(string)
}

if jmsg, ok := jv["message"]; ok {
msg = jmsg.(string)
}

return typ, msg, val.($value:T), nil
}
""",
MapUtils.of(
"fmtErrorf", GoStdlibTypes.Fmt.Errorf,
"decoder", GoStdlibTypes.Encoding.Json.NewDecoder,
"value", SmithyGoTypes.Encoding.Json.Value,
"reader", GoStdlibTypes.Bytes.NewReader
));
}

@Override
public void generateProtocolDocumentMarshalerMarshalDocument(GenerationContext context) {
// TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static software.amazon.smithy.go.codegen.GoWriter.goTemplate;
import static software.amazon.smithy.go.codegen.SmithyGoDependency.SMITHY_HTTP_TRANSPORT;
import static software.amazon.smithy.go.codegen.integration.ProtocolGenerator.getOperationErrorDeserFunctionName;
import static software.amazon.smithy.go.codegen.protocol.ProtocolUtil.hasEventStream;
import static software.amazon.smithy.go.codegen.server.protocol.JsonDeserializerGenerator.getDeserializerName;

Expand Down Expand Up @@ -125,14 +126,15 @@ private GoWriter.Writable handleResponseChecks() {
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return out, metadata, &$deserError:T{}
return out, metadata, $errorDeserialized:L(resp)
}

""",
MapUtils.of(
"response", SMITHY_HTTP_TRANSPORT.pointableSymbol("Response"),
"errorf", GoStdlibTypes.Fmt.Errorf,
"deserError", SmithyGoDependency.SMITHY.struct("DeserializationError")
"errorDeserialized", getOperationErrorDeserFunctionName(operation, ctx.getService(),
"awsJson10")
));
}

Expand Down
Loading