-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup gRPC util funcs for timeout and percent encoding #596
Changes from 3 commits
1b6eed7
1b36244
bdde501
284c5a9
5d8527b
dc08bf0
8f52228
29edb36
871f4cc
fe1b25d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ import ( | |
"strconv" | ||
"strings" | ||
"time" | ||
"unicode/utf8" | ||
|
||
statusv1 "connectrpc.com/connect/internal/gen/connectext/grpc/status/v1" | ||
) | ||
|
@@ -42,9 +41,6 @@ const ( | |
|
||
grpcFlagEnvelopeTrailer = 0b10000000 | ||
|
||
grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration? | ||
grpcMaxTimeoutChars = 8 // from gRPC protocol | ||
|
||
grpcContentTypeDefault = "application/grpc" | ||
grpcWebContentTypeDefault = "application/grpc-web" | ||
grpcContentTypePrefix = grpcContentTypeDefault + "+" | ||
|
@@ -54,21 +50,6 @@ const ( | |
) | ||
|
||
var ( | ||
grpcTimeoutUnits = []struct { | ||
size time.Duration | ||
char byte | ||
}{ | ||
{time.Nanosecond, 'n'}, | ||
{time.Microsecond, 'u'}, | ||
{time.Millisecond, 'm'}, | ||
{time.Second, 'S'}, | ||
{time.Minute, 'M'}, | ||
{time.Hour, 'H'}, | ||
} | ||
grpcTimeoutUnitLookup = make(map[byte]time.Duration) | ||
grpcAllowedMethods = map[string]struct{}{ | ||
http.MethodPost: {}, | ||
} | ||
errTrailersWithoutGRPCStatus = fmt.Errorf("protocol error: no %s trailer: %w", grpcHeaderStatus, io.ErrUnexpectedEOF) | ||
|
||
// defaultGrpcUserAgent follows | ||
|
@@ -83,12 +64,6 @@ var ( | |
defaultGrpcUserAgent = fmt.Sprintf("grpc-go-connect/%s (%s)", Version, runtime.Version()) | ||
) | ||
|
||
func init() { | ||
for _, pair := range grpcTimeoutUnits { | ||
grpcTimeoutUnitLookup[pair.char] = pair.size | ||
} | ||
} | ||
|
||
type protocolGRPC struct { | ||
web bool | ||
} | ||
|
@@ -134,7 +109,7 @@ type grpcHandler struct { | |
} | ||
|
||
func (g *grpcHandler) Methods() map[string]struct{} { | ||
return grpcAllowedMethods | ||
return map[string]struct{}{http.MethodPost: {}} | ||
} | ||
|
||
func (g *grpcHandler) ContentTypes() map[string]struct{} { | ||
|
@@ -285,11 +260,8 @@ func (g *grpcClient) NewConn( | |
header http.Header, | ||
) streamingClientConn { | ||
if deadline, ok := ctx.Deadline(); ok { | ||
if encodedDeadline, err := grpcEncodeTimeout(time.Until(deadline)); err == nil { | ||
// Tests verify that the error in encodeTimeout is unreachable, so we | ||
// don't need to handle the error case. | ||
header[grpcHeaderTimeout] = []string{encodedDeadline} | ||
} | ||
encodedDeadline := grpcEncodeTimeout(time.Until(deadline)) | ||
header[grpcHeaderTimeout] = []string{encodedDeadline} | ||
} | ||
duplexCall := newDuplexHTTPCall( | ||
ctx, | ||
|
@@ -748,7 +720,10 @@ func grpcErrorFromTrailer(protobuf Codec, trailer http.Header) *Error { | |
if err != nil { | ||
return errorf(CodeInternal, "protocol error: invalid error code %q", codeHeader) | ||
} | ||
message := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage)) | ||
message, err := grpcPercentDecode(getHeaderCanonical(trailer, grpcHeaderMessage)) | ||
if err != nil { | ||
return errorf(CodeInternal, "protocol error: invalid error message %q", message) | ||
} | ||
retErr := NewWireError(Code(code), errors.New(message)) | ||
|
||
detailsBinaryEncoded := getHeaderCanonical(trailer, grpcHeaderDetails) | ||
|
@@ -776,8 +751,8 @@ func grpcParseTimeout(timeout string) (time.Duration, error) { | |
if timeout == "" { | ||
return 0, errNoTimeout | ||
} | ||
unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]] | ||
if !ok { | ||
unit := grpcTimeoutUnitLookup(timeout[len(timeout)-1]) | ||
if unit == 0 { | ||
return 0, fmt.Errorf("protocol error: timeout %q has invalid unit", timeout) | ||
} | ||
num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */) | ||
|
@@ -787,6 +762,7 @@ func grpcParseTimeout(timeout string) (time.Duration, error) { | |
if num > 99999999 { // timeout must be ASCII string of at most 8 digits | ||
return 0, fmt.Errorf("protocol error: timeout %q is too long", timeout) | ||
} | ||
const grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) // how many hours fit into a time.Duration? | ||
if unit == time.Hour && num > grpcTimeoutMaxHours { | ||
// Timeout is effectively unbounded, so ignore it. The grpc-go | ||
// implementation does the same thing. | ||
|
@@ -795,19 +771,50 @@ func grpcParseTimeout(timeout string) (time.Duration, error) { | |
return time.Duration(num) * unit, nil | ||
} | ||
|
||
func grpcEncodeTimeout(timeout time.Duration) (string, error) { | ||
func grpcEncodeTimeout(timeout time.Duration) string { | ||
if timeout <= 0 { | ||
return "0n", nil | ||
return "0n" | ||
} | ||
for _, pair := range grpcTimeoutUnits { | ||
digits := strconv.FormatInt(int64(timeout/pair.size), 10 /* base */) | ||
if len(digits) < grpcMaxTimeoutChars { | ||
return digits + string(pair.char), nil | ||
} | ||
const grpcTimeoutMaxValue = 1e8 | ||
akshayjshah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
var ( | ||
size time.Duration | ||
unit byte | ||
) | ||
switch { | ||
case timeout < time.Nanosecond*grpcTimeoutMaxValue: | ||
size, unit = time.Nanosecond, 'n' | ||
case timeout < time.Microsecond*grpcTimeoutMaxValue: | ||
size, unit = time.Microsecond, 'u' | ||
case timeout < time.Millisecond*grpcTimeoutMaxValue: | ||
size, unit = time.Millisecond, 'm' | ||
case timeout < time.Second*grpcTimeoutMaxValue: | ||
size, unit = time.Second, 'S' | ||
case timeout < time.Minute*grpcTimeoutMaxValue: | ||
size, unit = time.Minute, 'M' | ||
default: | ||
size, unit = time.Hour, 'H' | ||
akshayjshah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
value := timeout / size | ||
return strconv.FormatInt(int64(value), 10 /* base */) + string(unit) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as we're golfing allocs, we can probably shave an alloc here by using |
||
} | ||
|
||
func grpcTimeoutUnitLookup(unit byte) time.Duration { | ||
switch unit { | ||
case 'n': | ||
return time.Nanosecond | ||
case 'u': | ||
return time.Microsecond | ||
case 'm': | ||
return time.Millisecond | ||
case 'S': | ||
return time.Second | ||
case 'M': | ||
return time.Minute | ||
case 'H': | ||
return time.Hour | ||
default: | ||
return 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. API-wise, this doesn't feel like an improvement to me: we're overloading the zero value in a way that feels odd in Go, and returning an error won't measurably affect performance. |
||
} | ||
// The max time.Duration is smaller than the maximum expressible gRPC | ||
// timeout, so we can't reach this case. | ||
return "", errNoTimeout | ||
} | ||
|
||
func grpcCodecFromContentType(web bool, contentType string) string { | ||
|
@@ -887,61 +894,92 @@ func grpcStatusFromError(err error) *statusv1.Status { | |
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses | ||
// https://datatracker.ietf.org/doc/html/rfc3986#section-2.1 | ||
func grpcPercentEncode(msg string) string { | ||
var hexCount int | ||
for i := 0; i < len(msg); i++ { | ||
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec. | ||
// They're different from the generic set defined in RFC 3986. | ||
if c := msg[i]; c < ' ' || c > '~' || c == '%' { | ||
return grpcPercentEncodeSlow(msg, i) | ||
if grpcShouldEscape(msg[i]) { | ||
hexCount++ | ||
} | ||
} | ||
return msg | ||
} | ||
|
||
// msg needs some percent-escaping. Bytes before offset don't require | ||
// percent-encoding, so they can be copied to the output as-is. | ||
func grpcPercentEncodeSlow(msg string, offset int) string { | ||
if hexCount == 0 { | ||
return msg | ||
} | ||
// We need to escape some characters, so we'll need to allocate a new string. | ||
var out strings.Builder | ||
out.Grow(2 * len(msg)) | ||
out.WriteString(msg[:offset]) | ||
for i := offset; i < len(msg); i++ { | ||
c := msg[i] | ||
if c < ' ' || c > '~' || c == '%' { | ||
fmt.Fprintf(&out, "%%%02X", c) | ||
continue | ||
out.Grow(len(msg) + 2*hexCount) | ||
mattrobenolt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i := 0; i < len(msg); i++ { | ||
switch char := msg[i]; { | ||
case grpcShouldEscape(char): | ||
out.WriteByte('%') | ||
out.WriteByte(upperhex[char>>4]) | ||
out.WriteByte(upperhex[char&15]) | ||
default: | ||
out.WriteByte(char) | ||
} | ||
out.WriteByte(c) | ||
} | ||
return out.String() | ||
} | ||
|
||
func grpcPercentDecode(encoded string) string { | ||
for i := 0; i < len(encoded); i++ { | ||
if c := encoded[i]; c == '%' && i+2 < len(encoded) { | ||
return grpcPercentDecodeSlow(encoded, i) | ||
func grpcPercentDecode(input string) (string, error) { | ||
percentCount := 0 | ||
for i := 0; i < len(input); { | ||
switch input[i] { | ||
case '%': | ||
percentCount++ | ||
if err := validateHex(input[i:]); err != nil { | ||
return "", err | ||
} | ||
i += 3 | ||
default: | ||
i++ | ||
} | ||
} | ||
return encoded | ||
} | ||
|
||
// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be | ||
// decoded byte-by-byte starting at offset. | ||
func grpcPercentDecodeSlow(encoded string, offset int) string { | ||
if percentCount == 0 { | ||
return input, nil | ||
} | ||
// We need to unescape some characters, so we'll need to allocate a new string. | ||
var out strings.Builder | ||
out.Grow(len(encoded)) | ||
out.WriteString(encoded[:offset]) | ||
for i := offset; i < len(encoded); i++ { | ||
c := encoded[i] | ||
if c != '%' || i+2 >= len(encoded) { | ||
out.WriteByte(c) | ||
continue | ||
out.Grow(len(input) - 2*percentCount) | ||
for i := 0; i < len(input); i++ { | ||
switch input[i] { | ||
case '%': | ||
out.WriteByte(unhex(input[i+1])<<4 | unhex(input[i+2])) | ||
emcfarlane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
i += 2 | ||
default: | ||
out.WriteByte(input[i]) | ||
} | ||
parsed, err := strconv.ParseUint(encoded[i+1:i+3], 16 /* hex */, 8 /* bitsize */) | ||
if err != nil { | ||
out.WriteRune(utf8.RuneError) | ||
} else { | ||
out.WriteByte(byte(parsed)) | ||
} | ||
return out.String(), nil | ||
} | ||
|
||
// Characters that need to be escaped are defined in gRPC's HTTP/2 spec. | ||
// They're different from the generic set defined in RFC 3986. | ||
func grpcShouldEscape(char byte) bool { | ||
return char < ' ' || char > '~' || char == '%' | ||
} | ||
|
||
const upperhex = "0123456789ABCDEF" | ||
akshayjshah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
func unhex(char byte) byte { | ||
switch { | ||
case '0' <= char && char <= '9': | ||
return char - '0' | ||
case 'a' <= char && char <= 'f': | ||
return char - 'a' + 10 | ||
case 'A' <= char && char <= 'F': | ||
return char - 'A' + 10 | ||
} | ||
return 0 | ||
} | ||
akshayjshah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
func isHex(char byte) bool { | ||
return ('0' <= char && char <= '9') || ('a' <= char && char <= 'f') || ('A' <= char && char <= 'F') | ||
} | ||
|
||
func validateHex(input string) error { | ||
if len(input) < 3 || input[0] != '%' || !isHex(input[1]) || !isHex(input[2]) { | ||
if len(input) > 3 { | ||
input = input[:3] | ||
} | ||
i += 2 | ||
return fmt.Errorf("invalid percent-encoded string %q", input) | ||
} | ||
return out.String() | ||
return nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an extra allocation along the request path. The returned map isn't accessible to Connect users, so I think we ought to keep the code as-is.
Was there a safety concern here, or were we just trying to get rid of globals?