From bf8a7d8361451273ce8797d2fe89509b3560c0f6 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 10 Jan 2024 10:15:56 -0500 Subject: [PATCH] grpcreflect: Fallback from v1 to v1alpha on "unavailable" error code (#588) --- grpcreflect/client.go | 10 +- grpcreflect/client_test.go | 210 +++++++++++++++++++++++++++++++------ 2 files changed, 188 insertions(+), 32 deletions(-) diff --git a/grpcreflect/client.go b/grpcreflect/client.go index 00336614..192c11f1 100644 --- a/grpcreflect/client.go +++ b/grpcreflect/client.go @@ -515,7 +515,15 @@ func (cr *Client) doSendLocked(attemptCount int, prevErr error, req *refv1alpha. if attemptCount >= 3 && prevErr != nil { return nil, prevErr } - if status.Code(prevErr) == codes.Unimplemented && cr.useV1() { + if (status.Code(prevErr) == codes.Unimplemented || + status.Code(prevErr) == codes.Unavailable) && + cr.useV1() { + // If v1 is unimplemented, fallback to v1alpha. + // We also fallback on unavailable because some servers have been + // observed to close the connection/cancel the stream, w/out sending + // back status or headers, when the service name is not known. When + // this happens, the RPC status code is unavailable. + // See https://github.com/fullstorydev/grpcurl/issues/434 cr.useV1Alpha = true cr.lastTriedV1 = cr.now() } diff --git a/grpcreflect/client_test.go b/grpcreflect/client_test.go index 703ab618..290f67a2 100644 --- a/grpcreflect/client_test.go +++ b/grpcreflect/client_test.go @@ -3,6 +3,7 @@ package grpcreflect import ( "context" "encoding/base64" + "errors" "fmt" "io" "net" @@ -15,8 +16,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" - rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" + reflectv1alpha "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" "google.golang.org/grpc/status" _ "google.golang.org/protobuf/types/known/apipb" _ "google.golang.org/protobuf/types/known/emptypb" @@ -38,7 +40,7 @@ func TestMain(m *testing.M) { defer func() { p := recover() if p != nil { - fmt.Fprintf(os.Stderr, "PANIC: %v\n", p) + _, _ = fmt.Fprintf(os.Stderr, "PANIC: %v\n", p) } os.Exit(code) }() @@ -50,18 +52,22 @@ func TestMain(m *testing.M) { if err != nil { panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) } - go svr.Serve(l) + go func() { + _ = svr.Serve(l) + }() defer svr.Stop() // create grpc client addr := l.Addr().String() - cconn, err := grpc.Dial(addr, grpc.WithInsecure()) + cconn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) } - defer cconn.Close() + defer func() { + _ = cconn.Close() + }() - stub := rpb.NewServerReflectionClient(cconn) + stub := reflectv1alpha.NewServerReflectionClient(cconn) client = NewClientV1Alpha(context.Background(), stub) code = m.Run() @@ -243,7 +249,8 @@ func TestRecover(t *testing.T) { // kill the stream stream := client.stream - client.stream.CloseSend() + err = client.stream.CloseSend() + testutil.Ok(t, err) // it should auto-recover and re-create stream _, err = client.ListServices() @@ -253,7 +260,7 @@ func TestRecover(t *testing.T) { func TestMultipleFiles(t *testing.T) { svr := grpc.NewServer() - rpb.RegisterServerReflectionServer(svr, testReflectionServer{}) + reflectv1alpha.RegisterServerReflectionServer(svr, testReflectionServer{}) l, err := net.Listen("tcp", "127.0.0.1:0") testutil.Ok(t, err, "failed to listen") @@ -273,9 +280,9 @@ func TestMultipleFiles(t *testing.T) { dialCtx, dialCancel := context.WithTimeout(ctx, 3*time.Second) defer dialCancel() - cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + cc, err := grpc.DialContext(dialCtx, l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) testutil.Ok(t, err, "failed ot dial %v", l.Addr().String()) - cl := rpb.NewServerReflectionClient(cc) + cl := reflectv1alpha.NewServerReflectionClient(cc) client := NewClientV1Alpha(ctx, cl) defer client.Reset() @@ -292,7 +299,7 @@ func TestMultipleFiles(t *testing.T) { type testReflectionServer struct{} -func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_ServerReflectionInfoServer) error { +func (t testReflectionServer) ServerReflectionInfo(server reflectv1alpha.ServerReflection_ServerReflectionInfoServer) error { const svcA_file = "ChdzYW5kYm94L3NlcnZpY2VfQS5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QRIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUESCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQRIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QRoSLnNhbmRib3guUmVzcG9uc2VBYgZwcm90bzM=" const svcB_file = "ChdzYW5kYm94L1NlcnZpY2VfQi5wcm90bxIHc2FuZGJveCIWCghSZXF1ZXN0QhIKCgJpZBgBIAEoBSIYCglSZXNwb25zZUISCwoDc3RyGAEgASgJMj0KCVNlcnZpY2VfQhIwCgdFeGVjdXRlEhEuc2FuZGJveC5SZXF1ZXN0QhoSLnNhbmRib3guUmVzcG9uc2VCYgZwcm90bzM=" @@ -303,24 +310,24 @@ func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_S } else if err != nil { return err } - var resp rpb.ServerReflectionResponse + var resp reflectv1alpha.ServerReflectionResponse resp.OriginalRequest = req switch req := req.MessageRequest.(type) { - case *rpb.ServerReflectionRequest_FileByFilename: + case *reflectv1alpha.ServerReflectionRequest_FileByFilename: switch req.FileByFilename { case "sandbox/service_A.proto": resp.MessageResponse = msgResponseForFiles(svcA_file) case "sandbox/service_B.proto": resp.MessageResponse = msgResponseForFiles(svcB_file) default: - resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ - ErrorResponse: &rpb.ErrorResponse{ + resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ + ErrorResponse: &reflectv1alpha.ErrorResponse{ ErrorCode: int32(codes.NotFound), ErrorMessage: "not found", }, } } - case *rpb.ServerReflectionRequest_FileContainingSymbol: + case *reflectv1alpha.ServerReflectionRequest_FileContainingSymbol: switch req.FileContainingSymbol { case "sandbox.Service_A": resp.MessageResponse = msgResponseForFiles(svcA_file) @@ -328,25 +335,25 @@ func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_S // HERE is where we return two files instead of one resp.MessageResponse = msgResponseForFiles(svcA_file, svcB_file) default: - resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ - ErrorResponse: &rpb.ErrorResponse{ + resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ + ErrorResponse: &reflectv1alpha.ErrorResponse{ ErrorCode: int32(codes.NotFound), ErrorMessage: "not found", }, } } - case *rpb.ServerReflectionRequest_ListServices: - resp.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ - ListServicesResponse: &rpb.ListServiceResponse{ - Service: []*rpb.ServiceResponse{ + case *reflectv1alpha.ServerReflectionRequest_ListServices: + resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ListServicesResponse{ + ListServicesResponse: &reflectv1alpha.ListServiceResponse{ + Service: []*reflectv1alpha.ServiceResponse{ {Name: "sandbox.Service_A"}, {Name: "sandbox.Service_B"}, }, }, } default: - resp.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ - ErrorResponse: &rpb.ErrorResponse{ + resp.MessageResponse = &reflectv1alpha.ServerReflectionResponse_ErrorResponse{ + ErrorResponse: &reflectv1alpha.ErrorResponse{ ErrorCode: int32(codes.NotFound), ErrorMessage: "not found", }, @@ -358,7 +365,7 @@ func (t testReflectionServer) ServerReflectionInfo(server rpb.ServerReflection_S } } -func msgResponseForFiles(files ...string) *rpb.ServerReflectionResponse_FileDescriptorResponse { +func msgResponseForFiles(files ...string) *reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse { descs := make([][]byte, len(files)) for i, f := range files { b, err := base64.StdEncoding.DecodeString(f) @@ -367,8 +374,8 @@ func msgResponseForFiles(files ...string) *rpb.ServerReflectionResponse_FileDesc } descs[i] = b } - return &rpb.ServerReflectionResponse_FileDescriptorResponse{ - FileDescriptorResponse: &rpb.FileDescriptorResponse{ + return &reflectv1alpha.ServerReflectionResponse_FileDescriptorResponse{ + FileDescriptorResponse: &reflectv1alpha.FileDescriptorResponse{ FileDescriptorProto: descs, }, } @@ -397,7 +404,7 @@ func TestAutoVersion(t *testing.T) { testClientAuto(t, func(s *grpc.Server) { impl := reflection.NewServer(reflection.ServerOptions{Services: s}) - rpb.RegisterServerReflectionServer(s, impl) + reflectv1alpha.RegisterServerReflectionServer(s, impl) testprotosgrpc.RegisterDummyServiceServer(s, testService{}) }, []string{ @@ -436,6 +443,8 @@ func TestAutoVersion(t *testing.T) { "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", }) }) + + t.Run("fallback-on-unavailable", testClientAutoOnUnavailable) } func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices []string, expectedLog []string) { @@ -446,14 +455,20 @@ func testClientAuto(t *testing.T, register func(*grpc.Server), expectedServices if err != nil { panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) } - go svr.Serve(l) + go func() { + err := svr.Serve(l) + testutil.Ok(t, err) + }() defer svr.Stop() - cconn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure()) + cconn, err := grpc.Dial(l.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) } - defer cconn.Close() + defer func() { + err := cconn.Close() + testutil.Ok(t, err) + }() client := NewClientAuto(context.Background(), cconn) now := time.Now() client.now = func() time.Time { @@ -509,3 +524,136 @@ func (c *captureStreamNames) intercept(srv interface{}, ss grpc.ServerStream, in func (c *captureStreamNames) handleUnknown(_ interface{}, _ grpc.ServerStream) error { return status.Errorf(codes.Unimplemented, "WTF?") } + +func testClientAutoOnUnavailable(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("Failed to open server socket: %s", err.Error())) + } + captureConn := &captureListener{Listener: l} + + var capture captureStreamNames + svr := grpc.NewServer( + grpc.StreamInterceptor(capture.intercept), + grpc.UnknownServiceHandler(func(_ interface{}, _ grpc.ServerStream) error { + // On unknown method, forcibly close the net.Conn, without sending + // back any reply, which should result in an "unavailable" error. + return captureConn.latest().Close() + }), + ) + impl := reflection.NewServer(reflection.ServerOptions{Services: svr}) + reflectv1alpha.RegisterServerReflectionServer(svr, impl) + testprotosgrpc.RegisterDummyServiceServer(svr, testService{}) + + go func() { + err := svr.Serve(captureConn) + testutil.Ok(t, err) + }() + defer svr.Stop() + + var captureErrs captureErrors + cconn, err := grpc.Dial( + l.Addr().String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithStreamInterceptor(captureErrs.intercept), + ) + if err != nil { + panic(fmt.Sprintf("Failed to create grpc client: %s", err.Error())) + } + defer func() { + err := cconn.Close() + testutil.Ok(t, err) + }() + client := NewClientAuto(context.Background(), cconn) + now := time.Now() + client.now = func() time.Time { + return now + } + + svcs, err := client.ListServices() + testutil.Ok(t, err) + sort.Strings(svcs) + testutil.Eq(t, []string{ + "grpc.reflection.v1alpha.ServerReflection", + "testprotos.DummyService", + }, svcs) + + // It should have tried v1 first and failed then tried v1alpha. + actualLog := capture.names() + testutil.Eq(t, []string{ + "/grpc.reflection.v1.ServerReflection/ServerReflectionInfo", + "/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo", + }, actualLog) + + // Make sure the error code observed by the client was unavailable and not unimplemented. + actualCodes := captureErrs.codes() + testutil.Eq(t, []codes.Code{codes.Unavailable}, actualCodes) +} + +type captureListener struct { + net.Listener + mu sync.Mutex + conn net.Conn +} + +func (c *captureListener) Accept() (net.Conn, error) { + conn, err := c.Listener.Accept() + if err == nil { + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + } + return conn, err +} + +func (c *captureListener) latest() net.Conn { + c.mu.Lock() + defer c.mu.Unlock() + return c.conn +} + +type captureErrors struct { + mu sync.Mutex + observed []codes.Code +} + +func (c *captureErrors) intercept(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + stream, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + c.observe(err) + return nil, err + } + return &captureErrorStream{ClientStream: stream, c: c}, nil +} + +func (c *captureErrors) observe(err error) { + c.mu.Lock() + c.observed = append(c.observed, status.Code(err)) + c.mu.Unlock() +} + +func (c *captureErrors) codes() []codes.Code { + c.mu.Lock() + defer c.mu.Unlock() + ret := make([]codes.Code, len(c.observed)) + copy(ret, c.observed) + return ret +} + +type captureErrorStream struct { + grpc.ClientStream + c *captureErrors + done int32 +} + +func (c *captureErrorStream) RecvMsg(m interface{}) error { + err := c.ClientStream.RecvMsg(m) + if err == nil || errors.Is(err, io.EOF) { + return nil + } + // Only record one error per RPC. + if atomic.CompareAndSwapInt32(&c.done, 0, 1) { + c.c.observe(err) + } + return err +}