From 8098575694fdd1551a8930be6436dbe68dea63e8 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Fri, 8 Sep 2023 11:00:56 -0400 Subject: [PATCH 1/8] Passing --- buffer_pool.go | 18 +- connect.go | 2 +- connect_ext_test.go | 30 +- envelope.go | 357 ++++++--------- error_writer.go | 20 +- header.go | 3 - protocol.go | 14 +- protocol_connect.go | 912 ++++++++++++++++++++------------------- protocol_connect_test.go | 38 +- protocol_grpc.go | 383 ++++++++-------- protocol_grpc_test.go | 39 +- 11 files changed, 848 insertions(+), 968 deletions(-) diff --git a/buffer_pool.go b/buffer_pool.go index 262bc7a4..83bdbbce 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -20,7 +20,7 @@ import ( ) const ( - initialBufferSize = 512 + initialBufferSize = bytes.MinRead maxRecycleBufferSize = 8 * 1024 * 1024 // if >8MiB, don't hold onto a buffer ) @@ -29,26 +29,20 @@ type bufferPool struct { } func newBufferPool() *bufferPool { - return &bufferPool{ - Pool: sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) - }, - }, - } + return &bufferPool{} } func (b *bufferPool) Get() *bytes.Buffer { if buf, ok := b.Pool.Get().(*bytes.Buffer); ok { + buf.Reset() return buf } return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) } -func (b *bufferPool) Put(buffer *bytes.Buffer) { - if buffer.Cap() > maxRecycleBufferSize { +func (b *bufferPool) Put(buf *bytes.Buffer) { + if buf.Cap() > maxRecycleBufferSize { return } - buffer.Reset() - b.Pool.Put(buffer) + b.Pool.Put(buf) } diff --git a/connect.go b/connect.go index 2c8cc5bb..e75b8301 100644 --- a/connect.go +++ b/connect.go @@ -364,7 +364,7 @@ func receiveUnaryResponse[T any](conn StreamingClientConn) (*Response[T], error) // In a well-formed stream, the response message may be followed by a block // of in-stream trailers or HTTP trailers. To ensure that we receive the // trailers, try to read another message from the stream. - if err := conn.Receive(new(T)); err == nil { + if err := conn.Receive(nil); err == nil { return nil, NewError(CodeUnknown, errors.New("unary stream has multiple messages")) } else if err != nil && !errors.Is(err, io.EOF) { return nil, NewError(CodeUnknown, err) diff --git a/connect_ext_test.go b/connect_ext_test.go index f6c1ed1d..e13cd098 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -707,7 +707,7 @@ func TestGRPCMissingTrailersError(t *testing.T) { assert.Equal(t, connectErr.Code(), connect.CodeInternal) assert.True( t, - strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer: unexpected EOF"), + strings.HasSuffix(connectErr.Message(), "protocol error: no Grpc-Status trailer"), ) } @@ -2165,7 +2165,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { assert.Nil(t, err) }, expectCode: connect.CodeInternal, - expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", + expectMsg: "internal: protocol error: no Grpc-Status trailer", }, { name: "grpc-web_missing_end", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, @@ -2178,7 +2178,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { assert.Nil(t, err) }, expectCode: connect.CodeInternal, - expectMsg: "internal: protocol error: no Grpc-Status trailer: unexpected EOF", + expectMsg: "internal: protocol error: no Grpc-Status trailer", }, { name: "connect_partial_payload", options: []connect.ClientOption{connect.WithProtoJSON()}, @@ -2190,8 +2190,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err = responseWriter.Write(payload[:len(payload)-1]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "grpc_partial_payload", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()}, @@ -2203,8 +2203,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err = responseWriter.Write(payload[:len(payload)-1]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "grpc-web_partial_payload", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, @@ -2216,8 +2216,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err = responseWriter.Write(payload[:len(payload)-1]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: fmt.Sprintf("invalid_argument: protocol error: promised %d bytes in enveloped message, got %d bytes", len(payload), len(payload)-1), + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "connect_partial_frame", options: []connect.ClientOption{connect.WithProtoJSON()}, @@ -2227,8 +2227,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err := responseWriter.Write(head[:4]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "grpc_partial_frame", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPC()}, @@ -2238,8 +2238,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err := responseWriter.Write(head[:4]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "grpc-web_partial_frame", options: []connect.ClientOption{connect.WithProtoJSON(), connect.WithGRPCWeb()}, @@ -2249,8 +2249,8 @@ func TestStreamUnexpectedEOF(t *testing.T) { _, err := responseWriter.Write(head[:4]) assert.Nil(t, err) }, - expectCode: connect.CodeInvalidArgument, - expectMsg: "invalid_argument: protocol error: incomplete envelope: unexpected EOF", + expectCode: connect.CodeInternal, + expectMsg: "internal: incomplete envelope: unexpected EOF", }, { name: "connect_excess_eof", options: []connect.ClientOption{connect.WithProtoJSON()}, diff --git a/envelope.go b/envelope.go index 8f81ab03..2381be8a 100644 --- a/envelope.go +++ b/envelope.go @@ -25,13 +25,6 @@ import ( // same meaning in the gRPC-Web, gRPC-HTTP2, and Connect protocols. const flagEnvelopeCompressed = 0b00000001 -var errSpecialEnvelope = errorf( - CodeUnknown, - "final message has protocol-specific flags: %w", - // User code checks for end of stream with errors.Is(err, io.EOF). - io.EOF, -) - // envelope is a block of arbitrary bytes wrapped in gRPC and Connect's framing // protocol. // @@ -44,267 +37,179 @@ type envelope struct { Flags uint8 } -func (e *envelope) IsSet(flag uint8) bool { - return e.Flags&flag == flag -} - -type envelopeWriter struct { - writer io.Writer - codec Codec - compressMinBytes int - compressionPool *compressionPool - bufferPool *bufferPool - sendMaxBytes int -} - -func (w *envelopeWriter) Marshal(message any) *Error { - if message == nil { - if _, err := w.writer.Write(nil); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr +func (e envelope) WriteTo(w io.Writer) (n int64, err error) { + prefix := [5]byte{} + prefix[0] = e.Flags + binary.BigEndian.PutUint32(prefix[1:5], uint32(e.Data.Len())) + for _, b := range [2][]byte{prefix[:], e.Data.Bytes()} { + wroteN, err := w.Write(b) + if err != nil { + if writeErr, ok := asError(err); ok { + return n, writeErr } - return NewError(CodeUnknown, err) - } - return nil - } - if appender, ok := w.codec.(marshalAppender); ok { - return w.marshalAppend(message, appender) - } - return w.marshal(message) -} - -// Write writes the enveloped message, compressing as necessary. It doesn't -// retain any references to the supplied envelope or its underlying data. -func (w *envelopeWriter) Write(env *envelope) *Error { - if env.IsSet(flagEnvelopeCompressed) || - w.compressionPool == nil || - env.Data.Len() < w.compressMinBytes { - if w.sendMaxBytes > 0 && env.Data.Len() > w.sendMaxBytes { - return errorf(CodeResourceExhausted, "message size %d exceeds sendMaxBytes %d", env.Data.Len(), w.sendMaxBytes) + return n, errorf(CodeUnknown, "write envelope: %w", err) } - return w.write(env) + n += int64(wroteN) } - data := w.bufferPool.Get() - defer w.bufferPool.Put(data) - if err := w.compressionPool.Compress(data, env.Data); err != nil { - return err - } - if w.sendMaxBytes > 0 && data.Len() > w.sendMaxBytes { - return errorf(CodeResourceExhausted, "compressed message size %d exceeds sendMaxBytes %d", data.Len(), w.sendMaxBytes) - } - return w.write(&envelope{ - Data: data, - Flags: env.Flags | flagEnvelopeCompressed, - }) + return n, nil } -func (w *envelopeWriter) marshalAppend(message any, codec marshalAppender) *Error { - // Codec supports MarshalAppend; try to re-use a []byte from the pool. - buffer := w.bufferPool.Get() - defer w.bufferPool.Put(buffer) - raw, err := codec.MarshalAppend(buffer.Bytes(), message) - if err != nil { - return errorf(CodeInternal, "marshal message: %w", err) +func marshal(dst *bytes.Buffer, message any, codec Codec) *Error { + if message == nil { + return nil } - if cap(raw) > buffer.Cap() { - // The buffer from the pool was too small, so MarshalAppend grew the slice. - // Pessimistically assume that the too-small buffer is insufficient for the - // application workload, so there's no point in keeping it in the pool. - // Instead, replace it with the larger, newly-allocated slice. This - // allocates, but it's a small, constant-size allocation. - *buffer = *bytes.NewBuffer(raw) - } else { - // MarshalAppend didn't allocate, but we need to fix the internal state of - // the buffer. Compared to replacing the buffer (as above), buffer.Write - // copies but avoids allocating. - buffer.Write(raw) + if codec, ok := codec.(marshalAppender); ok { + // Codec supports MarshalAppend; try to re-use a []byte from the pool. + raw, err := codec.MarshalAppend(dst.Bytes(), message) + if err != nil { + return errorf(CodeInternal, "marshal message: %w", err) + } + if cap(raw) > dst.Cap() { + // The buffer from the pool was too small, so MarshalAppend grew the slice. + // Pessimistically assume that the too-small buffer is insufficient for the + // application workload, so there's no point in keeping it in the pool. + // Instead, replace it with the larger, newly-allocated slice. This + // allocates, but it's a small, constant-size allocation. + *dst = *bytes.NewBuffer(raw) + } else { + // The buffer from the pool was large enough, MarshalAppend didn't allocate. + // Copy to the same byte slice is a nop. + dst.Write(raw[dst.Len():]) + } + return nil } - envelope := &envelope{Data: buffer} - return w.Write(envelope) -} - -func (w *envelopeWriter) marshal(message any) *Error { // Codec doesn't support MarshalAppend; let Marshal allocate a []byte. - raw, err := w.codec.Marshal(message) + raw, err := codec.Marshal(message) if err != nil { return errorf(CodeInternal, "marshal message: %w", err) } - buffer := bytes.NewBuffer(raw) - // Put our new []byte into the pool for later reuse. - defer w.bufferPool.Put(buffer) - envelope := &envelope{Data: buffer} - return w.Write(envelope) + dst.Write(raw) + return nil } -func (w *envelopeWriter) write(env *envelope) *Error { - prefix := [5]byte{} - prefix[0] = env.Flags - binary.BigEndian.PutUint32(prefix[1:5], uint32(env.Data.Len())) - if _, err := w.writer.Write(prefix[:]); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr - } - return errorf(CodeUnknown, "write envelope: %w", err) - } - if _, err := io.Copy(w.writer, env.Data); err != nil { - return errorf(CodeUnknown, "write message: %w", err) +func unmarshal(src *bytes.Buffer, message any, codec Codec) *Error { + if err := codec.Unmarshal(src.Bytes(), message); err != nil { + return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err) } return nil } -type envelopeReader struct { - reader io.Reader - codec Codec - last envelope - compressionPool *compressionPool - bufferPool *bufferPool - readMaxBytes int +func read(dst *bytes.Buffer, src io.Reader) (int, error) { + dst.Grow(bytes.MinRead) + b := dst.Bytes() + b = b[len(b):cap(b)] + n, err := src.Read(b) + _, _ = dst.Write(b[:n]) + return n, err } -func (r *envelopeReader) Unmarshal(message any) *Error { - buffer := r.bufferPool.Get() - defer r.bufferPool.Put(buffer) - - env := &envelope{Data: buffer} - err := r.Read(env) - switch { - case err == nil && - (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && - env.Data.Len() == 0: - // This is a standard message (because none of the top 7 bits are set) and - // there's no data, so the zero value of the message is correct. - return nil - case err != nil && errors.Is(err, io.EOF): - // The stream has ended. Propagate the EOF to the caller. - return err - case err != nil: - // Something's wrong. - return err - } - - data := env.Data - if data.Len() > 0 && env.IsSet(flagEnvelopeCompressed) { - if r.compressionPool == nil { - return errorf( - CodeInvalidArgument, - "protocol error: sent compressed message without Grpc-Encoding header", - ) - } - decompressed := r.bufferPool.Get() - defer r.bufferPool.Put(decompressed) - if err := r.compressionPool.Decompress(decompressed, data, int64(r.readMaxBytes)); err != nil { - return err - } - data = decompressed - } - - if env.Flags != 0 && env.Flags != flagEnvelopeCompressed { - // Drain the rest of the stream to ensure there is no extra data. - if n, err := discard(r.reader); err != nil { - return errorf(CodeInternal, "corrupt response: I/O error after end-stream message: %w", err) - } else if n > 0 { - return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n) +func readAll(dst *bytes.Buffer, src io.Reader, readMaxBytes int) *Error { + var totalN int64 + for { + readN, err := read(dst, src) + totalN += int64(readN) + if readMaxBytes > 0 && totalN > int64(readMaxBytes) { + discardN, err := discard(src) + if err != nil { + return errorf(CodeResourceExhausted, + "message is larger than configured max %d - unable to determine message size: %w", + readMaxBytes, err) + } + return errorf(CodeResourceExhausted, + "message size %d is larger than configured max %d", + totalN+discardN, readMaxBytes) } - // One of the protocol-specific flags are set, so this is the end of the - // stream. Save the message for protocol-specific code to process and - // return a sentinel error. Since we've deferred functions to return env's - // underlying buffer to a pool, we need to keep a copy. - copiedData := make([]byte, data.Len()) - copy(copiedData, data.Bytes()) - r.last = envelope{ - Data: bytes.NewBuffer(copiedData), - Flags: env.Flags, + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + if writeErr, ok := asError(err); ok { + return writeErr + } + if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", totalN); readMaxBytesErr != nil { + return readMaxBytesErr + } + return errorf(CodeUnknown, "read: %w", err) } - return errSpecialEnvelope - } - - if err := r.codec.Unmarshal(data.Bytes(), message); err != nil { - return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err) } - return nil } -func (r *envelopeReader) Read(env *envelope) *Error { - prefixes := [5]byte{} - prefixBytesRead, err := r.reader.Read(prefixes[:]) +var errEOF = errorf(CodeInternal, "%w", io.EOF) - switch { - case (err == nil || errors.Is(err, io.EOF)) && - prefixBytesRead == 5 && - isSizeZeroPrefix(prefixes): - // Successfully read prefix and expect no additional data. - env.Flags = prefixes[0] - return nil - case err != nil && errors.Is(err, io.EOF) && prefixBytesRead == 0: - // The stream ended cleanly. That's expected, but we need to propagate them - // to the user so that they know that the stream has ended. We shouldn't - // add any alarming text about protocol errors, though. - return NewError(CodeUnknown, err) - case err != nil || prefixBytesRead < 5: +func readEnvelope(dst *bytes.Buffer, src io.Reader, readMaxBytes int) (uint8, *Error) { + prefix := [5]byte{} + if _, err := io.ReadFull(src, prefix[:]); err != nil { + if errors.Is(err, io.EOF) { + return 0, errEOF + } // Something else has gone wrong - the stream didn't end cleanly. if connectErr, ok := asError(err); ok { - return connectErr + return 0, connectErr } if maxBytesErr := asMaxBytesError(err, "read 5 byte message prefix"); maxBytesErr != nil { // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. - return maxBytesErr - } - if err == nil { - err = io.ErrUnexpectedEOF + return 0, maxBytesErr } - return errorf( - CodeInvalidArgument, - "protocol error: incomplete envelope: %w", err, - ) + return 0, errorf(CodeInternal, "incomplete envelope: %w", err) } - size := int(binary.BigEndian.Uint32(prefixes[1:5])) + + size := int(binary.BigEndian.Uint32(prefix[1:5])) if size < 0 { - return errorf(CodeInvalidArgument, "message size %d overflowed uint32", size) + return 0, errorf(CodeInvalidArgument, "message size %d overflowed uint32", size) } - if r.readMaxBytes > 0 && size > r.readMaxBytes { - _, err := io.CopyN(io.Discard, r.reader, int64(size)) - if err != nil && !errors.Is(err, io.EOF) { - return errorf(CodeUnknown, "read enveloped message: %w", err) + if readMaxBytes > 0 && size > readMaxBytes { + if _, err := discard(src); err != nil { + return 0, errorf(CodeUnknown, "read enveloped message: %w", err) } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes) + return 0, errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, readMaxBytes) } if size > 0 { - // At layer 7, we don't know exactly what's happening down in L4. Large - // length-prefixed messages may arrive in chunks, so we may need to read - // the request body past EOF. We also need to take care that we don't retry - // forever if the message is malformed. - remaining := int64(size) - for remaining > 0 { - bytesRead, err := io.CopyN(env.Data, r.reader, remaining) - if err != nil && !errors.Is(err, io.EOF) { - if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { - // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. - return maxBytesErr - } - return errorf(CodeUnknown, "read enveloped message: %w", err) + dst.Grow(size) + data := dst.Bytes()[dst.Len() : dst.Len()+size] + if _, err := io.ReadFull(src, data); err != nil { + if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { + // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. + return 0, maxBytesErr } - if errors.Is(err, io.EOF) && bytesRead == 0 { - // We've gotten zero-length chunk of data. Message is likely malformed, - // don't wait for additional chunks. - return errorf( - CodeInvalidArgument, - "protocol error: promised %d bytes in enveloped message, got %d bytes", - size, - int64(size)-remaining, - ) - } - remaining -= bytesRead + return 0, errorf(CodeInternal, "incomplete envelope: %w", err) + } + if _, err := dst.Write(data); err != nil { + return 0, errorf(CodeInternal, "read enveloped message: %w", err) } } - env.Flags = prefixes[0] + return prefix[0], nil +} +func writeAll(dst io.Writer, src io.WriterTo) *Error { + if _, err := src.WriteTo(dst); err != nil { + if writeErr, ok := asError(err); ok { + return writeErr + } + return errorf(CodeInternal, "write message: %w", err) + } return nil } -func isSizeZeroPrefix(prefix [5]byte) bool { - for i := 1; i < 5; i++ { - if prefix[i] != 0 { - return false - } +func checkSendMaxBytes(length, sendMaxBytes int, isCompressed bool) *Error { + if sendMaxBytes <= 0 || length <= sendMaxBytes { + return nil + } + tmpl := "message size %d exceeds sendMaxBytes %d" + if isCompressed { + tmpl = "compressed message size %d exceeds sendMaxBytes %d" + } + return errorf(CodeResourceExhausted, tmpl, length, sendMaxBytes) +} + +func newErrInvalidEnvelopeFlags(flags uint8) *Error { + return errorf(CodeInternal, "protocol error: invalid envelope flags %08b", flags) +} + +// ensureEOF always returns io.EOF, unless there are extra bytes in src. +func ensureEOF(src io.Reader) error { + if n, err := discard(src); err != nil { + return err + } else if n > 0 { + return errorf(CodeInternal, "corrupt response: %d extra bytes after end of stream", n) } - return true + return io.EOF } diff --git a/error_writer.go b/error_writer.go index 9644715c..78928368 100644 --- a/error_writer.go +++ b/error_writer.go @@ -130,16 +130,18 @@ func (w *ErrorWriter) writeConnectUnary(response http.ResponseWriter, err error) } func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err error) error { - response.WriteHeader(http.StatusOK) - marshaler := &connectStreamingMarshaler{ - envelopeWriter: envelopeWriter{ - writer: response, - bufferPool: w.bufferPool, - }, + buffer := w.bufferPool.Get() + defer w.bufferPool.Put(buffer) + + end := newConnectEndStreamMessage(err, make(http.Header)) + if err := connectMarshalEndStreamMessage(buffer, end); err != nil { + return err } - // MarshalEndStream returns *Error: check return value to avoid typed nils. - if marshalErr := marshaler.MarshalEndStream(err, make(http.Header)); marshalErr != nil { - return marshalErr + + response.WriteHeader(http.StatusOK) + env := envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream} + if err := writeAll(response, env); err != nil { + return err } return nil } diff --git a/header.go b/header.go index f827aba2..ecadc5eb 100644 --- a/header.go +++ b/header.go @@ -55,9 +55,6 @@ func mergeHeaders(into, from http.Header) { // bypasses the CanonicalMIMEHeaderKey operation when we // know the key is already in canonical form. func getHeaderCanonical(h http.Header, key string) string { - if h == nil { - return "" - } v := h[key] if len(v) == 0 { return "" diff --git a/protocol.go b/protocol.go index a02f24b0..b24ec603 100644 --- a/protocol.go +++ b/protocol.go @@ -40,7 +40,7 @@ const ( headerUserAgent = "User-Agent" headerTrailer = "Trailer" - discardLimit = 1024 * 1024 * 4 // 4MiB + discardLimit = 1024 * 1024 * 8 // 8MiB ) var errNoTimeout = errors.New("no timeout") @@ -284,13 +284,15 @@ func isCommaOrSpace(c rune) bool { } func discard(reader io.Reader) (int64, error) { - if lr, ok := reader.(*io.LimitedReader); ok { - return io.Copy(io.Discard, lr) - } // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. - lr := &io.LimitedReader{R: reader, N: discardLimit} - return io.Copy(io.Discard, lr) + n, err := io.CopyN(io.Discard, reader, discardLimit) + if errors.Is(err, io.EOF) { + err = nil + } else if n == discardLimit { + err = io.ErrShortBuffer + } + return n, err } // negotiateCompression determines and validates the request compression and diff --git a/protocol_connect.go b/protocol_connect.go index 110666d3..2d90e24a 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -22,7 +22,6 @@ import ( "errors" "fmt" "io" - "math" "net/http" "net/url" "runtime" @@ -251,55 +250,29 @@ func (h *connectHandler) NewConn( } if h.Spec.StreamType == StreamTypeUnary { conn = &connectUnaryHandlerConn{ - spec: h.Spec, - peer: peer, - request: request, - responseWriter: responseWriter, - marshaler: connectUnaryMarshaler{ - writer: responseWriter, - codec: codec, - compressMinBytes: h.CompressMinBytes, - compressionName: responseCompression, - compressionPool: h.CompressionPools.Get(responseCompression), - bufferPool: h.BufferPool, - header: responseWriter.Header(), - sendMaxBytes: h.SendMaxBytes, - }, - unmarshaler: connectUnaryUnmarshaler{ - reader: requestBody, - codec: codec, - compressionPool: h.CompressionPools.Get(requestCompression), - bufferPool: h.BufferPool, - readMaxBytes: h.ReadMaxBytes, - }, - responseTrailer: make(http.Header), + connectHandler: h, + peer: peer, + request: request, + requestBody: requestBody, + responseWriter: responseWriter, + responseTrailer: make(http.Header), + codec: codec, + recvCompressionName: requestCompression, + recvCompressionPool: h.CompressionPools.Get(requestCompression), + sendCompressionName: responseCompression, + sendCompressionPool: h.CompressionPools.Get(responseCompression), } } else { conn = &connectStreamingHandlerConn{ - spec: h.Spec, - peer: peer, - request: request, - responseWriter: responseWriter, - marshaler: connectStreamingMarshaler{ - envelopeWriter: envelopeWriter{ - writer: responseWriter, - codec: codec, - compressMinBytes: h.CompressMinBytes, - compressionPool: h.CompressionPools.Get(responseCompression), - bufferPool: h.BufferPool, - sendMaxBytes: h.SendMaxBytes, - }, - }, - unmarshaler: connectStreamingUnmarshaler{ - envelopeReader: envelopeReader{ - reader: requestBody, - codec: codec, - compressionPool: h.CompressionPools.Get(requestCompression), - bufferPool: h.BufferPool, - readMaxBytes: h.ReadMaxBytes, - }, - }, - responseTrailer: make(http.Header), + connectHandler: h, + peer: peer, + request: request, + requestBody: requestBody, + responseWriter: responseWriter, + responseTrailer: make(http.Header), + codec: codec, + recvCompressionPool: h.CompressionPools.Get(requestCompression), + sendCompressionPool: h.CompressionPools.Get(responseCompression), } } conn = wrapHandlerConnWithCodedErrors(conn) @@ -370,71 +343,29 @@ func (c *connectClient) NewConn( var conn streamingClientConn if spec.StreamType == StreamTypeUnary { unaryConn := &connectUnaryClientConn{ - spec: spec, - peer: c.Peer(), - duplexCall: duplexCall, - compressionPools: c.CompressionPools, - bufferPool: c.BufferPool, - marshaler: connectUnaryRequestMarshaler{ - connectUnaryMarshaler: connectUnaryMarshaler{ - writer: duplexCall, - codec: c.Codec, - compressMinBytes: c.CompressMinBytes, - compressionName: c.CompressionName, - compressionPool: c.CompressionPools.Get(c.CompressionName), - bufferPool: c.BufferPool, - header: duplexCall.Header(), - sendMaxBytes: c.SendMaxBytes, - }, - }, - unmarshaler: connectUnaryUnmarshaler{ - reader: duplexCall, - codec: c.Codec, - bufferPool: c.BufferPool, - readMaxBytes: c.ReadMaxBytes, - }, + connectClient: c, + spec: spec, + duplexCall: duplexCall, + compressionPool: c.CompressionPools.Get(c.CompressionName), responseHeader: make(http.Header), responseTrailer: make(http.Header), } if spec.IdempotencyLevel == IdempotencyNoSideEffects { - unaryConn.marshaler.enableGet = c.EnableGet - unaryConn.marshaler.getURLMaxBytes = c.GetURLMaxBytes - unaryConn.marshaler.getUseFallback = c.GetUseFallback - unaryConn.marshaler.duplexCall = duplexCall if stableCodec, ok := c.Codec.(stableCodec); ok { - unaryConn.marshaler.stableCodec = stableCodec + unaryConn.stableCodec = stableCodec } } conn = unaryConn duplexCall.SetValidateResponse(unaryConn.validateResponse) } else { streamingConn := &connectStreamingClientConn{ - spec: spec, - peer: c.Peer(), - duplexCall: duplexCall, - compressionPools: c.CompressionPools, - bufferPool: c.BufferPool, - codec: c.Codec, - marshaler: connectStreamingMarshaler{ - envelopeWriter: envelopeWriter{ - writer: duplexCall, - codec: c.Codec, - compressMinBytes: c.CompressMinBytes, - compressionPool: c.CompressionPools.Get(c.CompressionName), - bufferPool: c.BufferPool, - sendMaxBytes: c.SendMaxBytes, - }, - }, - unmarshaler: connectStreamingUnmarshaler{ - envelopeReader: envelopeReader{ - reader: duplexCall, - codec: c.Codec, - bufferPool: c.BufferPool, - readMaxBytes: c.ReadMaxBytes, - }, - }, - responseHeader: make(http.Header), - responseTrailer: make(http.Header), + connectClient: c, + spec: spec, + duplexCall: duplexCall, + sendCompressionPool: c.CompressionPools.Get(c.CompressionName), + recvCompressionPool: nil, // set by validateResponse + responseHeader: make(http.Header), + responseTrailer: make(http.Header), } conn = streamingConn duplexCall.SetValidateResponse(streamingConn.validateResponse) @@ -443,15 +374,17 @@ func (c *connectClient) NewConn( } type connectUnaryClientConn struct { - spec Spec - peer Peer - duplexCall *duplexHTTPCall - compressionPools readOnlyCompressionPools - bufferPool *bufferPool - marshaler connectUnaryRequestMarshaler - unmarshaler connectUnaryUnmarshaler - responseHeader http.Header - responseTrailer http.Header + *connectClient + + spec Spec + duplexCall *duplexHTTPCall + compressionPool *compressionPool // set by validateResponse + responseHeader http.Header + responseTrailer http.Header + alreadyRead bool + + // Get-related fields + stableCodec stableCodec } func (cc *connectUnaryClientConn) Spec() Spec { @@ -463,7 +396,16 @@ func (cc *connectUnaryClientConn) Peer() Peer { } func (cc *connectUnaryClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + + if cc.EnableGet { + if err := cc.trySendGet(buffer, msg); err != nil { + return err + } + return nil + } + if err := cc.sendMsg(buffer, msg); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -478,8 +420,29 @@ func (cc *connectUnaryClientConn) CloseRequest() error { } func (cc *connectUnaryClientConn) Receive(msg any) error { + if cc.alreadyRead { + return io.EOF + } + cc.alreadyRead = true cc.duplexCall.BlockUntilResponseReady() - if err := cc.unmarshaler.Unmarshal(msg); err != nil { + + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + if err := readAll(buffer, cc.duplexCall, cc.ReadMaxBytes); err != nil { + return err + } + if cc.compressionPool != nil { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + + if err := cc.compressionPool.Decompress( + compressionBuffer, buffer, int64(cc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + } + if err := unmarshal(buffer, msg, cc.Codec); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -514,55 +477,63 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err compression := getHeaderCanonical(response.Header, connectUnaryHeaderCompression) if compression != "" && compression != compressionIdentity && - !cc.compressionPools.Contains(compression) { + !cc.CompressionPools.Contains(compression) { return errorf( CodeInternal, "unknown encoding %q: accepted encodings are %v", compression, - cc.compressionPools.CommaSeparatedNames(), + cc.CompressionPools.CommaSeparatedNames(), ) } + cc.compressionPool = cc.CompressionPools.Get(compression) + if response.StatusCode == http.StatusOK { + return nil + } if response.StatusCode == http.StatusNotModified && cc.Spec().IdempotencyLevel == IdempotencyNoSideEffects { serverErr := NewWireError(CodeUnknown, errNotModifiedClient) // RFC 9110 doesn't allow trailers on 304s, so we only need to include headers. serverErr.meta = cc.responseHeader.Clone() return serverErr - } else if response.StatusCode != http.StatusOK { - unmarshaler := connectUnaryUnmarshaler{ - reader: response.Body, - compressionPool: cc.compressionPools.Get(compression), - bufferPool: cc.bufferPool, - } - var wireErr connectWireError - if err := unmarshaler.UnmarshalFunc(&wireErr, json.Unmarshal); err != nil { - return NewError( - connectHTTPToCode(response.StatusCode), - errors.New(response.Status), - ) - } - serverErr := wireErr.asError() - if serverErr == nil { - return nil + } + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + + if err := readAll(buffer, response.Body, cc.ReadMaxBytes); err != nil { + return err + } + if cc.compressionPool != nil { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + if err := cc.compressionPool.Decompress( + compressionBuffer, buffer, int64(cc.ReadMaxBytes), + ); err != nil { + return err } - serverErr.meta = cc.responseHeader.Clone() - mergeHeaders(serverErr.meta, cc.responseTrailer) - return serverErr + buffer = compressionBuffer } - cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) - return nil + var wireErr connectWireError + if err := json.Unmarshal(buffer.Bytes(), &wireErr); err != nil { + return errorf(CodeInternal, "failed to unmarshal error: %v", err) + } + + serverErr := wireErr.asError() + if serverErr == nil { + return nil + } + serverErr.meta = cc.responseHeader.Clone() + mergeHeaders(serverErr.meta, cc.responseTrailer) + return serverErr } type connectStreamingClientConn struct { - spec Spec - peer Peer - duplexCall *duplexHTTPCall - compressionPools readOnlyCompressionPools - bufferPool *bufferPool - codec Codec - marshaler connectStreamingMarshaler - unmarshaler connectStreamingUnmarshaler - responseHeader http.Header - responseTrailer http.Header + *connectClient + + spec Spec + duplexCall *duplexHTTPCall + sendCompressionPool *compressionPool + recvCompressionPool *compressionPool + responseHeader http.Header + responseTrailer http.Header } func (cc *connectStreamingClientConn) Spec() Spec { @@ -574,10 +545,32 @@ func (cc *connectStreamingClientConn) Peer() Peer { } func (cc *connectStreamingClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + + if err := marshal(buffer, msg, cc.Codec); err != nil { return err } - return nil // must be a literal nil: nil *Error is a non-nil error + var flags uint8 + if cc.sendCompressionPool != nil && buffer.Len() > cc.CompressMinBytes { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + if err := cc.sendCompressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + flags |= flagEnvelopeCompressed + } + if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { + return err + } + env := envelope{Data: buffer, Flags: flags} + if err := writeAll(cc.duplexCall, env); err != nil { + return err + } + return nil // must be a literal nil: nil *error is a non-nil error } func (cc *connectStreamingClientConn) RequestHeader() http.Header { @@ -590,33 +583,63 @@ func (cc *connectStreamingClientConn) CloseRequest() error { func (cc *connectStreamingClientConn) Receive(msg any) error { cc.duplexCall.BlockUntilResponseReady() - err := cc.unmarshaler.Unmarshal(msg) - if err == nil { - return nil + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + + flags, err := readEnvelope(buffer, cc.duplexCall, cc.ReadMaxBytes) + if err != nil { + // If the error is EOF but not from a last message, we want to return + // io.ErrUnexpectedEOF instead. + if errors.Is(err, io.EOF) { + err = errorf(CodeInternal, "protocol error: %w", io.ErrUnexpectedEOF) + } + return err } - // See if the server sent an explicit error in the end-of-stream message. - mergeHeaders(cc.responseTrailer, cc.unmarshaler.Trailer()) - if serverErr := cc.unmarshaler.EndStreamError(); serverErr != nil { - // This is expected from a protocol perspective, but receiving an - // end-of-stream message means that we're _not_ getting a regular message. - // For users to realize that the stream has ended, Receive must return an - // error. - serverErr.meta = cc.responseHeader.Clone() - mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) - return serverErr + if flags&flagEnvelopeCompressed != 0 { + if cc.recvCompressionPool == nil { + return errorf(CodeInvalidArgument, + "protocol error: received compressed message without %s header", + connectStreamingHeaderCompression, + ) + } + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + if err := cc.recvCompressionPool.Decompress( + compressionBuffer, buffer, int64(cc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer } - // If the error is EOF but not from a last message, we want to return - // io.ErrUnexpectedEOF instead. - if errors.Is(err, io.EOF) && !errors.Is(err, errSpecialEnvelope) { - err = errorf(CodeInternal, "protocol error: %w", io.ErrUnexpectedEOF) + if flags != 0 && flags != flagEnvelopeCompressed { + end, err := connectUnmarshalEndStreamMessage(buffer, flags) + if err != nil { + return err + } + // See if the server sent an explicit error in the end-of-stream message. + mergeHeaders(cc.responseTrailer, end.Trailer) + if serverErr := end.Error.asError(); serverErr != nil { + // This is expected from a protocol perspective, but receiving an + // end-of-stream message means that we're _not_ getting a regular message. + // For users to realize that the stream has ended, Receive must return an + // error. + serverErr.meta = cc.responseHeader.Clone() + mergeHeaders(serverErr.meta, cc.responseTrailer) + cc.duplexCall.SetError(serverErr) + return serverErr + } + return ensureEOF(cc.duplexCall) } - // There's no error in the trailers, so this was probably an error - // converting the bytes to a message, an error reading from the network, or - // just an EOF. We're going to return it to the user, but we also want to - // setResponseError so Send errors out. - cc.duplexCall.SetError(err) - return err + + if err := unmarshal(buffer, msg, cc.Codec); err != nil { + // There's no error in the trailers, so this was probably an error + // converting the bytes to a message, an error reading from the network, or + // just an EOF. We're going to return it to the user, but we also want to + // setResponseError so Send errors out. + cc.duplexCall.SetError(err) + return err + } + return nil // must be a literal nil: nil *Error is a non-nil error } func (cc *connectStreamingClientConn) ResponseHeader() http.Header { @@ -644,32 +667,38 @@ func (cc *connectStreamingClientConn) validateResponse(response *http.Response) compression := getHeaderCanonical(response.Header, connectStreamingHeaderCompression) if compression != "" && compression != compressionIdentity && - !cc.compressionPools.Contains(compression) { + !cc.CompressionPools.Contains(compression) { return errorf( CodeInternal, "unknown encoding %q: accepted encodings are %v", compression, - cc.compressionPools.CommaSeparatedNames(), + cc.CompressionPools.CommaSeparatedNames(), ) } - cc.unmarshaler.compressionPool = cc.compressionPools.Get(compression) + cc.recvCompressionPool = cc.CompressionPools.Get(compression) mergeHeaders(cc.responseHeader, response.Header) return nil } type connectUnaryHandlerConn struct { - spec Spec - peer Peer - request *http.Request - responseWriter http.ResponseWriter - marshaler connectUnaryMarshaler - unmarshaler connectUnaryUnmarshaler - responseTrailer http.Header - wroteBody bool + *connectHandler + + peer Peer + request *http.Request + requestBody io.ReadCloser + responseWriter http.ResponseWriter + responseTrailer http.Header + codec Codec + recvCompressionName string + recvCompressionPool *compressionPool + sendCompressionName string + sendCompressionPool *compressionPool + alreadyRead bool + wroteBody bool } func (hc *connectUnaryHandlerConn) Spec() Spec { - return hc.spec + return hc.protocolHandlerParams.Spec } func (hc *connectUnaryHandlerConn) Peer() Peer { @@ -677,7 +706,30 @@ func (hc *connectUnaryHandlerConn) Peer() Peer { } func (hc *connectUnaryHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { + if hc.alreadyRead { + return NewError(CodeInternal, io.EOF) + } + hc.alreadyRead = true + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + if err := readAll(buffer, hc.requestBody, hc.ReadMaxBytes); err != nil { + return err + } + if buffer.Len() > 0 && hc.recvCompressionPool != nil { + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + if err := hc.recvCompressionPool.Decompress( + compressionBuffer, buffer, int64(hc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer + } + if err := unmarshal(buffer, msg, hc.codec); err != nil { + return err + } + if err := ensureEOF(hc.requestBody); !errors.Is(err, io.EOF) { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -690,7 +742,33 @@ func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { func (hc *connectUnaryHandlerConn) Send(msg any) error { hc.wroteBody = true hc.writeResponseHeader(nil /* error */) - if err := hc.marshaler.Marshal(msg); err != nil { + header := hc.responseWriter.Header() + + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + if err := marshal(buffer, msg, hc.codec); err != nil { + return err + } + var isCompressed bool + if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + + if err := hc.sendCompressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + setHeaderCanonical(header, connectUnaryHeaderCompression, hc.sendCompressionName) + isCompressed = true + } + if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, isCompressed); err != nil { + delHeaderCanonical(header, connectUnaryHeaderCompression) + return err + } + if err := writeAll(hc.responseWriter, buffer); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -755,17 +833,21 @@ func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { } type connectStreamingHandlerConn struct { - spec Spec - peer Peer - request *http.Request - responseWriter http.ResponseWriter - marshaler connectStreamingMarshaler - unmarshaler connectStreamingUnmarshaler - responseTrailer http.Header + *connectHandler + + peer Peer + request *http.Request + requestBody io.ReadCloser + responseWriter http.ResponseWriter + responseTrailer http.Header + codec Codec + sendCompressionPool *compressionPool + recvCompressionPool *compressionPool + end *connectEndStreamMessage // set by Receive } func (hc *connectStreamingHandlerConn) Spec() Spec { - return hc.spec + return hc.protocolHandlerParams.Spec } func (hc *connectStreamingHandlerConn) Peer() Peer { @@ -773,12 +855,41 @@ func (hc *connectStreamingHandlerConn) Peer() Peer { } func (hc *connectStreamingHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { - // Clients may not send end-of-stream metadata, so we don't need to handle - // errSpecialEnvelope. + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + flags, err := readEnvelope(buffer, hc.request.Body, hc.ReadMaxBytes) + if err != nil { return err } - return nil // must be a literal nil: nil *Error is a non-nil error + if flags&flagEnvelopeCompressed != 0 { + if hc.recvCompressionPool == nil { + return errorf(CodeInvalidArgument, + "protocol error: received compressed message without %s header", + connectStreamingHeaderCompression, + ) + } + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + if err := hc.recvCompressionPool.Decompress( + compressionBuffer, buffer, int64(hc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer + } + if flags != 0 && flags != flagEnvelopeCompressed { + end, err := connectUnmarshalEndStreamMessage(buffer, flags) + if err != nil { + return err + } + hc.end = end + return ensureEOF(hc.request.Body) + } + if err := unmarshal(buffer, msg, hc.codec); err != nil { + return err + } + return nil // must be a literal nil: nil *error is a non-nil error } func (hc *connectStreamingHandlerConn) RequestHeader() http.Header { @@ -786,11 +897,39 @@ func (hc *connectStreamingHandlerConn) RequestHeader() http.Header { } func (hc *connectStreamingHandlerConn) Send(msg any) error { - defer flushResponseWriter(hc.responseWriter) - if err := hc.marshaler.Marshal(msg); err != nil { + if msg == nil { + hc.responseWriter.WriteHeader(http.StatusOK) + return nil + } + + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + if err := marshal(buffer, msg, hc.codec); err != nil { return err } - return nil // must be a literal nil: nil *Error is a non-nil error + var flags uint8 + if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + + if err := hc.sendCompressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + flags |= flagEnvelopeCompressed + } + if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { + return err + } + env := envelope{Data: buffer, Flags: flags} + if err := writeAll(hc.responseWriter, env); err != nil { + return err + } + flushResponseWriter(hc.responseWriter) + return nil // must be a literal nil: nil *error is a non-nil error } func (hc *connectStreamingHandlerConn) ResponseHeader() http.Header { @@ -803,7 +942,7 @@ func (hc *connectStreamingHandlerConn) ResponseTrailer() http.Header { func (hc *connectStreamingHandlerConn) Close(err error) error { defer flushResponseWriter(hc.responseWriter) - if err := hc.marshaler.MarshalEndStream(err, hc.responseTrailer); err != nil { + if err := hc.marshalEndStream(err, hc.responseTrailer); err != nil { _ = hc.request.Body.Close() return err } @@ -822,290 +961,133 @@ func (hc *connectStreamingHandlerConn) Close(err error) error { return nil // must be a literal nil: nil *Error is a non-nil error } -type connectStreamingMarshaler struct { - envelopeWriter -} - -func (m *connectStreamingMarshaler) MarshalEndStream(err error, trailer http.Header) *Error { - end := &connectEndStreamMessage{Trailer: trailer} - if err != nil { - end.Error = newConnectWireError(err) - if connectErr, ok := asError(err); ok { - mergeHeaders(end.Trailer, connectErr.meta) - } - } - data, marshalErr := json.Marshal(end) - if marshalErr != nil { - return errorf(CodeInternal, "marshal end stream: %w", marshalErr) +func (hc *connectStreamingHandlerConn) marshalEndStream(err error, trailer http.Header) *Error { + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + end := newConnectEndStreamMessage(err, trailer) + if err := connectMarshalEndStreamMessage(buffer, end); err != nil { + return err } - raw := bytes.NewBuffer(data) - defer m.envelopeWriter.bufferPool.Put(raw) - return m.Write(&envelope{ - Data: raw, - Flags: connectFlagEnvelopeEndStream, - }) + env := envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream} + return writeAll(hc.responseWriter, env) } -type connectStreamingUnmarshaler struct { - envelopeReader - - endStreamErr *Error - trailer http.Header -} - -func (u *connectStreamingUnmarshaler) Unmarshal(message any) *Error { - err := u.envelopeReader.Unmarshal(message) - if err == nil { - return nil - } - if !errors.Is(err, errSpecialEnvelope) { +func (cc *connectUnaryClientConn) sendMsg(buffer *bytes.Buffer, msg any) error { + if err := marshal(buffer, msg, cc.Codec); err != nil { return err } - env := u.envelopeReader.last - if !env.IsSet(connectFlagEnvelopeEndStream) { - return errorf(CodeInternal, "protocol error: invalid envelope flags %d", env.Flags) - } - var end connectEndStreamMessage - if err := json.Unmarshal(env.Data.Bytes(), &end); err != nil { - return errorf(CodeInternal, "unmarshal end stream message: %w", err) - } - for name, value := range end.Trailer { - canonical := http.CanonicalHeaderKey(name) - if name != canonical { - delete(end.Trailer, name) - end.Trailer[canonical] = append(end.Trailer[canonical], value...) - } - } - u.trailer = end.Trailer - u.endStreamErr = end.Error.asError() - return errSpecialEnvelope -} - -func (u *connectStreamingUnmarshaler) Trailer() http.Header { - return u.trailer -} - -func (u *connectStreamingUnmarshaler) EndStreamError() *Error { - return u.endStreamErr -} - -type connectUnaryMarshaler struct { - writer io.Writer - codec Codec - compressMinBytes int - compressionName string - compressionPool *compressionPool - bufferPool *bufferPool - header http.Header - sendMaxBytes int -} + var isCompressed bool + if cc.compressionPool != nil && buffer.Len() > cc.CompressMinBytes { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) -func (m *connectUnaryMarshaler) Marshal(message any) *Error { - if message == nil { - return m.write(nil) - } - var data []byte - var err error - if appender, ok := m.codec.(marshalAppender); ok { - data, err = appender.MarshalAppend(m.bufferPool.Get().Bytes(), message) - } else { - // Can't avoid allocating the slice, but we'll reuse it. - data, err = m.codec.Marshal(message) - } - if err != nil { - return errorf(CodeInternal, "marshal message: %w", err) - } - uncompressed := bytes.NewBuffer(data) - defer m.bufferPool.Put(uncompressed) - if len(data) < m.compressMinBytes || m.compressionPool == nil { - if m.sendMaxBytes > 0 && len(data) > m.sendMaxBytes { - return NewError(CodeResourceExhausted, fmt.Errorf("message size %d exceeds sendMaxBytes %d", len(data), m.sendMaxBytes)) + if err := cc.compressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err } - return m.write(data) + buffer = compressionBuffer // swap buffers + setHeaderCanonical(cc.duplexCall.Header(), connectUnaryHeaderCompression, cc.CompressionName) + isCompressed = true } - compressed := m.bufferPool.Get() - defer m.bufferPool.Put(compressed) - if err := m.compressionPool.Compress(compressed, uncompressed); err != nil { + if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, isCompressed); err != nil { + delHeaderCanonical(cc.duplexCall.Header(), connectUnaryHeaderCompression) return err } - if m.sendMaxBytes > 0 && compressed.Len() > m.sendMaxBytes { - return NewError(CodeResourceExhausted, fmt.Errorf("compressed message size %d exceeds sendMaxBytes %d", compressed.Len(), m.sendMaxBytes)) + if err := writeAll(cc.duplexCall, buffer); err != nil { + return err } - setHeaderCanonical(m.header, connectUnaryHeaderCompression, m.compressionName) - return m.write(compressed.Bytes()) + return nil } -func (m *connectUnaryMarshaler) write(data []byte) *Error { - if _, err := m.writer.Write(data); err != nil { - if connectErr, ok := asError(err); ok { - return connectErr +func (cc *connectUnaryClientConn) trySendGet(buffer *bytes.Buffer, msg any) error { + if cc.stableCodec == nil { + if cc.GetUseFallback { + return cc.sendMsg(buffer, msg) } - return errorf(CodeUnknown, "write message: %w", err) + return errorf(CodeInternal, "codec %s doesn't support stable marshal; can't use get", cc.Codec.Name()) + } + if msg != nil { + data, err := cc.stableCodec.MarshalStable(msg) + if err != nil { + return err + } + buffer.Write(data) } - return nil -} -type connectUnaryRequestMarshaler struct { - connectUnaryMarshaler + isTooBig := cc.SendMaxBytes > 0 && buffer.Len() > cc.SendMaxBytes + isCompressed := false - enableGet bool - getURLMaxBytes int - getUseFallback bool - stableCodec stableCodec - duplexCall *duplexHTTPCall -} + if isTooBig && cc.compressionPool != nil { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) -func (m *connectUnaryRequestMarshaler) Marshal(message any) *Error { - if m.enableGet { - if m.stableCodec == nil && !m.getUseFallback { - return errorf(CodeInternal, "codec %s doesn't support stable marshal; can't use get", m.codec.Name()) - } - if m.stableCodec != nil { - return m.marshalWithGet(message) + if err := cc.compressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err } + buffer = compressionBuffer // swap buffers + isCompressed = true + isTooBig = cc.SendMaxBytes > 0 && buffer.Len() > cc.SendMaxBytes } - return m.connectUnaryMarshaler.Marshal(message) -} -func (m *connectUnaryRequestMarshaler) marshalWithGet(message any) *Error { - // TODO(jchadwick-buf): This function is mostly a superset of - // connectUnaryMarshaler.Marshal. This should be reconciled at some point. - var data []byte - var err error - if message != nil { - data, err = m.stableCodec.MarshalStable(message) - if err != nil { - return errorf(CodeInternal, "marshal message stable: %w", err) + if isTooBig { + if cc.GetUseFallback { + buffer.Reset() + return cc.sendMsg(buffer, msg) } - } - isTooBig := m.sendMaxBytes > 0 && len(data) > m.sendMaxBytes - if isTooBig && m.compressionPool == nil { - return NewError(CodeResourceExhausted, fmt.Errorf( - "message size %d exceeds sendMaxBytes %d: enabling request compression may help", - len(data), - m.sendMaxBytes, - )) - } - if !isTooBig { - url := m.buildGetURL(data, false /* compressed */) - if m.getURLMaxBytes <= 0 || len(url.String()) < m.getURLMaxBytes { - return m.writeWithGet(url) - } - if m.compressionPool == nil { - if m.getUseFallback { - return m.write(data) - } - return NewError(CodeResourceExhausted, fmt.Errorf( - "url size %d exceeds getURLMaxBytes %d: enabling request compression may help", - len(url.String()), - m.getURLMaxBytes, - )) + if isCompressed { + return errorf(CodeResourceExhausted, + "message size %d exceeds sendMaxBytes %d", + buffer.Len(), + cc.SendMaxBytes, + ) } + return errorf(CodeResourceExhausted, + "message size %d exceeds sendMaxBytes %d: enabling request compression may help", + buffer.Len(), + cc.SendMaxBytes, + ) } - // Compress message to try to make it fit in the URL. - uncompressed := bytes.NewBuffer(data) - defer m.bufferPool.Put(uncompressed) - compressed := m.bufferPool.Get() - defer m.bufferPool.Put(compressed) - if err := m.compressionPool.Compress(compressed, uncompressed); err != nil { - return err - } - if m.sendMaxBytes > 0 && compressed.Len() > m.sendMaxBytes { - return NewError(CodeResourceExhausted, fmt.Errorf("compressed message size %d exceeds sendMaxBytes %d", compressed.Len(), m.sendMaxBytes)) - } - url := m.buildGetURL(compressed.Bytes(), true /* compressed */) - if m.getURLMaxBytes <= 0 || len(url.String()) < m.getURLMaxBytes { - return m.writeWithGet(url) - } - if m.getUseFallback { - setHeaderCanonical(m.header, connectUnaryHeaderCompression, m.compressionName) - return m.write(compressed.Bytes()) + + header := cc.duplexCall.Header() + url := cc.buildGetURL(buffer.Bytes(), isCompressed) + if cc.GetURLMaxBytes > 0 && len(url.String()) > cc.GetURLMaxBytes { + if cc.GetUseFallback { + buffer.Reset() + setHeaderCanonical(header, connectUnaryHeaderCompression, cc.CompressionName) + return cc.sendMsg(buffer, msg) + } + return errorf(CodeResourceExhausted, + "compressed url size %d exceeds getURLMaxBytes %d", + len(url.String()), cc.GetURLMaxBytes) } - return NewError(CodeResourceExhausted, fmt.Errorf("compressed url size %d exceeds getURLMaxBytes %d", len(url.String()), m.getURLMaxBytes)) + + delete(header, connectHeaderProtocolVersion) + cc.duplexCall.SetMethod(http.MethodGet) + *cc.duplexCall.URL() = *url + return nil } -func (m *connectUnaryRequestMarshaler) buildGetURL(data []byte, compressed bool) *url.URL { - url := *m.duplexCall.URL() +func (cc *connectUnaryClientConn) buildGetURL(data []byte, compressed bool) *url.URL { + url := *cc.duplexCall.URL() query := url.Query() query.Set(connectUnaryConnectQueryParameter, connectUnaryConnectQueryValue) - query.Set(connectUnaryEncodingQueryParameter, m.codec.Name()) - if m.stableCodec.IsBinary() || compressed { + query.Set(connectUnaryEncodingQueryParameter, cc.Codec.Name()) + if cc.stableCodec.IsBinary() || compressed { query.Set(connectUnaryMessageQueryParameter, encodeBinaryQueryValue(data)) query.Set(connectUnaryBase64QueryParameter, "1") } else { query.Set(connectUnaryMessageQueryParameter, string(data)) } if compressed { - query.Set(connectUnaryCompressionQueryParameter, m.compressionName) + query.Set(connectUnaryCompressionQueryParameter, cc.CompressionName) } url.RawQuery = query.Encode() return &url } -func (m *connectUnaryRequestMarshaler) writeWithGet(url *url.URL) *Error { - delete(m.header, connectHeaderProtocolVersion) - m.duplexCall.SetMethod(http.MethodGet) - *m.duplexCall.URL() = *url - return nil -} - -type connectUnaryUnmarshaler struct { - reader io.Reader - codec Codec - compressionPool *compressionPool - bufferPool *bufferPool - alreadyRead bool - readMaxBytes int -} - -func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error { - return u.UnmarshalFunc(message, u.codec.Unmarshal) -} - -func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error { - if u.alreadyRead { - return NewError(CodeInternal, io.EOF) - } - u.alreadyRead = true - data := u.bufferPool.Get() - defer u.bufferPool.Put(data) - reader := u.reader - if u.readMaxBytes > 0 && int64(u.readMaxBytes) < math.MaxInt64 { - reader = io.LimitReader(u.reader, int64(u.readMaxBytes)+1) - } - // ReadFrom ignores io.EOF, so any error here is real. - bytesRead, err := data.ReadFrom(reader) - if err != nil { - if connectErr, ok := asError(err); ok { - return connectErr - } - if readMaxBytesErr := asMaxBytesError(err, "read first %d bytes of message", bytesRead); readMaxBytesErr != nil { - return readMaxBytesErr - } - return errorf(CodeUnknown, "read message: %w", err) - } - if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) { - // Attempt to read to end in order to allow connection re-use - discardedBytes, err := io.Copy(io.Discard, u.reader) - if err != nil { - return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err) - } - return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes) - } - if data.Len() > 0 && u.compressionPool != nil { - decompressed := u.bufferPool.Get() - defer u.bufferPool.Put(decompressed) - if err := u.compressionPool.Decompress(decompressed, data, int64(u.readMaxBytes)); err != nil { - return err - } - data = decompressed - } - if err := unmarshal(data.Bytes(), message); err != nil { - return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err) - } - return nil -} - type connectWireDetail ErrorDetail func (d *connectWireDetail) MarshalJSON() ([]byte, error) { @@ -1203,6 +1185,20 @@ type connectEndStreamMessage struct { Trailer http.Header `json:"metadata,omitempty"` } +func newConnectEndStreamMessage(err error, trailer http.Header) *connectEndStreamMessage { + if trailer == nil { + trailer = make(http.Header) + } + end := &connectEndStreamMessage{Trailer: trailer} + if err != nil { + end.Error = newConnectWireError(err) + if connectErr, ok := asError(err); ok { + mergeHeaders(end.Trailer, connectErr.meta) + } + } + return end +} + func connectCodeToHTTP(code Code) int { // Return literals rather than named constants from the HTTP package to make // it easier to compare this function to the Connect specification. @@ -1312,3 +1308,29 @@ func queryValueReader(data string, base64Encoded bool) io.Reader { } return strings.NewReader(data) } + +func connectUnmarshalEndStreamMessage(src *bytes.Buffer, flags uint8) (*connectEndStreamMessage, *Error) { + end := connectEndStreamMessage{} + if flags^connectFlagEnvelopeEndStream != 0 { + return nil, newErrInvalidEnvelopeFlags(flags) + } + if err := json.Unmarshal(src.Bytes(), &end); err != nil { + return nil, errorf(CodeInternal, "unmarshal end stream message: %w", err) + } + for name, value := range end.Trailer { + canonical := http.CanonicalHeaderKey(name) + if name != canonical { + delete(end.Trailer, name) + end.Trailer[canonical] = append(end.Trailer[canonical], value...) + } + } + return &end, nil +} +func connectMarshalEndStreamMessage(dst *bytes.Buffer, end *connectEndStreamMessage) *Error { + data, marshalErr := json.Marshal(end) + if marshalErr != nil { + return errorf(CodeInternal, "marshal end stream: %w", marshalErr) + } + dst.Write(data) + return nil +} diff --git a/protocol_connect_test.go b/protocol_connect_test.go index 692183d7..80d9257c 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -60,36 +60,26 @@ func TestConnectErrorDetailMarshalingNoDescriptor(t *testing.T) { func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { t.Parallel() - buffer := bytes.Buffer{} - bufferPool := newBufferPool() - - endStreamMessage := connectEndStreamMessage{Trailer: make(http.Header)} + buffer := &bytes.Buffer{} + endStreamMessage := newConnectEndStreamMessage(nil, make(http.Header)) endStreamMessage.Trailer["not-canonical-header"] = []string{"a"} endStreamMessage.Trailer["mixed-Canonical"] = []string{"b"} endStreamMessage.Trailer["Mixed-Canonical"] = []string{"b"} endStreamMessage.Trailer["Canonical-Header"] = []string{"c"} - endStreamData, err := json.Marshal(endStreamMessage) + err := connectMarshalEndStreamMessage(buffer, endStreamMessage) + assert.Nil(t, err) + + output := &bytes.Buffer{} + err = writeAll(output, envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream}) assert.Nil(t, err) - writer := envelopeWriter{ - writer: &buffer, - bufferPool: bufferPool, - } - err = writer.Write(&envelope{ - Flags: connectFlagEnvelopeEndStream, - Data: bytes.NewBuffer(endStreamData), - }) + input := &bytes.Buffer{} + _, err = readEnvelope(input, output, -1) assert.Nil(t, err) - unmarshaler := connectStreamingUnmarshaler{ - envelopeReader: envelopeReader{ - reader: &buffer, - bufferPool: bufferPool, - }, - } - err = unmarshaler.Unmarshal(nil) // parameter won't be used - assert.ErrorIs(t, err, errSpecialEnvelope) - assert.Equal(t, unmarshaler.Trailer().Values("Not-Canonical-Header"), []string{"a"}) - assert.Equal(t, unmarshaler.Trailer().Values("Mixed-Canonical"), []string{"b", "b"}) - assert.Equal(t, unmarshaler.Trailer().Values("Canonical-Header"), []string{"c"}) + end, err := connectUnmarshalEndStreamMessage(input, connectFlagEnvelopeEndStream) + assert.Nil(t, err) + assert.Equal(t, end.Trailer.Values("Not-Canonical-Header"), []string{"a"}) + assert.Equal(t, end.Trailer.Values("Mixed-Canonical"), []string{"b", "b"}) + assert.Equal(t, end.Trailer.Values("Canonical-Header"), []string{"c"}) } diff --git a/protocol_grpc.go b/protocol_grpc.go index 8a9d0ada..0cfeecc0 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -16,6 +16,7 @@ package connect import ( "bufio" + "bytes" "context" "errors" "fmt" @@ -69,7 +70,6 @@ var ( grpcAllowedMethods = map[string]struct{}{ http.MethodPost: {}, } - errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF) // defaultGrpcUserAgent follows // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#user-agents: @@ -196,38 +196,18 @@ func (g *grpcHandler) NewConn( protocolName = ProtocolGRPCWeb } conn := wrapHandlerConnWithCodedErrors(&grpcHandlerConn{ - spec: g.Spec, + grpcHandler: g, peer: Peer{ Addr: request.RemoteAddr, Protocol: protocolName, }, - web: g.web, - bufferPool: g.BufferPool, - protobuf: g.Codecs.Protobuf(), // for errors - marshaler: grpcMarshaler{ - envelopeWriter: envelopeWriter{ - writer: responseWriter, - compressionPool: g.CompressionPools.Get(responseCompression), - codec: codec, - compressMinBytes: g.CompressMinBytes, - bufferPool: g.BufferPool, - sendMaxBytes: g.SendMaxBytes, - }, - }, - responseWriter: responseWriter, - responseHeader: make(http.Header), - responseTrailer: make(http.Header), - request: request, - unmarshaler: grpcUnmarshaler{ - envelopeReader: envelopeReader{ - reader: request.Body, - codec: codec, - compressionPool: g.CompressionPools.Get(requestCompression), - bufferPool: g.BufferPool, - readMaxBytes: g.ReadMaxBytes, - }, - web: g.web, - }, + request: request, + responseWriter: responseWriter, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), + codec: codec, + recvCompressionPool: g.CompressionPools.Get(requestCompression), + sendCompressionPool: g.CompressionPools.Get(responseCompression), }) if failed != nil { // Negotiation failed, so we can't establish a stream. @@ -299,62 +279,28 @@ func (g *grpcClient) NewConn( header, ) conn := &grpcClientConn{ - spec: spec, - peer: g.Peer(), - duplexCall: duplexCall, - compressionPools: g.CompressionPools, - bufferPool: g.BufferPool, - protobuf: g.Protobuf, - marshaler: grpcMarshaler{ - envelopeWriter: envelopeWriter{ - writer: duplexCall, - compressionPool: g.CompressionPools.Get(g.CompressionName), - codec: g.Codec, - compressMinBytes: g.CompressMinBytes, - bufferPool: g.BufferPool, - sendMaxBytes: g.SendMaxBytes, - }, - }, - unmarshaler: grpcUnmarshaler{ - envelopeReader: envelopeReader{ - reader: duplexCall, - codec: g.Codec, - bufferPool: g.BufferPool, - readMaxBytes: g.ReadMaxBytes, - }, - }, - responseHeader: make(http.Header), - responseTrailer: make(http.Header), + grpcClient: g, + spec: spec, + duplexCall: duplexCall, + responseHeader: make(http.Header), + responseTrailer: make(http.Header), + sendCompressionPool: g.CompressionPools.Get(g.CompressionName), + recvCompressionPool: nil, // set in SetValidateResponse } duplexCall.SetValidateResponse(conn.validateResponse) - if g.web { - conn.unmarshaler.web = true - conn.readTrailers = func(unmarshaler *grpcUnmarshaler, _ *duplexHTTPCall) http.Header { - return unmarshaler.WebTrailer() - } - } else { - conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header { - // To access HTTP trailers, we need to read the body to EOF. - _, _ = discard(call) - return call.ResponseTrailer() - } - } return wrapClientConnWithCodedErrors(conn) } // grpcClientConn works for both gRPC and gRPC-Web. type grpcClientConn struct { - spec Spec - peer Peer - duplexCall *duplexHTTPCall - compressionPools readOnlyCompressionPools - bufferPool *bufferPool - protobuf Codec // for errors - marshaler grpcMarshaler - unmarshaler grpcUnmarshaler - responseHeader http.Header - responseTrailer http.Header - readTrailers func(*grpcUnmarshaler, *duplexHTTPCall) http.Header + *grpcClient + + spec Spec + duplexCall *duplexHTTPCall + responseHeader http.Header + responseTrailer http.Header + sendCompressionPool *compressionPool + recvCompressionPool *compressionPool } func (cc *grpcClientConn) Spec() Spec { @@ -366,7 +312,28 @@ func (cc *grpcClientConn) Peer() Peer { } func (cc *grpcClientConn) Send(msg any) error { - if err := cc.marshaler.Marshal(msg); err != nil { + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) + if err := marshal(buffer, msg, cc.Codec); err != nil { + return err + } + var flags uint8 + if buffer.Len() > cc.CompressMinBytes && cc.sendCompressionPool != nil { + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + if err := cc.sendCompressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + flags |= flagEnvelopeCompressed + } + if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { + return err + } + env := envelope{Data: buffer, Flags: flags} + if err := writeAll(cc.duplexCall, env); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -381,42 +348,69 @@ func (cc *grpcClientConn) CloseRequest() error { } func (cc *grpcClientConn) Receive(msg any) error { + buffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(buffer) cc.duplexCall.BlockUntilResponseReady() - err := cc.unmarshaler.Unmarshal(msg) - if err == nil { - return nil + + flags, err := readEnvelope(buffer, cc.duplexCall, cc.ReadMaxBytes) + if err != nil { + if !errors.Is(err, io.EOF) { + return err + } + // If we got an EOF, we need to check the trailers for an error. + mergeHeaders(cc.responseTrailer, cc.duplexCall.ResponseTrailer()) + if err := grpcErrorFromTrailer(cc.Protobuf, cc.responseTrailer); err != nil { + return err + } + return io.EOF + } + if flags&flagEnvelopeCompressed != 0 { + if cc.recvCompressionPool == nil { + return errorf(CodeInvalidArgument, + "protocol error: received compressed message without %s header", + grpcHeaderCompression, + ) + } + compressionBuffer := cc.BufferPool.Get() + defer cc.BufferPool.Put(compressionBuffer) + if err := cc.recvCompressionPool.Decompress( + compressionBuffer, buffer, int64(cc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers } - if getHeaderCanonical(cc.responseHeader, grpcHeaderStatus) != "" { - // We got what gRPC calls a trailers-only response, which puts the trailing - // metadata (including errors) into HTTP headers. validateResponse has - // already extracted the error. + if flags&grpcFlagEnvelopeTrailer != 0 { + if !cc.web { + return newErrInvalidEnvelopeFlags(flags) + } + // handle grpc-web trailers + // Per the gRPC-Web specification, trailers should be encoded as an HTTP/1 + // headers block _without_ the terminating newline. To make the headers + // parseable by net/textproto, we need to add the newline. + if err := buffer.WriteByte('\n'); err != nil { + return errorf(CodeInternal, "unmarshal web trailers: %w", err) + } + bufferedReader := bufio.NewReader(buffer) + mimeReader := textproto.NewReader(bufferedReader) + mimeHeader, mimeErr := mimeReader.ReadMIMEHeader() + if mimeErr != nil { + return errorf( + CodeInternal, + "gRPC-Web protocol error: trailers invalid: %w", + mimeErr, + ) + } + mergeHeaders(cc.responseTrailer, http.Header(mimeHeader)) + if err := grpcErrorFromTrailer(cc.Protobuf, cc.responseTrailer); err != nil { + return err + } + return ensureEOF(cc.duplexCall) + } + if err := unmarshal(buffer, msg, cc.Codec); err != nil { return err } - // See if the server sent an explicit error in the HTTP or gRPC-Web trailers. - mergeHeaders( - cc.responseTrailer, - cc.readTrailers(&cc.unmarshaler, cc.duplexCall), - ) - serverErr := grpcErrorFromTrailer(cc.protobuf, cc.responseTrailer) - if serverErr != nil && (errors.Is(err, io.EOF) || !errors.Is(serverErr, errTrailersWithoutGRPCStatus)) { - // We've either: - // - Cleanly read until the end of the response body and *not* received - // gRPC status trailers, which is a protocol error, or - // - Received an explicit error from the server. - // - // This is expected from a protocol perspective, but receiving trailers - // means that we're _not_ getting a message. For users to realize that - // the stream has ended, Receive must return an error. - serverErr.meta = cc.responseHeader.Clone() - mergeHeaders(serverErr.meta, cc.responseTrailer) - cc.duplexCall.SetError(serverErr) - return serverErr - } - // This was probably an error converting the bytes to a message or an error - // reading from the network. We're going to return it to the - // user, but we also want to setResponseError so Send errors out. - cc.duplexCall.SetError(err) - return err + return nil // must be a literal nil: nil *Error is a non-nil error } func (cc *grpcClientConn) ResponseHeader() http.Header { @@ -442,33 +436,32 @@ func (cc *grpcClientConn) validateResponse(response *http.Response) *Error { response, cc.responseHeader, cc.responseTrailer, - cc.compressionPools, - cc.protobuf, + cc.CompressionPools, + cc.Protobuf, ); err != nil { return err } compression := getHeaderCanonical(response.Header, grpcHeaderCompression) - cc.unmarshaler.envelopeReader.compressionPool = cc.compressionPools.Get(compression) + cc.recvCompressionPool = cc.CompressionPools.Get(compression) return nil } type grpcHandlerConn struct { - spec Spec - peer Peer - web bool - bufferPool *bufferPool - protobuf Codec // for errors - marshaler grpcMarshaler - responseWriter http.ResponseWriter - responseHeader http.Header - responseTrailer http.Header - wroteToBody bool - request *http.Request - unmarshaler grpcUnmarshaler + *grpcHandler + + peer Peer + request *http.Request + responseWriter http.ResponseWriter + responseHeader http.Header + responseTrailer http.Header + codec Codec + recvCompressionPool *compressionPool + sendCompressionPool *compressionPool + wroteToBody bool } func (hc *grpcHandlerConn) Spec() Spec { - return hc.spec + return hc.protocolHandlerParams.Spec } func (hc *grpcHandlerConn) Peer() Peer { @@ -476,8 +469,28 @@ func (hc *grpcHandlerConn) Peer() Peer { } func (hc *grpcHandlerConn) Receive(msg any) error { - if err := hc.unmarshaler.Unmarshal(msg); err != nil { - return err // already coded + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + flags, err := readEnvelope(buffer, hc.request.Body, hc.ReadMaxBytes) + if err != nil { + return err + } + if flags&flagEnvelopeCompressed != 0 { + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + if err := hc.recvCompressionPool.Decompress( + compressionBuffer, buffer, int64(hc.ReadMaxBytes), + ); err != nil { + return err + } + buffer = compressionBuffer + } + if flags != 0 && flags != flagEnvelopeCompressed { + return newErrInvalidEnvelopeFlags(flags) + } + if err := unmarshal(buffer, msg, hc.codec); err != nil { + return err } return nil // must be a literal nil: nil *Error is a non-nil error } @@ -487,14 +500,36 @@ func (hc *grpcHandlerConn) RequestHeader() http.Header { } func (hc *grpcHandlerConn) Send(msg any) error { - defer flushResponseWriter(hc.responseWriter) if !hc.wroteToBody { mergeHeaders(hc.responseWriter.Header(), hc.responseHeader) hc.wroteToBody = true } - if err := hc.marshaler.Marshal(msg); err != nil { + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + + if err := marshal(buffer, msg, hc.codec); err != nil { + return err + } + var flags uint8 + if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { + compressionBuffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(compressionBuffer) + if err := hc.sendCompressionPool.Compress( + compressionBuffer, buffer, + ); err != nil { + return err + } + buffer = compressionBuffer // swap buffers + flags |= flagEnvelopeCompressed + } + if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { + return err + } + env := envelope{Data: buffer, Flags: flags} + if err := writeAll(hc.responseWriter, env); err != nil { return err } + flushResponseWriter(hc.responseWriter) return nil // must be a literal nil: nil *Error is a non-nil error } @@ -532,7 +567,7 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) { len(hc.responseTrailer)+2, // always make space for status & message ) mergeHeaders(mergedTrailers, hc.responseTrailer) - grpcErrorToTrailer(mergedTrailers, hc.protobuf, err) + grpcErrorToTrailer(mergedTrailers, hc.Codecs.Protobuf(), err) if hc.web && !hc.wroteToBody { // We're using gRPC-Web and we haven't yet written to the body. Since we're // not sending any response messages, the gRPC specification calls this a @@ -547,7 +582,7 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) { if hc.web { // We're using gRPC-Web and we've already sent the headers, so we write // trailing metadata to the HTTP body. - if err := hc.marshaler.MarshalWebTrailers(mergedTrailers); err != nil { + if err := hc.writeWebTrailers(mergedTrailers); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -595,13 +630,17 @@ func (hc *grpcHandlerConn) Close(err error) (retErr error) { return nil } -type grpcMarshaler struct { - envelopeWriter +func (hc *grpcHandlerConn) writeWebTrailers(trailer http.Header) *Error { + buffer := hc.BufferPool.Get() + defer hc.BufferPool.Put(buffer) + if err := grpcMarshalWebTrailers(buffer, trailer); err != nil { + return errorf(CodeInternal, "format trailers: %v", err) + } + env := envelope{Data: buffer, Flags: grpcFlagEnvelopeTrailer} + return writeAll(hc.responseWriter, env) } -func (m *grpcMarshaler) MarshalWebTrailers(trailer http.Header) *Error { - raw := m.envelopeWriter.bufferPool.Get() - defer m.envelopeWriter.bufferPool.Put(raw) +func grpcMarshalWebTrailers(dst *bytes.Buffer, trailer http.Header) error { for key, values := range trailer { // Per the Go specification, keys inserted during iteration may be produced // later in the iteration or may be skipped. For safety, avoid mutating the @@ -613,56 +652,10 @@ func (m *grpcMarshaler) MarshalWebTrailers(trailer http.Header) *Error { delete(trailer, key) trailer[lower] = values } - if err := trailer.Write(raw); err != nil { + if err := trailer.Write(dst); err != nil { return errorf(CodeInternal, "format trailers: %w", err) } - return m.Write(&envelope{ - Data: raw, - Flags: grpcFlagEnvelopeTrailer, - }) -} - -type grpcUnmarshaler struct { - envelopeReader envelopeReader - web bool - webTrailer http.Header -} - -func (u *grpcUnmarshaler) Unmarshal(message any) *Error { - err := u.envelopeReader.Unmarshal(message) - if err == nil { - return nil - } - if !errors.Is(err, errSpecialEnvelope) { - return err - } - env := u.envelopeReader.last - if !u.web || !env.IsSet(grpcFlagEnvelopeTrailer) { - return errorf(CodeInternal, "protocol error: invalid envelope flags %d", env.Flags) - } - - // Per the gRPC-Web specification, trailers should be encoded as an HTTP/1 - // headers block _without_ the terminating newline. To make the headers - // parseable by net/textproto, we need to add the newline. - if err := env.Data.WriteByte('\n'); err != nil { - return errorf(CodeInternal, "unmarshal web trailers: %w", err) - } - bufferedReader := bufio.NewReader(env.Data) - mimeReader := textproto.NewReader(bufferedReader) - mimeHeader, mimeErr := mimeReader.ReadMIMEHeader() - if mimeErr != nil { - return errorf( - CodeInternal, - "gRPC-Web protocol error: trailers invalid: %w", - mimeErr, - ) - } - u.webTrailer = http.Header(mimeHeader) - return errSpecialEnvelope -} - -func (u *grpcUnmarshaler) WebTrailer() http.Header { - return u.webTrailer + return nil } func grpcValidateResponse( @@ -687,12 +680,9 @@ func grpcValidateResponse( availableCompressors.CommaSeparatedNames(), ) } - // When there's no body, gRPC and gRPC-Web servers may send error information - // in the HTTP headers. - if err := grpcErrorFromTrailer( - protobuf, - response.Header, - ); err != nil && !errors.Is(err, errTrailersWithoutGRPCStatus) { + if getHeaderCanonical(response.Header, grpcHeaderStatus) != "" { + // We got what gRPC calls a trailers-only response, which puts the trailing + // metadata (including errors) into HTTP headers. // Per the specification, only the HTTP status code and Content-Type should // be treated as headers. The rest should be treated as trailing metadata. if contentType := getHeaderCanonical(response.Header, headerContentType); contentType != "" { @@ -700,10 +690,7 @@ func grpcValidateResponse( } mergeHeaders(trailer, response.Header) delHeaderCanonical(trailer, headerContentType) - // Also set the error metadata - err.meta = header.Clone() - mergeHeaders(err.meta, trailer) - return err + return grpcErrorFromTrailer(protobuf, trailer) } // The response is valid, so we should expose the headers. mergeHeaders(header, response.Header) @@ -738,7 +725,7 @@ func grpcHTTPToCode(httpCode int) Code { func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error { codeHeader := getHeaderCanonical(trailer, grpcHeaderStatus) if codeHeader == "" { - return NewError(CodeInternal, errTrailersWithoutGRPCStatus) + return errorf(CodeInternal, "protocol error: no %s trailer", grpcHeaderStatus) } if codeHeader == "0" { return nil @@ -750,6 +737,7 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error { } message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage)) retErr := NewWireError(Code(code), errors.New(message)) + retErr.meta = trailer.Clone() detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails) if len(detailsBinaryEncoded) > 0 { @@ -768,7 +756,6 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error { retErr.code = Code(status.Code) retErr.err = errors.New(status.Message) } - return retErr } diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index ee4d9fe3..464cc949 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -15,6 +15,7 @@ package connect import ( + "bytes" "errors" "math" "net/http" @@ -33,7 +34,6 @@ func TestGRPCHandlerSender(t *testing.T) { t.Parallel() newConn := func(web bool) *grpcHandlerConn { responseWriter := httptest.NewRecorder() - protobufCodec := &protoBinaryCodec{} bufferPool := newBufferPool() request, err := http.NewRequest( http.MethodPost, @@ -42,28 +42,17 @@ func TestGRPCHandlerSender(t *testing.T) { ) assert.Nil(t, err) return &grpcHandlerConn{ - spec: Spec{}, - web: web, - bufferPool: bufferPool, - protobuf: protobufCodec, - marshaler: grpcMarshaler{ - envelopeWriter: envelopeWriter{ - writer: responseWriter, - codec: protobufCodec, - bufferPool: bufferPool, + grpcHandler: &grpcHandler{ + protocolHandlerParams: protocolHandlerParams{ + Codecs: newReadOnlyCodecs(map[string]Codec{}), + BufferPool: bufferPool, }, + web: web, }, + request: request, responseWriter: responseWriter, responseHeader: make(http.Header), responseTrailer: make(http.Header), - request: request, - unmarshaler: grpcUnmarshaler{ - envelopeReader: envelopeReader{ - reader: request.Body, - codec: protobufCodec, - bufferPool: bufferPool, - }, - }, } } t.Run("web", func(t *testing.T) { @@ -175,21 +164,13 @@ func TestGRPCPercentEncoding(t *testing.T) { func TestGRPCWebTrailerMarshalling(t *testing.T) { t.Parallel() - responseWriter := httptest.NewRecorder() - marshaler := grpcMarshaler{ - envelopeWriter: envelopeWriter{ - writer: responseWriter, - bufferPool: newBufferPool(), - }, - } trailer := http.Header{} trailer.Add("grpc-status", "0") trailer.Add("Grpc-Message", "Foo") trailer.Add("User-Provided", "bar") - err := marshaler.MarshalWebTrailers(trailer) - assert.Nil(t, err) - responseWriter.Body.Next(5) // skip flags and message length - marshalled := responseWriter.Body.String() + var buf bytes.Buffer + assert.Nil(t, grpcMarshalWebTrailers(&buf, trailer)) + marshalled := buf.String() assert.Equal(t, marshalled, "grpc-message: Foo\r\ngrpc-status: 0\r\nuser-provided: bar\r\n") } From 4711456e507431ac4c8baba7903eaf1474d62e04 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Sat, 9 Sep 2023 13:12:12 -0400 Subject: [PATCH 2/8] Use tmp buffer for compressors --- compression.go | 16 ++++++--- protocol_connect.go | 86 +++++++++------------------------------------ protocol_grpc.go | 20 +++-------- 3 files changed, 32 insertions(+), 90 deletions(-) diff --git a/compression.go b/compression.go index 3f3a8812..b0ed38eb 100644 --- a/compression.go +++ b/compression.go @@ -78,7 +78,10 @@ func newCompressionPool( } } -func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readMaxBytes int64) *Error { +func (c *compressionPool) Decompress(pool *bufferPool, src *bytes.Buffer, readMaxBytes int64) *Error { + tmp := pool.Get() + defer pool.Put(tmp) + decompressor, err := c.getDecompressor(src) if err != nil { return errorf(CodeInvalidArgument, "get decompressor: %w", err) @@ -87,7 +90,7 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM if readMaxBytes > 0 && readMaxBytes < math.MaxInt64 { reader = io.LimitReader(decompressor, readMaxBytes+1) } - bytesRead, err := dst.ReadFrom(reader) + bytesRead, err := tmp.ReadFrom(reader) if err != nil { _ = c.putDecompressor(decompressor) return errorf(CodeInvalidArgument, "decompress: %w", err) @@ -103,11 +106,15 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM if err := c.putDecompressor(decompressor); err != nil { return errorf(CodeUnknown, "recycle decompressor: %w", err) } + *tmp, *src = *src, *tmp // swap buffers return nil } -func (c *compressionPool) Compress(dst *bytes.Buffer, src *bytes.Buffer) *Error { - compressor, err := c.getCompressor(dst) +func (c *compressionPool) Compress(pool *bufferPool, src *bytes.Buffer) *Error { + tmp := pool.Get() + defer pool.Put(tmp) + + compressor, err := c.getCompressor(tmp) if err != nil { return errorf(CodeUnknown, "get compressor: %w", err) } @@ -118,6 +125,7 @@ func (c *compressionPool) Compress(dst *bytes.Buffer, src *bytes.Buffer) *Error if err := c.putCompressor(compressor); err != nil { return errorf(CodeInternal, "recycle compressor: %w", err) } + *tmp, *src = *src, *tmp // swap buffers return nil } diff --git a/protocol_connect.go b/protocol_connect.go index 2d90e24a..1e33122f 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -432,15 +432,9 @@ func (cc *connectUnaryClientConn) Receive(msg any) error { return err } if cc.compressionPool != nil { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) - - if err := cc.compressionPool.Decompress( - compressionBuffer, buffer, int64(cc.ReadMaxBytes), - ); err != nil { + if err := cc.compressionPool.Decompress(cc.BufferPool, buffer, int64(cc.ReadMaxBytes)); err != nil { return err } - buffer = compressionBuffer // swap buffers } if err := unmarshal(buffer, msg, cc.Codec); err != nil { return err @@ -502,14 +496,9 @@ func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Err return err } if cc.compressionPool != nil { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) - if err := cc.compressionPool.Decompress( - compressionBuffer, buffer, int64(cc.ReadMaxBytes), - ); err != nil { + if err := cc.compressionPool.Compress(cc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer } var wireErr connectWireError if err := json.Unmarshal(buffer.Bytes(), &wireErr); err != nil { @@ -553,14 +542,9 @@ func (cc *connectStreamingClientConn) Send(msg any) error { } var flags uint8 if cc.sendCompressionPool != nil && buffer.Len() > cc.CompressMinBytes { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) - if err := cc.sendCompressionPool.Compress( - compressionBuffer, buffer, - ); err != nil { + if err := cc.sendCompressionPool.Compress(cc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer // swap buffers flags |= flagEnvelopeCompressed } if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { @@ -602,14 +586,9 @@ func (cc *connectStreamingClientConn) Receive(msg any) error { connectStreamingHeaderCompression, ) } - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) - if err := cc.recvCompressionPool.Decompress( - compressionBuffer, buffer, int64(cc.ReadMaxBytes), - ); err != nil { + if err := cc.recvCompressionPool.Decompress(cc.BufferPool, buffer, int64(cc.ReadMaxBytes)); err != nil { return err } - buffer = compressionBuffer } if flags != 0 && flags != flagEnvelopeCompressed { end, err := connectUnmarshalEndStreamMessage(buffer, flags) @@ -717,14 +696,11 @@ func (hc *connectUnaryHandlerConn) Receive(msg any) error { return err } if buffer.Len() > 0 && hc.recvCompressionPool != nil { - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) if err := hc.recvCompressionPool.Decompress( - compressionBuffer, buffer, int64(hc.ReadMaxBytes), + hc.BufferPool, buffer, int64(hc.ReadMaxBytes), ); err != nil { return err } - buffer = compressionBuffer } if err := unmarshal(buffer, msg, hc.codec); err != nil { return err @@ -750,19 +726,12 @@ func (hc *connectUnaryHandlerConn) Send(msg any) error { if err := marshal(buffer, msg, hc.codec); err != nil { return err } - var isCompressed bool - if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) - - if err := hc.sendCompressionPool.Compress( - compressionBuffer, buffer, - ); err != nil { + isCompressed := buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil + if isCompressed { + if err := hc.sendCompressionPool.Compress(hc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer // swap buffers setHeaderCanonical(header, connectUnaryHeaderCompression, hc.sendCompressionName) - isCompressed = true } if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, isCompressed); err != nil { delHeaderCanonical(header, connectUnaryHeaderCompression) @@ -869,14 +838,11 @@ func (hc *connectStreamingHandlerConn) Receive(msg any) error { connectStreamingHeaderCompression, ) } - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) if err := hc.recvCompressionPool.Decompress( - compressionBuffer, buffer, int64(hc.ReadMaxBytes), + hc.BufferPool, buffer, int64(hc.ReadMaxBytes), ); err != nil { return err } - buffer = compressionBuffer } if flags != 0 && flags != flagEnvelopeCompressed { end, err := connectUnmarshalEndStreamMessage(buffer, flags) @@ -910,15 +876,9 @@ func (hc *connectStreamingHandlerConn) Send(msg any) error { } var flags uint8 if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) - - if err := hc.sendCompressionPool.Compress( - compressionBuffer, buffer, - ); err != nil { + if err := hc.sendCompressionPool.Compress(hc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer // swap buffers flags |= flagEnvelopeCompressed } if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { @@ -976,19 +936,12 @@ func (cc *connectUnaryClientConn) sendMsg(buffer *bytes.Buffer, msg any) error { if err := marshal(buffer, msg, cc.Codec); err != nil { return err } - var isCompressed bool - if cc.compressionPool != nil && buffer.Len() > cc.CompressMinBytes { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) - - if err := cc.compressionPool.Compress( - compressionBuffer, buffer, - ); err != nil { + isCompressed := cc.compressionPool != nil && buffer.Len() > cc.CompressMinBytes + if isCompressed { + if err := cc.compressionPool.Compress(cc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer // swap buffers setHeaderCanonical(cc.duplexCall.Header(), connectUnaryHeaderCompression, cc.CompressionName) - isCompressed = true } if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, isCompressed); err != nil { delHeaderCanonical(cc.duplexCall.Header(), connectUnaryHeaderCompression) @@ -1016,19 +969,12 @@ func (cc *connectUnaryClientConn) trySendGet(buffer *bytes.Buffer, msg any) erro } isTooBig := cc.SendMaxBytes > 0 && buffer.Len() > cc.SendMaxBytes - isCompressed := false - - if isTooBig && cc.compressionPool != nil { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) + isCompressed := isTooBig && cc.compressionPool != nil - if err := cc.compressionPool.Compress( - compressionBuffer, buffer, - ); err != nil { + if isCompressed { + if err := cc.compressionPool.Compress(cc.BufferPool, buffer); err != nil { return err } - buffer = compressionBuffer // swap buffers - isCompressed = true isTooBig = cc.SendMaxBytes > 0 && buffer.Len() > cc.SendMaxBytes } diff --git a/protocol_grpc.go b/protocol_grpc.go index 0cfeecc0..36d59fd2 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -319,14 +319,11 @@ func (cc *grpcClientConn) Send(msg any) error { } var flags uint8 if buffer.Len() > cc.CompressMinBytes && cc.sendCompressionPool != nil { - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) if err := cc.sendCompressionPool.Compress( - compressionBuffer, buffer, + cc.BufferPool, buffer, ); err != nil { return err } - buffer = compressionBuffer // swap buffers flags |= flagEnvelopeCompressed } if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { @@ -371,14 +368,11 @@ func (cc *grpcClientConn) Receive(msg any) error { grpcHeaderCompression, ) } - compressionBuffer := cc.BufferPool.Get() - defer cc.BufferPool.Put(compressionBuffer) if err := cc.recvCompressionPool.Decompress( - compressionBuffer, buffer, int64(cc.ReadMaxBytes), + cc.BufferPool, buffer, int64(cc.ReadMaxBytes), ); err != nil { return err } - buffer = compressionBuffer // swap buffers } if flags&grpcFlagEnvelopeTrailer != 0 { if !cc.web { @@ -477,14 +471,11 @@ func (hc *grpcHandlerConn) Receive(msg any) error { return err } if flags&flagEnvelopeCompressed != 0 { - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) if err := hc.recvCompressionPool.Decompress( - compressionBuffer, buffer, int64(hc.ReadMaxBytes), + hc.BufferPool, buffer, int64(hc.ReadMaxBytes), ); err != nil { return err } - buffer = compressionBuffer } if flags != 0 && flags != flagEnvelopeCompressed { return newErrInvalidEnvelopeFlags(flags) @@ -512,14 +503,11 @@ func (hc *grpcHandlerConn) Send(msg any) error { } var flags uint8 if buffer.Len() > hc.CompressMinBytes && hc.sendCompressionPool != nil { - compressionBuffer := hc.BufferPool.Get() - defer hc.BufferPool.Put(compressionBuffer) if err := hc.sendCompressionPool.Compress( - compressionBuffer, buffer, + hc.BufferPool, buffer, ); err != nil { return err } - buffer = compressionBuffer // swap buffers flags |= flagEnvelopeCompressed } if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { From 7c05a40010ae4e07955477fbf0e1cc9d66441d0c Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Sat, 9 Sep 2023 14:40:01 -0400 Subject: [PATCH 3/8] Use writeEnvelope --- envelope.go | 47 ++++++++++++++-------------------------- error_writer.go | 4 +--- protocol.go | 6 ++--- protocol_connect.go | 12 +++++----- protocol_connect_test.go | 2 +- protocol_grpc.go | 14 +++++------- protocol_grpc_test.go | 1 + 7 files changed, 33 insertions(+), 53 deletions(-) diff --git a/envelope.go b/envelope.go index 2381be8a..401a25d3 100644 --- a/envelope.go +++ b/envelope.go @@ -25,34 +25,7 @@ import ( // same meaning in the gRPC-Web, gRPC-HTTP2, and Connect protocols. const flagEnvelopeCompressed = 0b00000001 -// envelope is a block of arbitrary bytes wrapped in gRPC and Connect's framing -// protocol. -// -// Each message is preceded by a 5-byte prefix. The first byte is a uint8 used -// as a set of bitwise flags, and the remainder is a uint32 indicating the -// message length. gRPC and Connect interpret the bitwise flags differently, so -// envelope leaves their interpretation up to the caller. -type envelope struct { - Data *bytes.Buffer - Flags uint8 -} - -func (e envelope) WriteTo(w io.Writer) (n int64, err error) { - prefix := [5]byte{} - prefix[0] = e.Flags - binary.BigEndian.PutUint32(prefix[1:5], uint32(e.Data.Len())) - for _, b := range [2][]byte{prefix[:], e.Data.Bytes()} { - wroteN, err := w.Write(b) - if err != nil { - if writeErr, ok := asError(err); ok { - return n, writeErr - } - return n, errorf(CodeUnknown, "write envelope: %w", err) - } - n += int64(wroteN) - } - return n, nil -} +var errEOF = errorf(CodeInternal, "%w", io.EOF) func marshal(dst *bytes.Buffer, message any, codec Codec) *Error { if message == nil { @@ -134,8 +107,6 @@ func readAll(dst *bytes.Buffer, src io.Reader, readMaxBytes int) *Error { } } -var errEOF = errorf(CodeInternal, "%w", io.EOF) - func readEnvelope(dst *bytes.Buffer, src io.Reader, readMaxBytes int) (uint8, *Error) { prefix := [5]byte{} if _, err := io.ReadFull(src, prefix[:]); err != nil { @@ -179,7 +150,8 @@ func readEnvelope(dst *bytes.Buffer, src io.Reader, readMaxBytes int) (uint8, *E } return prefix[0], nil } -func writeAll(dst io.Writer, src io.WriterTo) *Error { + +func writeAll(dst io.Writer, src *bytes.Buffer) *Error { if _, err := src.WriteTo(dst); err != nil { if writeErr, ok := asError(err); ok { return writeErr @@ -189,6 +161,19 @@ func writeAll(dst io.Writer, src io.WriterTo) *Error { return nil } +func writeEnvelope(dst io.Writer, src *bytes.Buffer, flags uint8) *Error { + prefix := [5]byte{} + prefix[0] = flags + binary.BigEndian.PutUint32(prefix[1:5], uint32(src.Len())) + if _, err := dst.Write(prefix[:]); err != nil { + if writeErr, ok := asError(err); ok { + return writeErr + } + return errorf(CodeUnknown, "write envelope: %w", err) + } + return writeAll(dst, src) +} + func checkSendMaxBytes(length, sendMaxBytes int, isCompressed bool) *Error { if sendMaxBytes <= 0 || length <= sendMaxBytes { return nil diff --git a/error_writer.go b/error_writer.go index 78928368..f50a6ca0 100644 --- a/error_writer.go +++ b/error_writer.go @@ -137,10 +137,8 @@ func (w *ErrorWriter) writeConnectStreaming(response http.ResponseWriter, err er if err := connectMarshalEndStreamMessage(buffer, end); err != nil { return err } - response.WriteHeader(http.StatusOK) - env := envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream} - if err := writeAll(response, env); err != nil { + if err := writeEnvelope(response, buffer, connectFlagEnvelopeEndStream); err != nil { return err } return nil diff --git a/protocol.go b/protocol.go index b24ec603..3ca79e6a 100644 --- a/protocol.go +++ b/protocol.go @@ -286,13 +286,13 @@ func isCommaOrSpace(c rune) bool { func discard(reader io.Reader) (int64, error) { // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. - n, err := io.CopyN(io.Discard, reader, discardLimit) + wroteN, err := io.CopyN(io.Discard, reader, discardLimit) if errors.Is(err, io.EOF) { err = nil - } else if n == discardLimit { + } else if wroteN == discardLimit { err = io.ErrShortBuffer } - return n, err + return wroteN, err } // negotiateCompression determines and validates the request compression and diff --git a/protocol_connect.go b/protocol_connect.go index 1e33122f..7733c5cd 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -439,6 +439,9 @@ func (cc *connectUnaryClientConn) Receive(msg any) error { if err := unmarshal(buffer, msg, cc.Codec); err != nil { return err } + if err := ensureEOF(cc.duplexCall); !errors.Is(err, io.EOF) { + return err + } return nil // must be a literal nil: nil *Error is a non-nil error } @@ -550,8 +553,7 @@ func (cc *connectStreamingClientConn) Send(msg any) error { if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { return err } - env := envelope{Data: buffer, Flags: flags} - if err := writeAll(cc.duplexCall, env); err != nil { + if err := writeEnvelope(cc.duplexCall, buffer, flags); err != nil { return err } return nil // must be a literal nil: nil *error is a non-nil error @@ -884,8 +886,7 @@ func (hc *connectStreamingHandlerConn) Send(msg any) error { if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { return err } - env := envelope{Data: buffer, Flags: flags} - if err := writeAll(hc.responseWriter, env); err != nil { + if err := writeEnvelope(hc.responseWriter, buffer, flags); err != nil { return err } flushResponseWriter(hc.responseWriter) @@ -928,8 +929,7 @@ func (hc *connectStreamingHandlerConn) marshalEndStream(err error, trailer http. if err := connectMarshalEndStreamMessage(buffer, end); err != nil { return err } - env := envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream} - return writeAll(hc.responseWriter, env) + return writeEnvelope(hc.responseWriter, buffer, connectFlagEnvelopeEndStream) } func (cc *connectUnaryClientConn) sendMsg(buffer *bytes.Buffer, msg any) error { diff --git a/protocol_connect_test.go b/protocol_connect_test.go index 80d9257c..bfebd257 100644 --- a/protocol_connect_test.go +++ b/protocol_connect_test.go @@ -70,7 +70,7 @@ func TestConnectEndOfResponseCanonicalTrailers(t *testing.T) { assert.Nil(t, err) output := &bytes.Buffer{} - err = writeAll(output, envelope{Data: buffer, Flags: connectFlagEnvelopeEndStream}) + err = writeEnvelope(output, buffer, connectFlagEnvelopeEndStream) assert.Nil(t, err) input := &bytes.Buffer{} diff --git a/protocol_grpc.go b/protocol_grpc.go index 36d59fd2..34a72adf 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -314,14 +314,13 @@ func (cc *grpcClientConn) Peer() Peer { func (cc *grpcClientConn) Send(msg any) error { buffer := cc.BufferPool.Get() defer cc.BufferPool.Put(buffer) + if err := marshal(buffer, msg, cc.Codec); err != nil { return err } var flags uint8 if buffer.Len() > cc.CompressMinBytes && cc.sendCompressionPool != nil { - if err := cc.sendCompressionPool.Compress( - cc.BufferPool, buffer, - ); err != nil { + if err := cc.sendCompressionPool.Compress(cc.BufferPool, buffer); err != nil { return err } flags |= flagEnvelopeCompressed @@ -329,8 +328,7 @@ func (cc *grpcClientConn) Send(msg any) error { if err := checkSendMaxBytes(buffer.Len(), cc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { return err } - env := envelope{Data: buffer, Flags: flags} - if err := writeAll(cc.duplexCall, env); err != nil { + if err := writeEnvelope(cc.duplexCall, buffer, flags); err != nil { return err } return nil // must be a literal nil: nil *Error is a non-nil error @@ -513,8 +511,7 @@ func (hc *grpcHandlerConn) Send(msg any) error { if err := checkSendMaxBytes(buffer.Len(), hc.SendMaxBytes, flags&flagEnvelopeCompressed > 0); err != nil { return err } - env := envelope{Data: buffer, Flags: flags} - if err := writeAll(hc.responseWriter, env); err != nil { + if err := writeEnvelope(hc.responseWriter, buffer, flags); err != nil { return err } flushResponseWriter(hc.responseWriter) @@ -624,8 +621,7 @@ func (hc *grpcHandlerConn) writeWebTrailers(trailer http.Header) *Error { if err := grpcMarshalWebTrailers(buffer, trailer); err != nil { return errorf(CodeInternal, "format trailers: %v", err) } - env := envelope{Data: buffer, Flags: grpcFlagEnvelopeTrailer} - return writeAll(hc.responseWriter, env) + return writeEnvelope(hc.responseWriter, buffer, grpcFlagEnvelopeTrailer) } func grpcMarshalWebTrailers(dst *bytes.Buffer, trailer http.Header) error { diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 464cc949..1a6adb26 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -43,6 +43,7 @@ func TestGRPCHandlerSender(t *testing.T) { assert.Nil(t, err) return &grpcHandlerConn{ grpcHandler: &grpcHandler{ + //nolint:exhaustruct protocolHandlerParams: protocolHandlerParams{ Codecs: newReadOnlyCodecs(map[string]Codec{}), BufferPool: bufferPool, From 739164e365ab202559fcd6033e6b97e243e5b419 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Mon, 11 Sep 2023 19:18:54 -0400 Subject: [PATCH 4/8] Fix malicious client prefix reading --- connect_ext_test.go | 1 + envelope.go | 37 ++++++++-------- envelope_test.go | 103 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 envelope_test.go diff --git a/connect_ext_test.go b/connect_ext_test.go index e13cd098..d12cb4e6 100644 --- a/connect_ext_test.go +++ b/connect_ext_test.go @@ -2320,6 +2320,7 @@ func TestStreamUnexpectedEOF(t *testing.T) { for i := 0; stream.Receive() && i < upTo; i++ { assert.Equal(t, stream.Msg().Number, 42) } + t.Log("err:", stream.Err()) assert.NotNil(t, stream.Err()) assert.Equal(t, connect.CodeOf(stream.Err()), testcase.expectCode) assert.Equal(t, stream.Err().Error(), testcase.expectMsg) diff --git a/envelope.go b/envelope.go index 401a25d3..7ff31e49 100644 --- a/envelope.go +++ b/envelope.go @@ -69,10 +69,9 @@ func unmarshal(src *bytes.Buffer, message any, codec Codec) *Error { func read(dst *bytes.Buffer, src io.Reader) (int, error) { dst.Grow(bytes.MinRead) - b := dst.Bytes() - b = b[len(b):cap(b)] + b := dst.Bytes()[dst.Len():dst.Cap()] n, err := src.Read(b) - _, _ = dst.Write(b[:n]) + _, _ = dst.Write(b[:n]) // noop return n, err } @@ -125,28 +124,30 @@ func readEnvelope(dst *bytes.Buffer, src io.Reader, readMaxBytes int) (uint8, *E } size := int(binary.BigEndian.Uint32(prefix[1:5])) - if size < 0 { + switch { + case size < 0: return 0, errorf(CodeInvalidArgument, "message size %d overflowed uint32", size) - } - if readMaxBytes > 0 && size > readMaxBytes { + case readMaxBytes > 0 && size > readMaxBytes: if _, err := discard(src); err != nil { return 0, errorf(CodeUnknown, "read enveloped message: %w", err) } return 0, errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, readMaxBytes) + case size == 0: + return prefix[0], nil } - if size > 0 { - dst.Grow(size) - data := dst.Bytes()[dst.Len() : dst.Len()+size] - if _, err := io.ReadFull(src, data); err != nil { - if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { - // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. - return 0, maxBytesErr - } - return 0, errorf(CodeInternal, "incomplete envelope: %w", err) - } - if _, err := dst.Write(data); err != nil { - return 0, errorf(CodeInternal, "read enveloped message: %w", err) + + // Don't allocate the entire buffer up front to avoid malicious clients. + // Instead, limit the size of the source to the message size. + src = io.LimitReader(src, int64(size)) + if readN, err := dst.ReadFrom(src); err != nil { + if maxBytesErr := asMaxBytesError(err, "read %d byte message", size); maxBytesErr != nil { + // We're reading from an http.MaxBytesHandler, and we've exceeded the read limit. + return 0, maxBytesErr } + return 0, errorf(CodeInternal, "incomplete envelope: %w", err) + } else if readN != int64(size) { + err = io.ErrUnexpectedEOF + return 0, errorf(CodeInternal, "incomplete envelope: %w", err) } return prefix[0], nil } diff --git a/envelope_test.go b/envelope_test.go new file mode 100644 index 00000000..4f843f97 --- /dev/null +++ b/envelope_test.go @@ -0,0 +1,103 @@ +// Copyright 2021-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connect + +import ( + "bytes" + "compress/gzip" + "io" + "strings" + "testing" + + "connectrpc.com/connect/internal/assert" + pingv1 "connectrpc.com/connect/internal/gen/connect/ping/v1" +) + +func TestBuffers_ReadAllAllocs(t *testing.T) { + t.Parallel() + str := "hello world" + strings.Repeat("b", bytes.MinRead) + src := strings.NewReader(str) + buf := &bytes.Buffer{} + + avg := testing.AllocsPerRun(4, func() { + _, _ = src.Seek(0, 0) + buf.Reset() + if err := readAll(buf, src, maxRecycleBufferSize); err != nil { + t.Fatal(err) + } + }) + t.Log(avg) + assert.Equal(t, str, buf.String()) + assert.True(t, avg <= 1.0) +} + +func TestBuffers_ReadEnvelopeAllocs(t *testing.T) { + t.Parallel() + env := bytes.Buffer{} + str := "hello world" + strings.Repeat("b", bytes.MinRead) + assert.Nil(t, writeEnvelope(&env, bytes.NewBufferString(str), 0)) + env.Write(make([]byte, 0, maxRecycleBufferSize)) // large stream, greater than maxRecycledBufferSize + src := bytes.NewReader(env.Bytes()) + buf := &bytes.Buffer{} + + avg := testing.AllocsPerRun(4, func() { + _, _ = src.Seek(0, 0) + buf.Reset() + _, err := readEnvelope(buf, src, maxRecycleBufferSize) + assert.Nil(t, err) + }) + t.Log(avg) + assert.Equal(t, len(str), buf.Len()) + assert.Equal(t, str, buf.String()) + assert.True(t, avg <= 2.0) +} + +func TestBuffers_Marshal(t *testing.T) { + t.Parallel() + codec := &protoBinaryCodec{} + msg := &pingv1.PingRequest{Text: "hello world"} + buf := &bytes.Buffer{} + + avg := testing.AllocsPerRun(4, func() { + buf.Reset() + assert.Nil(t, marshal(buf, msg, codec)) + assert.Nil(t, unmarshal(buf, msg, codec)) + assert.Equal(t, "hello world", msg.Text) + }) + t.Log(avg) + assert.True(t, avg <= 16.0) + assert.True(t, strings.Contains(buf.String(), "hello world")) +} + +func TestBuffers_Compress(t *testing.T) { + t.Parallel() + pool := &bufferPool{} + input := `{"text":"` + strings.Repeat("a", bytes.MinRead) + `"}` + src := bytes.NewBufferString(input) + comp := newCompressionPool( + func() Decompressor { return &gzip.Reader{} }, + func() Compressor { return gzip.NewWriter(io.Discard) }, + ) + + avg := testing.AllocsPerRun(4, func() { + assert.Nil(t, comp.Compress(pool, src)) + assert.True(t, src.Len() < len(input)) + assert.Nil(t, comp.Decompress(pool, src, 0)) + assert.True(t, src.Len() == len(input)) + }) + t.Log(avg) + // Don't check avg because it's dependent on buffer.Pool's behavior. + assert.Equal(t, input, src.String()) +} From c710296658ea7815294906c6e052deabf7ed56d1 Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 12 Sep 2023 10:04:01 -0400 Subject: [PATCH 5/8] Fix discard limit check --- protocol.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/protocol.go b/protocol.go index 3ca79e6a..3089d1a4 100644 --- a/protocol.go +++ b/protocol.go @@ -284,12 +284,19 @@ func isCommaOrSpace(c rune) bool { } func discard(reader io.Reader) (int64, error) { + lreader, ok := reader.(*io.LimitedReader) + if !ok { + lreader = &io.LimitedReader{R: reader, N: discardLimit} + } + limit := lreader.N // We don't want to get stuck throwing data away forever, so limit how much // we're willing to do here. - wroteN, err := io.CopyN(io.Discard, reader, discardLimit) + wroteN, err := io.Copy(io.Discard, lreader) if errors.Is(err, io.EOF) { err = nil - } else if wroteN == discardLimit { + } + if wroteN == limit { + // Ensure we error if we hit the limit. err = io.ErrShortBuffer } return wroteN, err From 9ddd9043bbe84b6050d36f3072d61b53746ef42e Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 12 Sep 2023 15:10:29 -0400 Subject: [PATCH 6/8] Use json.NewEncoder for connect EOS message --- protocol_connect.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/protocol_connect.go b/protocol_connect.go index 7733c5cd..5d0fca48 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1273,10 +1273,9 @@ func connectUnmarshalEndStreamMessage(src *bytes.Buffer, flags uint8) (*connectE return &end, nil } func connectMarshalEndStreamMessage(dst *bytes.Buffer, end *connectEndStreamMessage) *Error { - data, marshalErr := json.Marshal(end) - if marshalErr != nil { - return errorf(CodeInternal, "marshal end stream: %w", marshalErr) + enc := json.NewEncoder(dst) + if err := enc.Encode(end); err != nil { + return errorf(CodeInternal, "marshal end stream: %w", err) } - dst.Write(data) return nil } From 9c11e78244b750401106ca9c419c462613159a0e Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 12 Sep 2023 15:16:26 -0400 Subject: [PATCH 7/8] Relax buffers marshalling test --- envelope_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/envelope_test.go b/envelope_test.go index 4f843f97..10aa34b7 100644 --- a/envelope_test.go +++ b/envelope_test.go @@ -77,7 +77,7 @@ func TestBuffers_Marshal(t *testing.T) { assert.Equal(t, "hello world", msg.Text) }) t.Log(avg) - assert.True(t, avg <= 16.0) + // Don't check avg because it's dependent on proto.Marshal's behavior. assert.True(t, strings.Contains(buf.String(), "hello world")) } From b6e458762d170c1e233a8604e057749e53400edd Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Tue, 3 Oct 2023 17:39:08 +0100 Subject: [PATCH 8/8] Fix lint and better buf for marshal --- envelope.go | 6 +++++- envelope_test.go | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/envelope.go b/envelope.go index c890d8cb..456df5eb 100644 --- a/envelope.go +++ b/envelope.go @@ -56,7 +56,11 @@ func marshal(dst *bytes.Buffer, message any, codec Codec) *Error { if err != nil { return errorf(CodeInternal, "marshal message: %w", err) } - dst.Write(raw) + if dst.Cap() < len(raw) { + *dst = *bytes.NewBuffer(raw) + } else { + dst.Write(raw) + } return nil } diff --git a/envelope_test.go b/envelope_test.go index 10aa34b7..351267f7 100644 --- a/envelope_test.go +++ b/envelope_test.go @@ -1,4 +1,4 @@ -// Copyright 2021-2023 Buf Technologies, Inc. +// Copyright 2021-2023 The Connect Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.