diff --git a/README.md b/README.md index c551084a..e6eeb507 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ _not_ make breaking changes in the 1.x series of releases. Offered under the [Apache 2 license][license]. [APIv2]: https://blog.golang.org/protobuf-apiv2 -[Buf Studio]: https://studio.buf.build/ +[Buf Studio]: https://buf.build/studio [Getting Started]: https://connect.build/docs/go/getting-started [blog]: https://buf.build/blog/connect-a-better-grpc [connect-crosstest]: https://github.com/bufbuild/connect-crosstest diff --git a/connect_ext_test.go b/connect_ext_test.go index 2c9cdba0..62dfd7d7 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -707,7 +707,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { assert.Equal(t, connectErr.Code(), connect.CodeInternal) assert.True( t, - strings.HasSuffix(connectErr.Message(), "gRPC protocol error: no Grpc-Status trailer"), + strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"), ) } @@ -2099,8 +2099,6 @@ func TestStreamUnexpectedEOF(t *testing.T) { return } _, _ = io.Copy(io.Discard, request.Body) - header := responseWriter.Header() - header.Set("Content-Type", "application/connect+json") testcase(responseWriter, request) }) server := httptest.NewUnstartedServer(mux) @@ -2108,42 +2106,174 @@ func TestStreamUnexpectedEOF(t *testing.T) { server.StartTLS() t.Cleanup(server.Close) - client := pingv1connect.NewPingServiceClient( - server.Client(), - server.URL, - connect.WithProtoJSON(), - ) head := [5]byte{} payload := []byte(`{"number": 42}`) binary.BigEndian.PutUint32(head[1:], uint32(len(payload))) testcases := []struct { name string handler http.HandlerFunc + options []connect.ClientOption expectCode connect.Code expectMsg string }{{ - name: "stream_unexpected_eof", - handler: func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = responseWriter.Write(head[:]) + name: "connect_missing_end", + options: []connect.ClientOption{connect.WithProtoJSON()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/connect+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: "internal: protocol error: unexpected EOF", + }, { + name: "grpc_missing_end", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", + }, { + name: "grpc-web_missing_end", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc-web+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) _, _ = responseWriter.Write(payload) + assert.Nil(t, err) }, - expectCode: connect.CodeUnknown, - expectMsg: "unknown: unexpected EOF", + expectCode: connect.CodeInternal, + expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", }, { - name: "stream_partial_payload", - handler: func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = responseWriter.Write(head[:]) - _, _ = responseWriter.Write(payload[:len(payload)-1]) + name: "connect_partial_payload", + options: []connect.ClientOption{connect.WithProtoJSON()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/connect+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + }, { + name: "grpc_partial_payload", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), }, { - name: "stream_partial_frame", - handler: func(responseWriter http.ResponseWriter, request *http.Request) { - _, _ = responseWriter.Write(head[:4]) + name: "grpc-web_partial_payload", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc-web+json") + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload[:len(payload)-1]) + assert.Nil(t, err) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + }, { + name: "connect_partial_frame", + options: []connect.ClientOption{connect.WithProtoJSON()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/connect+json") + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) }, expectCode: connect.CodeInvalidArgument, expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }, { + name: "grpc_partial_frame", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc+json") + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }, { + name: "grpc-web_partial_frame", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + header := responseWriter.Header() + header.Set("Content-Type", "application/grpc-web+json") + _, err := responseWriter.Write(head[:4]) + assert.Nil(t, err) + }, + expectCode: connect.CodeInvalidArgument, + expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + }, { + name: "connect_excess_eof", + options: []connect.ClientOption{connect.WithProtoJSON()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + // Write EOF + _, err = responseWriter.Write([]byte{1 << 1, 0, 0, 0, 2}) + assert.Nil(t, err) + _, err = responseWriter.Write([]byte("{}")) + assert.Nil(t, err) + // Excess payload + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)), + }, { + name: "grpc-web_excess_eof", + options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, + handler: func(responseWriter http.ResponseWriter, _ *http.Request) { + _, err := responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + // Write EOF + var buf bytes.Buffer + trailer := http.Header{"grpc-status": []string{"0"}} + assert.Nil(t, trailer.Write(&buf)) + var head [5]byte + head[0] = 1 << 7 + binary.BigEndian.PutUint32(head[1:], uint32(buf.Len())) + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(buf.Bytes()) + assert.Nil(t, err) + // Excess payload + _, err = responseWriter.Write(head[:]) + assert.Nil(t, err) + _, err = responseWriter.Write(payload) + assert.Nil(t, err) + }, + expectCode: connect.CodeInternal, + expectMsg: fmt.Sprintf("internal: corrupt response: %d extra bytes after end of stream", len(payload)+len(head)), }} for _, testcase := range testcases { testcaseMux[t.Name()+"/"+testcase.name] = testcase.handler @@ -2152,17 +2282,21 @@ func TestStreamUnexpectedEOF(t *testing.T) { testcase := testcase t.Run(testcase.name, func(t *testing.T) { t.Parallel() + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + testcase.options..., + ) const upTo = 2 request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) request.Header().Set("Test-Case", t.Name()) stream, err := client.CountUp(context.Background(), request) assert.Nil(t, err) - for stream.Receive() { + for i := 0; stream.Receive() && i < upTo; i++ { assert.Equal(t, stream.Msg().Number, 42) } assert.NotNil(t, stream.Err()) assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode) - t.Log(stream.Err()) assert.Equal(t, stream.Err().Error(), testcase.expectMsg) }) } diff --git a/duplex_http_call.go b/duplex_http_call.go index 4dac0092..439f55c2 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -179,7 +179,7 @@ func (d *duplexHTTPCall) CloseRead() error { if d.response == nil { return nil } - if err := discard(d.response.Body); err != nil { + if _, err := discard(d.response.Body); err != nil { _ = d.response.Body.Close() return wrapIfRSTError(err) } diff --git a/envelope.go b/envelope.go index 36e21b9a..559ede69 100644 --- a/envelope.go +++ b/envelope.go @@ -187,7 +187,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error { if r.compressionPool == nil { return errorf( CodeInvalidArgument, - "gRPC protocol error: sent compressed message without Grpc-Encoding header", + "protocol error: sent compressed message without Grpc-Encoding header", ) } decompressed := r.bufferPool.Get() @@ -199,6 +199,12 @@ func (r *envelopeReader) Unmarshal(message any) *Error { } if env.Flags != 0 && env.Flags != flagEnvelopeCompressed { + // Drain the rest of the stream to ensure there is no extra data. + if n, err := discard(r.reader); err != nil { + return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err) + } else if n > 0 { + return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n) + } // One of the protocol-specific flags are set, so this is the end of the // stream. Save the message for protocol-specific code to process and // return a sentinel error. Since we've deferred functions to return env's diff --git a/protocol.go b/protocol.go index 8486ca96..a02f24b0 100644 --- a/protocol.go +++ b/protocol.go @@ -283,16 +283,14 @@ func isCommaOrSpace(c rune) bool { return c == ',' || c == ' ' } -func discard(reader io.Reader) error { +func discard(reader io.Reader) (int64, error) { if lr, ok := reader.(*io.LimitedReader); ok { - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. lr := &io.LimitedReader{R: reader, N: discardLimit} - _, err := io.Copy(io.Discard, lr) - return err + return io.Copy(io.Discard, lr) } // negotiateCompression determines and validates the request compression and diff --git a/protocol_connect.go b/protocol_connect.go index 07b06af4..460871a4 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -609,7 +609,7 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { // If the error is EOF but not from a last message, we want to return // io.ErrUnexpectedEOF instead. if errors.Is(err, io.EOF) && !errors.Is(err, errSpecialEnvelope) { - err = NewError(CodeUnknown, io.ErrUnexpectedEOF) + err = errorf(CodeInternal, "protocol error: %w", io.ErrUnexpectedEOF) } // There's no error in the trailers, so this was probably an error // converting the bytes to a message, an error reading from the network, or diff --git a/protocol_grpc.go b/protocol_grpc.go index 62e3355e..7e94eaf6 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -67,7 +67,7 @@ var ( grpcAllowedMethods = map[string]struct{}{ http.MethodPost: {}, } - errTrailersWithoutGRPCStatus = fmt.Errorf("gRPC protocol error: no %s trailer", grpcHeaderStatus) + errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF) // defaultGrpcUserAgent follows // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#user-agents: @@ -326,7 +326,7 @@ func (g *grpcClient) NewConn( } else { conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { // To access HTTP trailers, we need to read the body to EOF. - _ = discard(call) + _, _ = discard(call) return call.ResponseTrailer() } } @@ -737,7 +737,7 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error { code, err := strconv.ParseUint(codeHeader, 10 /* base */, 32 /* bitsize */) if err != nil { - return errorf(CodeInternal, "gRPC protocol error: invalid error code %q", codeHeader) + return errorf(CodeInternal, "protocol error: invalid error code %q", codeHeader) } message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage)) retErr := NewWireError(Code(code), errors.New(message)) @@ -769,14 +769,14 @@ func grpcParseTimeout(timeout string) (time.Duration, error) { } unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]] if !ok { - return 0, fmt.Errorf("gRPC protocol error: timeout %q has invalid unit", timeout) + return 0, fmt.Errorf("protocol error: timeout %q has invalid unit", timeout) } num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */) if err != nil || num < 0 { - return 0, fmt.Errorf("gRPC protocol error: invalid timeout %q", timeout) + return 0, fmt.Errorf("protocol error: invalid timeout %q", timeout) } if num > 99999999 { // timeout must be ASCII string of at most 8 digits - return 0, fmt.Errorf("gRPC protocol error: timeout %q is too long", timeout) + return 0, fmt.Errorf("protocol error: timeout %q is too long", timeout) } if unit == time.Hour && num > grpcTimeoutMaxHours { // Timeout is effectively unbounded, so ignore it. The grpc-go