Skip to content

Commit

Permalink
Merge branch 'main' into emcfarlane/handler-optional
Browse files Browse the repository at this point in the history
  • Loading branch information
akshayjshah authored Jul 13, 2023
2 parents 11df7c4 + e3fcf6f commit 4ef063d
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ _not_ make breaking changes in the 1.x series of releases.
Offered under the [Apache 2 license][license].

[APIv2]: https://blog.golang.org/protobuf-apiv2
[Buf Studio]: https://studio.buf.build/
[Buf Studio]: https://buf.build/studio
[Getting Started]: https://connect.build/docs/go/getting-started
[blog]: https://buf.build/blog/connect-a-better-grpc
[connect-crosstest]: https://github.com/bufbuild/connect-crosstest
Expand Down
178 changes: 156 additions & 22 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ func TestGRPCMissingTrailersError(t *testing.T) {
assert.Equal(t, connectErr.Code(), connect.CodeInternal)
assert.True(
t,
strings.HasSuffix(connectErr.Message(), "gRPC protocol error: no Grpc-Status trailer"),
strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"),
)
}

Expand Down Expand Up @@ -2099,51 +2099,181 @@ func TestStreamUnexpectedEOF(t *testing.T) {
return
}
_, _ = io.Copy(io.Discard, request.Body)
header := responseWriter.Header()
header.Set("Content-Type", "application/connect+json")
testcase(responseWriter, request)
})
server := httptest.NewUnstartedServer(mux)
server.EnableHTTP2 = true
server.StartTLS()
t.Cleanup(server.Close)

client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
connect.WithProtoJSON(),
)
head := [5]byte{}
payload := []byte(`{"number": 42}`)
binary.BigEndian.PutUint32(head[1:], uint32(len(payload)))
testcases := []struct {
name string
handler http.HandlerFunc
options []connect.ClientOption
expectCode connect.Code
expectMsg string
}{{
name: "stream_unexpected_eof",
handler: func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = responseWriter.Write(head[:])
name: "connect_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/connect+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: unexpected EOF",
}, {
name: "grpc_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "grpc-web_missing_end",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, _ = responseWriter.Write(payload)
assert.Nil(t, err)
},
expectCode: connect.CodeUnknown,
expectMsg: "unknown: unexpected EOF",
expectCode: connect.CodeInternal,
expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF",
}, {
name: "stream_partial_payload",
handler: func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = responseWriter.Write(head[:])
_, _ = responseWriter.Write(payload[:len(payload)-1])
name: "connect_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/connect+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload[:len(payload)-1])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1),
}, {
name: "grpc_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload[:len(payload)-1])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1),
}, {
name: "stream_partial_frame",
handler: func(responseWriter http.ResponseWriter, request *http.Request) {
_, _ = responseWriter.Write(head[:4])
name: "grpc-web_partial_payload",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload[:len(payload)-1])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1),
}, {
name: "connect_partial_frame",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/connect+json")
_, err := responseWriter.Write(head[:4])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF",
}, {
name: "grpc_partial_frame",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc+json")
_, err := responseWriter.Write(head[:4])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF",
}, {
name: "grpc-web_partial_frame",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
header := responseWriter.Header()
header.Set("Content-Type", "application/grpc-web+json")
_, err := responseWriter.Write(head[:4])
assert.Nil(t, err)
},
expectCode: connect.CodeInvalidArgument,
expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF",
}, {
name: "connect_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Write EOF
_, err = responseWriter.Write([]byte{1 << 1, 0, 0, 0, 2})
assert.Nil(t, err)
_, err = responseWriter.Write([]byte("{}"))
assert.Nil(t, err)
// Excess payload
_, err = responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)),
}, {
name: "grpc-web_excess_eof",
options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()},
handler: func(responseWriter http.ResponseWriter, _ *http.Request) {
_, err := responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
// Write EOF
var buf bytes.Buffer
trailer := http.Header{"grpc-status": []string{"0"}}
assert.Nil(t, trailer.Write(&buf))
var head [5]byte
head[0] = 1 << 7
binary.BigEndian.PutUint32(head[1:], uint32(buf.Len()))
_, err = responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(buf.Bytes())
assert.Nil(t, err)
// Excess payload
_, err = responseWriter.Write(head[:])
assert.Nil(t, err)
_, err = responseWriter.Write(payload)
assert.Nil(t, err)
},
expectCode: connect.CodeInternal,
expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)),
}}
for _, testcase := range testcases {
testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler
Expand All @@ -2152,17 +2282,21 @@ func TestStreamUnexpectedEOF(t *testing.T) {
testcase := testcase
t.Run(testcase.name, func(t *testing.T) {
t.Parallel()
client := pingv1connect.NewPingServiceClient(
server.Client(),
server.URL,
testcase.options...,
)
const upTo = 2
request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo})
request.Header().Set("Test-Case", t.Name())
stream, err := client.CountUp(context.Background(), request)
assert.Nil(t, err)
for stream.Receive() {
for i := 0; stream.Receive() && i < upTo; i++ {
assert.Equal(t, stream.Msg().Number, 42)
}
assert.NotNil(t, stream.Err())
assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode)
t.Log(stream.Err())
assert.Equal(t, stream.Err().Error(), testcase.expectMsg)
})
}
Expand Down
2 changes: 1 addition & 1 deletion duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (d *duplexHTTPCall) CloseRead() error {
if d.response == nil {
return nil
}
if err := discard(d.response.Body); err != nil {
if _, err := discard(d.response.Body); err != nil {
_ = d.response.Body.Close()
return wrapIfRSTError(err)
}
Expand Down
8 changes: 7 additions & 1 deletion envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
if r.compressionPool == nil {
return errorf(
CodeInvalidArgument,
"gRPC protocol error: sent compressed message without Grpc-Encoding header",
"protocol error: sent compressed message without Grpc-Encoding header",
)
}
decompressed := r.bufferPool.Get()
Expand All @@ -199,6 +199,12 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
}

if env.Flags != 0 && env.Flags != flagEnvelopeCompressed {
// Drain the rest of the stream to ensure there is no extra data.
if n, err := discard(r.reader); err != nil {
return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err)
} else if n > 0 {
return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n)
}
// One of the protocol-specific flags are set, so this is the end of the
// stream. Save the message for protocol-specific code to process and
// return a sentinel error. Since we've deferred functions to return env's
Expand Down
8 changes: 3 additions & 5 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,14 @@ func isCommaOrSpace(c rune) bool {
return c == ',' || c == ' '
}

func discard(reader io.Reader) error {
func discard(reader io.Reader) (int64, error) {
if lr, ok := reader.(*io.LimitedReader); ok {
_, err := io.Copy(io.Discard, lr)
return err
return io.Copy(io.Discard, lr)
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
_, err := io.Copy(io.Discard, lr)
return err
return io.Copy(io.Discard, lr)
}

// negotiateCompression determines and validates the request compression and
Expand Down
2 changes: 1 addition & 1 deletion protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error {
// If the error is EOF but not from a last message, we want to return
// io.ErrUnexpectedEOF instead.
if errors.Is(err, io.EOF) && !errors.Is(err, errSpecialEnvelope) {
err = NewError(CodeUnknown, io.ErrUnexpectedEOF)
err = errorf(CodeInternal, "protocol error: %w", io.ErrUnexpectedEOF)
}
// There's no error in the trailers, so this was probably an error
// converting the bytes to a message, an error reading from the network, or
Expand Down
12 changes: 6 additions & 6 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ var (
grpcAllowedMethods = map[string]struct{}{
http.MethodPost: {},
}
errTrailersWithoutGRPCStatus = fmt.Errorf("gRPC protocol error: no %s trailer", grpcHeaderStatus)
errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF)

// defaultGrpcUserAgent follows
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#user-agents:
Expand Down Expand Up @@ -326,7 +326,7 @@ func (g *grpcClient) NewConn(
} else {
conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header {
// To access HTTP trailers, we need to read the body to EOF.
_ = discard(call)
_, _ = discard(call)
return call.ResponseTrailer()
}
}
Expand Down Expand Up @@ -737,7 +737,7 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error {

code, err := strconv.ParseUint(codeHeader, 10 /* base */, 32 /* bitsize */)
if err != nil {
return errorf(CodeInternal, "gRPC protocol error: invalid error code %q", codeHeader)
return errorf(CodeInternal, "protocol error: invalid error code %q", codeHeader)
}
message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
retErr := NewWireError(Code(code), errors.New(message))
Expand Down Expand Up @@ -769,14 +769,14 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
}
unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]]
if !ok {
return 0, fmt.Errorf("gRPC protocol error: timeout %q has invalid unit", timeout)
return 0, fmt.Errorf("protocol error: timeout %q has invalid unit", timeout)
}
num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */)
if err != nil || num < 0 {
return 0, fmt.Errorf("gRPC protocol error: invalid timeout %q", timeout)
return 0, fmt.Errorf("protocol error: invalid timeout %q", timeout)
}
if num > 99999999 { // timeout must be ASCII string of at most 8 digits
return 0, fmt.Errorf("gRPC protocol error: timeout %q is too long", timeout)
return 0, fmt.Errorf("protocol error: timeout %q is too long", timeout)
}
if unit == time.Hour && num > grpcTimeoutMaxHours {
// Timeout is effectively unbounded, so ignore it. The grpc-go
Expand Down

0 comments on commit 4ef063d

Please sign in to comment.