Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup gRPC util funcs for timeout and percent encoding #596

Merged
merged 10 commits into from
Oct 3, 2023
7 changes: 5 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ issues:
# We need to init a global in-mem HTTP server for testable examples.
- linters: [gochecknoinits, gochecknoglobals]
path: example_init_test.go
# We need to initialize a global map from a slice.
- linters: [gochecknoinits, gochecknoglobals]
# We need to initialize default grpc User-Agent
- linters: [gochecknoglobals]
path: protocol_grpc.go
# We need to initialize default connect User-Agent
- linters: [gochecknoglobals]
path: protocol_connect.go
# We purposefully do an ineffectual assignment for an example.
- linters: [ineffassign]
path: client_example_test.go
Expand Down
2 changes: 0 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ const (
)

// defaultConnectUserAgent returns a User-Agent string similar to those used in gRPC.
//
//nolint:gochecknoglobals
var defaultConnectUserAgent = fmt.Sprintf("connect-go/%s (%s)", Version, runtime.Version())

type protocolConnect struct{}
Expand Down
210 changes: 124 additions & 86 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"

statusv1 "connectrpc.com/connect/internal/gen/connectext/grpc/status/v1"
)
Expand All @@ -42,9 +41,6 @@ const (

grpcFlagEnvelopeTrailer = 0b10000000

grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration?
grpcMaxTimeoutChars = 8 // from gRPC protocol

grpcContentTypeDefault = "application/grpc"
grpcWebContentTypeDefault = "application/grpc-web"
grpcContentTypePrefix = grpcContentTypeDefault + "+"
Expand All @@ -54,21 +50,6 @@ const (
)

var (
grpcTimeoutUnits = []struct {
size time.Duration
char byte
}{
{time.Nanosecond, 'n'},
{time.Microsecond, 'u'},
{time.Millisecond, 'm'},
{time.Second, 'S'},
{time.Minute, 'M'},
{time.Hour, 'H'},
}
grpcTimeoutUnitLookup = make(map[byte]time.Duration)
grpcAllowedMethods = map[string]struct{}{
http.MethodPost: {},
}
errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF)

// defaultGrpcUserAgent follows
Expand All @@ -83,12 +64,6 @@ var (
defaultGrpcUserAgent = fmt.Sprintf("grpc-go-connect/%s (%s)", Version, runtime.Version())
)

func init() {
for _, pair := range grpcTimeoutUnits {
grpcTimeoutUnitLookup[pair.char] = pair.size
}
}

type protocolGRPC struct {
web bool
}
Expand Down Expand Up @@ -134,7 +109,7 @@ type grpcHandler struct {
}

func (g *grpcHandler) Methods() map[string]struct{} {
return grpcAllowedMethods
return map[string]struct{}{http.MethodPost: {}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an extra allocation along the request path. The returned map isn't accessible to Connect users, so I think we ought to keep the code as-is.

Was there a safety concern here, or were we just trying to get rid of globals?

}

func (g *grpcHandler) ContentTypes() map[string]struct{} {
Expand Down Expand Up @@ -285,11 +260,8 @@ func (g *grpcClient) NewConn(
header http.Header,
) streamingClientConn {
if deadline, ok := ctx.Deadline(); ok {
if encodedDeadline, err := grpcEncodeTimeout(time.Until(deadline)); err == nil {
// Tests verify that the error in encodeTimeout is unreachable, so we
// don't need to handle the error case.
header[grpcHeaderTimeout] = []string{encodedDeadline}
}
encodedDeadline := grpcEncodeTimeout(time.Until(deadline))
header[grpcHeaderTimeout] = []string{encodedDeadline}
}
duplexCall := newDuplexHTTPCall(
ctx,
Expand Down Expand Up @@ -748,7 +720,10 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error {
if err != nil {
return errorf(CodeInternal, "protocol error: invalid error code %q", codeHeader)
}
message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
message, err := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
if err != nil {
return errorf(CodeInternal, "protocol error: invalid error message %q", message)
}
retErr := NewWireError(Code(code), errors.New(message))

detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails)
Expand Down Expand Up @@ -776,8 +751,8 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
if timeout == "" {
return 0, errNoTimeout
}
unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]]
if !ok {
unit := grpcTimeoutUnitLookup(timeout[len(timeout)-1])
if unit == 0 {
return 0, fmt.Errorf("protocol error: timeout %q has invalid unit", timeout)
}
num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */)
Expand All @@ -787,6 +762,7 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
if num > 99999999 { // timeout must be ASCII string of at most 8 digits
return 0, fmt.Errorf("protocol error: timeout %q is too long", timeout)
}
const grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration?
if unit == time.Hour && num > grpcTimeoutMaxHours {
// Timeout is effectively unbounded, so ignore it. The grpc-go
// implementation does the same thing.
Expand All @@ -795,19 +771,50 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
return time.Duration(num) * unit, nil
}

func grpcEncodeTimeout(timeout time.Duration) (string, error) {
func grpcEncodeTimeout(timeout time.Duration) string {
if timeout <= 0 {
return "0n", nil
return "0n"
}
for _, pair := range grpcTimeoutUnits {
digits := strconv.FormatInt(int64(timeout/pair.size), 10 /* base */)
if len(digits) < grpcMaxTimeoutChars {
return digits + string(pair.char), nil
}
const grpcTimeoutMaxValue = 1e8
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved
var (
size time.Duration
unit byte
)
switch {
case timeout < time.Nanosecond*grpcTimeoutMaxValue:
size, unit = time.Nanosecond, 'n'
case timeout < time.Microsecond*grpcTimeoutMaxValue:
size, unit = time.Microsecond, 'u'
case timeout < time.Millisecond*grpcTimeoutMaxValue:
size, unit = time.Millisecond, 'm'
case timeout < time.Second*grpcTimeoutMaxValue:
size, unit = time.Second, 'S'
case timeout < time.Minute*grpcTimeoutMaxValue:
size, unit = time.Minute, 'M'
default:
size, unit = time.Hour, 'H'
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved
}
value := timeout / size
return strconv.FormatInt(int64(value), 10 /* base */) + string(unit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we're golfing allocs, we can probably shave an alloc here by using strconv.AppendInt.

}

func grpcTimeoutUnitLookup(unit byte) time.Duration {
switch unit {
case 'n':
return time.Nanosecond
case 'u':
return time.Microsecond
case 'm':
return time.Millisecond
case 'S':
return time.Second
case 'M':
return time.Minute
case 'H':
return time.Hour
default:
return 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API-wise, this doesn't feel like an improvement to me: we're overloading the zero value in a way that feels odd in Go, and returning an error won't measurably affect performance.

}
// The max time.Duration is smaller than the maximum expressible gRPC
// timeout, so we can't reach this case.
return "", errNoTimeout
}

func grpcCodecFromContentType(web bool, contentType string) string {
Expand Down Expand Up @@ -887,61 +894,92 @@ func grpcStatusFromError(err error) *statusv1.Status {
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1
func grpcPercentEncode(msg string) string {
var hexCount int
for i := 0; i < len(msg); i++ {
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
if c := msg[i]; c < ' ' || c > '~' || c == '%' {
return grpcPercentEncodeSlow(msg, i)
if grpcShouldEscape(msg[i]) {
hexCount++
}
}
return msg
}

// msg needs some percent-escaping. Bytes before offset don't require
// percent-encoding, so they can be copied to the output as-is.
func grpcPercentEncodeSlow(msg string, offset int) string {
if hexCount == 0 {
return msg
}
// We need to escape some characters, so we'll need to allocate a new string.
var out strings.Builder
out.Grow(2 * len(msg))
out.WriteString(msg[:offset])
for i := offset; i < len(msg); i++ {
c := msg[i]
if c < ' ' || c > '~' || c == '%' {
fmt.Fprintf(&out, "%%%02X", c)
continue
out.Grow(len(msg) + 2*hexCount)
mattrobenolt marked this conversation as resolved.
Show resolved Hide resolved
for i := 0; i < len(msg); i++ {
switch char := msg[i]; {
case grpcShouldEscape(char):
out.WriteByte('%')
out.WriteByte(upperhex[char>>4])
out.WriteByte(upperhex[char&15])
default:
out.WriteByte(char)
}
out.WriteByte(c)
}
return out.String()
}

func grpcPercentDecode(encoded string) string {
for i := 0; i < len(encoded); i++ {
if c := encoded[i]; c == '%' && i+2 < len(encoded) {
return grpcPercentDecodeSlow(encoded, i)
func grpcPercentDecode(input string) (string, error) {
percentCount := 0
for i := 0; i < len(input); {
switch input[i] {
case '%':
percentCount++
if err := validateHex(input[i:]); err != nil {
return "", err
}
i += 3
default:
i++
}
}
return encoded
}

// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be
// decoded byte-by-byte starting at offset.
func grpcPercentDecodeSlow(encoded string, offset int) string {
if percentCount == 0 {
return input, nil
}
// We need to unescape some characters, so we'll need to allocate a new string.
var out strings.Builder
out.Grow(len(encoded))
out.WriteString(encoded[:offset])
for i := offset; i < len(encoded); i++ {
c := encoded[i]
if c != '%' || i+2 >= len(encoded) {
out.WriteByte(c)
continue
out.Grow(len(input) - 2*percentCount)
for i := 0; i < len(input); i++ {
switch input[i] {
case '%':
out.WriteByte(unhex(input[i+1])<<4 | unhex(input[i+2]))
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
i += 2
default:
out.WriteByte(input[i])
}
parsed, err := strconv.ParseUint(encoded[i+1:i+3], 16 /* hex */, 8 /* bitsize */)
if err != nil {
out.WriteRune(utf8.RuneError)
} else {
out.WriteByte(byte(parsed))
}
return out.String(), nil
}

// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
func grpcShouldEscape(char byte) bool {
return char < ' ' || char > '~' || char == '%'
}

const upperhex = "0123456789ABCDEF"
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved

func unhex(char byte) byte {
switch {
case '0' <= char && char <= '9':
return char - '0'
case 'a' <= char && char <= 'f':
return char - 'a' + 10
case 'A' <= char && char <= 'F':
return char - 'A' + 10
}
return 0
}
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved
func isHex(char byte) bool {
return ('0' <= char && char <= '9') || ('a' <= char && char <= 'f') || ('A' <= char && char <= 'F')
}

func validateHex(input string) error {
if len(input) < 3 || input[0] != '%' || !isHex(input[1]) || !isHex(input[2]) {
if len(input) > 3 {
input = input[:3]
}
i += 2
return fmt.Errorf("invalid percent-encoded string %q", input)
}
return out.String()
return nil
}
Loading
Loading