Skip to content

Commit

Permalink
grpcreflect: Fallback from v1 to v1alpha on "unavailable" error code (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Jan 10, 2024
1 parent f139a6d commit bf8a7d8
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 32 deletions.
10 changes: 9 additions & 1 deletion grpcreflect/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,15 @@ func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1alpha.
if attemptCount >= 3 && prevErr != nil {
return nil, prevErr
}
if status.Code(prevErr) == codes.Unimplemented && cr.useV1() {
if (status.Code(prevErr) == codes.Unimplemented ||
status.Code(prevErr) == codes.Unavailable) &&
cr.useV1() {
// If v1 is unimplemented, fallback to v1alpha.
// We also fallback on unavailable because some servers have been
// observed to close the connection/cancel the stream, w/out sending
// back status or headers, when the service name is not known. When
// this happens, the RPC status code is unavailable.
// See https://github.com/fullstorydev/grpcurl/issues/434
cr.useV1Alpha = true
cr.lastTriedV1 = cr.now()
}
Expand Down
210 changes: 179 additions & 31 deletions grpcreflect/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package grpcreflect
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
Expand All @@ -15,8 +16,9 @@ import (

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/reflection"
rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
"google.golang.org/grpc/status"
_ "google.golang.org/protobuf/types/known/apipb"
_ "google.golang.org/protobuf/types/known/emptypb"
Expand All @@ -38,7 +40,7 @@ func TestMain(m *testing.M) {
defer func() {
p := recover()
if p != nil {
fmt.Fprintf(os.Stderr, "PANIC: %v\n", p)
_, _ = fmt.Fprintf(os.Stderr, "PANIC: %v\n", p)
}
os.Exit(code)
}()
Expand All @@ -50,18 +52,22 @@ func TestMain(m *testing.M) {
if err != nil {
panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
}
go svr.Serve(l)
go func() {
_ = svr.Serve(l)
}()
defer svr.Stop()

// create grpc client
addr := l.Addr().String()
cconn, err := grpc.Dial(addr, grpc.WithInsecure())
cconn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
}
defer cconn.Close()
defer func() {
_ = cconn.Close()
}()

stub := rpb.NewServerReflectionClient(cconn)
stub := reflectv1alpha.NewServerReflectionClient(cconn)
client = NewClientV1Alpha(context.Background(), stub)

code = m.Run()
Expand Down Expand Up @@ -243,7 +249,8 @@ func TestRecover(t *testing.T) {

// kill the stream
stream := client.stream
client.stream.CloseSend()
err = client.stream.CloseSend()
testutil.Ok(t, err)

// it should auto-recover and re-create stream
_, err = client.ListServices()
Expand All @@ -253,7 +260,7 @@ func TestRecover(t *testing.T) {

func TestMultipleFiles(t *testing.T) {
svr := grpc.NewServer()
rpb.RegisterServerReflectionServer(svr, testReflectionServer{})
reflectv1alpha.RegisterServerReflectionServer(svr, testReflectionServer{})

l, err := net.Listen("tcp", "127.0.0.1:0")
testutil.Ok(t, err, "failed to listen")
Expand All @@ -273,9 +280,9 @@ func TestMultipleFiles(t *testing.T) {

dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second)
defer dialCancel()
cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithInsecure(), grpc.WithBlock())
cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
testutil.Ok(t, err, "failed ot dial %v", l.Addr().String())
cl := rpb.NewServerReflectionClient(cc)
cl := reflectv1alpha.NewServerReflectionClient(cc)

client := NewClientV1Alpha(ctx, cl)
defer client.Reset()
Expand All @@ -292,7 +299,7 @@ func TestMultipleFiles(t *testing.T) {

type testReflectionServer struct{}

func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_ServerReflectionInfoServer) error {
func (t testReflectionServer) ServerReflectionInfo(server reflectv1alpha.ServerReflection_ServerReflectionInfoServer) error {
const svcA_file = "ChdzYW5kYm94L3NlcnZpY2VfQS5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QRIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUESCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQRIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QRoSLnNhbmRib3guUmVzcG9uc2VBYgZwcm90bzM="
const svcB_file = "ChdzYW5kYm94L1NlcnZpY2VfQi5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QhIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUISCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQhIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QhoSLnNhbmRib3guUmVzcG9uc2VCYgZwcm90bzM="

Expand All @@ -303,50 +310,50 @@ func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_S
} else if err != nil {
return err
}
var resp rpb.ServerReflectionResponse
var resp reflectv1alpha.ServerReflectionResponse
resp.OriginalRequest = req
switch req := req.MessageRequest.(type) {
case *rpb.ServerReflectionRequest_FileByFilename:
case *reflectv1alpha.ServerReflectionRequest_FileByFilename:
switch req.FileByFilename {
case "sandbox/service_A.proto":
resp.MessageResponse = msgResponseForFiles(svcA_file)
case "sandbox/service_B.proto":
resp.MessageResponse = msgResponseForFiles(svcB_file)
default:
resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &reflectv1alpha.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: "not found",
},
}
}
case *rpb.ServerReflectionRequest_FileContainingSymbol:
case *reflectv1alpha.ServerReflectionRequest_FileContainingSymbol:
switch req.FileContainingSymbol {
case "sandbox.Service_A":
resp.MessageResponse = msgResponseForFiles(svcA_file)
case "sandbox.Service_B":
// HERE is where we return two files instead of one
resp.MessageResponse = msgResponseForFiles(svcA_file, svcB_file)
default:
resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &reflectv1alpha.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: "not found",
},
}
}
case *rpb.ServerReflectionRequest_ListServices:
resp.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &rpb.ListServiceResponse{
Service: []*rpb.ServiceResponse{
case *reflectv1alpha.ServerReflectionRequest_ListServices:
resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ListServicesResponse{
ListServicesResponse: &reflectv1alpha.ListServiceResponse{
Service: []*reflectv1alpha.ServiceResponse{
{Name: "sandbox.Service_A"},
{Name: "sandbox.Service_B"},
},
},
}
default:
resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &rpb.ErrorResponse{
resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{
ErrorResponse: &reflectv1alpha.ErrorResponse{
ErrorCode: int32(codes.NotFound),
ErrorMessage: "not found",
},
Expand All @@ -358,7 +365,7 @@ func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_S
}
}

func msgResponseForFiles(files ...string) *rpb.ServerReflectionResponse_FileDescriptorResponse {
func msgResponseForFiles(files ...string) *reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse {
descs := make([][]byte, len(files))
for i, f := range files {
b, err := base64.StdEncoding.DecodeString(f)
Expand All @@ -367,8 +374,8 @@ func msgResponseForFiles(files ...string) *rpb.ServerReflectionResponse_FileDesc
}
descs[i] = b
}
return &rpb.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &rpb.FileDescriptorResponse{
return &reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse{
FileDescriptorResponse: &reflectv1alpha.FileDescriptorResponse{
FileDescriptorProto: descs,
},
}
Expand Down Expand Up @@ -397,7 +404,7 @@ func TestAutoVersion(t *testing.T) {
testClientAuto(t,
func(s *grpc.Server) {
impl := reflection.NewServer(reflection.ServerOptions{Services: s})
rpb.RegisterServerReflectionServer(s, impl)
reflectv1alpha.RegisterServerReflectionServer(s, impl)
testprotosgrpc.RegisterDummyServiceServer(s, testService{})
},
[]string{
Expand Down Expand Up @@ -436,6 +443,8 @@ func TestAutoVersion(t *testing.T) {
"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
})
})

t.Run("fallback-on-unavailable", testClientAutoOnUnavailable)
}

func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices []string, expectedLog []string) {
Expand All @@ -446,14 +455,20 @@ func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices
if err != nil {
panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
}
go svr.Serve(l)
go func() {
err := svr.Serve(l)
testutil.Ok(t, err)
}()
defer svr.Stop()

cconn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
cconn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
}
defer cconn.Close()
defer func() {
err := cconn.Close()
testutil.Ok(t, err)
}()
client := NewClientAuto(context.Background(), cconn)
now := time.Now()
client.now = func() time.Time {
Expand Down Expand Up @@ -509,3 +524,136 @@ func (c *captureStreamNames) intercept(srv interface{}, ss grpc.ServerStream, in
func (c *captureStreamNames) handleUnknown(_ interface{}, _ grpc.ServerStream) error {
return status.Errorf(codes.Unimplemented, "WTF?")
}

func testClientAutoOnUnavailable(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
panic(fmt.Sprintf("Failed to open server socket: %s", err.Error()))
}
captureConn := &captureListener{Listener: l}

var capture captureStreamNames
svr := grpc.NewServer(
grpc.StreamInterceptor(capture.intercept),
grpc.UnknownServiceHandler(func(_ interface{}, _ grpc.ServerStream) error {
// On unknown method, forcibly close the net.Conn, without sending
// back any reply, which should result in an "unavailable" error.
return captureConn.latest().Close()
}),
)
impl := reflection.NewServer(reflection.ServerOptions{Services: svr})
reflectv1alpha.RegisterServerReflectionServer(svr, impl)
testprotosgrpc.RegisterDummyServiceServer(svr, testService{})

go func() {
err := svr.Serve(captureConn)
testutil.Ok(t, err)
}()
defer svr.Stop()

var captureErrs captureErrors
cconn, err := grpc.Dial(
l.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithStreamInterceptor(captureErrs.intercept),
)
if err != nil {
panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error()))
}
defer func() {
err := cconn.Close()
testutil.Ok(t, err)
}()
client := NewClientAuto(context.Background(), cconn)
now := time.Now()
client.now = func() time.Time {
return now
}

svcs, err := client.ListServices()
testutil.Ok(t, err)
sort.Strings(svcs)
testutil.Eq(t, []string{
"grpc.reflection.v1alpha.ServerReflection",
"testprotos.DummyService",
}, svcs)

// It should have tried v1 first and failed then tried v1alpha.
actualLog := capture.names()
testutil.Eq(t, []string{
"/grpc.reflection.v1.ServerReflection/ServerReflectionInfo",
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo",
}, actualLog)

// Make sure the error code observed by the client was unavailable and not unimplemented.
actualCodes := captureErrs.codes()
testutil.Eq(t, []codes.Code{codes.Unavailable}, actualCodes)
}

type captureListener struct {
net.Listener
mu sync.Mutex
conn net.Conn
}

func (c *captureListener) Accept() (net.Conn, error) {
conn, err := c.Listener.Accept()
if err == nil {
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
}
return conn, err
}

func (c *captureListener) latest() net.Conn {
c.mu.Lock()
defer c.mu.Unlock()
return c.conn
}

type captureErrors struct {
mu sync.Mutex
observed []codes.Code
}

func (c *captureErrors) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
stream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
c.observe(err)
return nil, err
}
return &captureErrorStream{ClientStream: stream, c: c}, nil
}

func (c *captureErrors) observe(err error) {
c.mu.Lock()
c.observed = append(c.observed, status.Code(err))
c.mu.Unlock()
}

func (c *captureErrors) codes() []codes.Code {
c.mu.Lock()
defer c.mu.Unlock()
ret := make([]codes.Code, len(c.observed))
copy(ret, c.observed)
return ret
}

type captureErrorStream struct {
grpc.ClientStream
c *captureErrors
done int32
}

func (c *captureErrorStream) RecvMsg(m interface{}) error {
err := c.ClientStream.RecvMsg(m)
if err == nil || errors.Is(err, io.EOF) {
return nil
}
// Only record one error per RPC.
if atomic.CompareAndSwapInt32(&c.done, 0, 1) {
c.c.observe(err)
}
return err
}

0 comments on commit bf8a7d8

Please sign in to comment.