Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup gRPC util funcs for timeout and percent encoding #596

Merged
merged 10 commits into from
Oct 3, 2023
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Autogenerated by makego. DO NOT EDIT.
/.tmp/
*.pprof
*.svg
cover.out
connect.test
7 changes: 5 additions & 2 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ issues:
# We need to init a global in-mem HTTP server for testable examples.
- linters: [gochecknoinits, gochecknoglobals]
path: example_init_test.go
# We need to initialize a global map from a slice.
- linters: [gochecknoinits, gochecknoglobals]
# We need to initialize default grpc User-Agent
- linters: [gochecknoglobals]
path: protocol_grpc.go
# We need to initialize default connect User-Agent
- linters: [gochecknoglobals]
path: protocol_connect.go
# We purposefully do an ineffectual assignment for an example.
- linters: [ineffassign]
path: client_example_test.go
Expand Down
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ clean: ## Delete intermediate build artifacts
test: build ## Run unit tests
go test -vet=off -race -cover ./...

.PHONY: bench
bench: BENCH ?= .*
bench: build ## Run benchmarks for root package
go test -vet=off -run '^$$' -bench '$(BENCH)' -benchmem -cpuprofile cpu.pprof -memprofile mem.pprof .

.PHONY: build
build: generate ## Build all packages
go build ./...
Expand Down
2 changes: 0 additions & 2 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ const (
)

// defaultConnectUserAgent returns a User-Agent string similar to those used in gRPC.
//
//nolint:gochecknoglobals
var defaultConnectUserAgent = fmt.Sprintf("connect-go/%s (%s)", Version, runtime.Version())

type protocolConnect struct{}
Expand Down
220 changes: 134 additions & 86 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"strconv"
"strings"
"time"
"unicode/utf8"

statusv1 "connectrpc.com/connect/internal/gen/connectext/grpc/status/v1"
)
Expand All @@ -42,33 +41,17 @@ const (

grpcFlagEnvelopeTrailer = 0b10000000

grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration?
grpcMaxTimeoutChars = 8 // from gRPC protocol

grpcContentTypeDefault = "application/grpc"
grpcWebContentTypeDefault = "application/grpc-web"
grpcContentTypePrefix = grpcContentTypeDefault + "+"
grpcWebContentTypePrefix = grpcWebContentTypeDefault + "+"

headerXUserAgent = "X-User-Agent"

upperhex = "0123456789ABCDEF"
)

var (
grpcTimeoutUnits = []struct {
size time.Duration
char byte
}{
{time.Nanosecond, 'n'},
{time.Microsecond, 'u'},
{time.Millisecond, 'm'},
{time.Second, 'S'},
{time.Minute, 'M'},
{time.Hour, 'H'},
}
grpcTimeoutUnitLookup = make(map[byte]time.Duration)
grpcAllowedMethods = map[string]struct{}{
http.MethodPost: {},
}
errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF)

// defaultGrpcUserAgent follows
Expand All @@ -81,13 +64,10 @@ var (
//
// User-Agent → "grpc-" Language ?("-" Variant) "/" Version ?( " (" *(AdditionalProperty ";") ")" )
defaultGrpcUserAgent = fmt.Sprintf("grpc-go-connect/%s (%s)", Version, runtime.Version())
)

func init() {
for _, pair := range grpcTimeoutUnits {
grpcTimeoutUnitLookup[pair.char] = pair.size
grpcAllowedMethods = map[string]struct{}{
http.MethodPost: {},
}
}
)

type protocolGRPC struct {
web bool
Expand Down Expand Up @@ -285,11 +265,8 @@ func (g *grpcClient) NewConn(
header http.Header,
) streamingClientConn {
if deadline, ok := ctx.Deadline(); ok {
if encodedDeadline, err := grpcEncodeTimeout(time.Until(deadline)); err == nil {
// Tests verify that the error in encodeTimeout is unreachable, so we
// don't need to handle the error case.
header[grpcHeaderTimeout] = []string{encodedDeadline}
}
encodedDeadline := grpcEncodeTimeout(time.Until(deadline))
header[grpcHeaderTimeout] = []string{encodedDeadline}
}
duplexCall := newDuplexHTTPCall(
ctx,
Expand Down Expand Up @@ -748,7 +725,10 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error {
if err != nil {
return errorf(CodeInternal, "protocol error: invalid error code %q", codeHeader)
}
message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
message, err := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage))
if err != nil {
return errorf(CodeInternal, "protocol error: invalid error message %q", message)
}
retErr := NewWireError(Code(code), errors.New(message))

detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails)
Expand Down Expand Up @@ -776,9 +756,9 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
if timeout == "" {
return 0, errNoTimeout
}
unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]]
if !ok {
return 0, fmt.Errorf("protocol error: timeout %q has invalid unit", timeout)
unit, err := grpcTimeoutUnitLookup(timeout[len(timeout)-1])
if err != nil {
return 0, err
}
num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */)
if err != nil || num < 0 {
Expand All @@ -787,6 +767,7 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
if num > 99999999 { // timeout must be ASCII string of at most 8 digits
return 0, fmt.Errorf("protocol error: timeout %q is too long", timeout)
}
const grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration?
if unit == time.Hour && num > grpcTimeoutMaxHours {
// Timeout is effectively unbounded, so ignore it. The grpc-go
// implementation does the same thing.
Expand All @@ -795,19 +776,56 @@ func grpcParseTimeout(timeout string) (time.Duration, error) {
return time.Duration(num) * unit, nil
}

func grpcEncodeTimeout(timeout time.Duration) (string, error) {
func grpcEncodeTimeout(timeout time.Duration) string {
if timeout <= 0 {
return "0n", nil
return "0n"
}
for _, pair := range grpcTimeoutUnits {
digits := strconv.FormatInt(int64(timeout/pair.size), 10 /* base */)
if len(digits) < grpcMaxTimeoutChars {
return digits + string(pair.char), nil
}
// The gRPC protocol limits timeouts to 8 characters (not counting the unit),
// so timeouts must be strictly less than 1e8 of the appropriate unit.
const grpcTimeoutMaxValue = 1e8
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved
var (
size time.Duration
unit byte
)
switch {
case timeout < time.Nanosecond*grpcTimeoutMaxValue:
size, unit = time.Nanosecond, 'n'
case timeout < time.Microsecond*grpcTimeoutMaxValue:
size, unit = time.Microsecond, 'u'
case timeout < time.Millisecond*grpcTimeoutMaxValue:
size, unit = time.Millisecond, 'm'
case timeout < time.Second*grpcTimeoutMaxValue:
size, unit = time.Second, 'S'
case timeout < time.Minute*grpcTimeoutMaxValue:
size, unit = time.Minute, 'M'
default:
// time.Duration is an int64 number of nanoseconds, so the largest
// expressible duration is less than 1e8 hours.
size, unit = time.Hour, 'H'
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved
}
buf := make([]byte, 0, 9)
buf = strconv.AppendInt(buf, int64(timeout/size), 10 /* base */)
buf = append(buf, unit)
return string(buf)
}

func grpcTimeoutUnitLookup(unit byte) (time.Duration, error) {
switch unit {
case 'n':
return time.Nanosecond, nil
case 'u':
return time.Microsecond, nil
case 'm':
return time.Millisecond, nil
case 'S':
return time.Second, nil
case 'M':
return time.Minute, nil
case 'H':
return time.Hour, nil
default:
return 0, fmt.Errorf("protocol error: timeout has invalid unit %q", unit)
}
// The max time.Duration is smaller than the maximum expressible gRPC
// timeout, so we can't reach this case.
return "", errNoTimeout
}

func grpcCodecFromContentType(web bool, contentType string) string {
Expand Down Expand Up @@ -887,61 +905,91 @@ func grpcStatusFromError(err error) *statusv1.Status {
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1
func grpcPercentEncode(msg string) string {
var hexCount int
for i := 0; i < len(msg); i++ {
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
if c := msg[i]; c < ' ' || c > '~' || c == '%' {
return grpcPercentEncodeSlow(msg, i)
if grpcShouldEscape(msg[i]) {
hexCount++
}
}
return msg
}

// msg needs some percent-escaping. Bytes before offset don't require
// percent-encoding, so they can be copied to the output as-is.
func grpcPercentEncodeSlow(msg string, offset int) string {
if hexCount == 0 {
return msg
}
// We need to escape some characters, so we'll need to allocate a new string.
var out strings.Builder
out.Grow(2 * len(msg))
out.WriteString(msg[:offset])
for i := offset; i < len(msg); i++ {
c := msg[i]
if c < ' ' || c > '~' || c == '%' {
fmt.Fprintf(&out, "%%%02X", c)
continue
out.Grow(len(msg) + 2*hexCount)
mattrobenolt marked this conversation as resolved.
Show resolved Hide resolved
for i := 0; i < len(msg); i++ {
switch char := msg[i]; {
case grpcShouldEscape(char):
out.WriteByte('%')
out.WriteByte(upperhex[char>>4])
out.WriteByte(upperhex[char&15])
default:
out.WriteByte(char)
}
out.WriteByte(c)
}
return out.String()
}

func grpcPercentDecode(encoded string) string {
for i := 0; i < len(encoded); i++ {
if c := encoded[i]; c == '%' && i+2 < len(encoded) {
return grpcPercentDecodeSlow(encoded, i)
func grpcPercentDecode(input string) (string, error) {
percentCount := 0
for i := 0; i < len(input); {
switch input[i] {
case '%':
percentCount++
if err := validateHex(input[i:]); err != nil {
return "", err
}
i += 3
default:
i++
}
}
return encoded
}

// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be
// decoded byte-by-byte starting at offset.
func grpcPercentDecodeSlow(encoded string, offset int) string {
if percentCount == 0 {
return input, nil
}
// We need to unescape some characters, so we'll need to allocate a new string.
var out strings.Builder
out.Grow(len(encoded))
out.WriteString(encoded[:offset])
for i := offset; i < len(encoded); i++ {
c := encoded[i]
if c != '%' || i+2 >= len(encoded) {
out.WriteByte(c)
continue
out.Grow(len(input) - 2*percentCount)
for i := 0; i < len(input); i++ {
switch input[i] {
case '%':
out.WriteByte(unhex(input[i+1])<<4 | unhex(input[i+2]))
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
i += 2
default:
out.WriteByte(input[i])
}
parsed, err := strconv.ParseUint(encoded[i+1:i+3], 16 /* hex */, 8 /* bitsize */)
if err != nil {
out.WriteRune(utf8.RuneError)
} else {
out.WriteByte(byte(parsed))
}
return out.String(), nil
}

// Characters that need to be escaped are defined in gRPC's HTTP/2 spec.
// They're different from the generic set defined in RFC 3986.
func grpcShouldEscape(char byte) bool {
return char < ' ' || char > '~' || char == '%'
}

func unhex(char byte) byte {
switch {
case '0' <= char && char <= '9':
return char - '0'
case 'a' <= char && char <= 'f':
return char - 'a' + 10
case 'A' <= char && char <= 'F':
return char - 'A' + 10
}
return 0
}
akshayjshah marked this conversation as resolved.
Show resolved Hide resolved

func isHex(char byte) bool {
return ('0' <= char && char <= '9') || ('a' <= char && char <= 'f') || ('A' <= char && char <= 'F')
}

func validateHex(input string) error {
if len(input) < 3 || input[0] != '%' || !isHex(input[1]) || !isHex(input[2]) {
if len(input) > 3 {
input = input[:3]
}
i += 2
return fmt.Errorf("invalid percent-encoded string %q", input)
}
return out.String()
return nil
}
Loading
Loading