From 9a1b4e160563fadf0c66dc79bbc1fa629e2ac6d2 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Fri, 20 Dec 2024 13:57:15 -0800 Subject: [PATCH] grpc: factor out setup func This uses a pattern that is new to our tests. setup accepts a variadic list of options, and uses a type switch to make use of those options during setup. This allows us to pass setup only the options that are relevant to any given test case, leaving the rest to sensible defaults. --- grpc/interceptors_test.go | 235 ++++++++++++++------------------------ 1 file changed, 87 insertions(+), 148 deletions(-) diff --git a/grpc/interceptors_test.go b/grpc/interceptors_test.go index 1b5415fedcd..84e78d3d41b 100644 --- a/grpc/interceptors_test.go +++ b/grpc/interceptors_test.go @@ -154,14 +154,14 @@ func TestWaitForReadyFalse(t *testing.T) { } } -// testServer is used to implement TestTimeouts, and will attempt to sleep for +// testTimeoutServer is used to implement TestTimeouts, and will attempt to sleep for // the given amount of time (unless it hits a timeout or cancel). -type testServer struct { +type testTimeoutServer struct { test_proto.UnimplementedChillerServer } // Chill implements ChillerServer.Chill -func (s *testServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { +func (s *testTimeoutServer) Chill(ctx context.Context, in *test_proto.Time) (*test_proto.Time, error) { start := time.Now() // Sleep for either the requested amount of time, or the context times out or // is canceled. @@ -175,42 +175,9 @@ func (s *testServer) Chill(ctx context.Context, in *test_proto.Time) (*test_prot } func TestTimeouts(t *testing.T) { - // start server - lis, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - port := lis.Addr().(*net.TCPAddr).Port - - serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating server metrics") - si := newServerMetadataInterceptor(serverMetrics, clock.NewFake()) - s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) - test_proto.RegisterChillerServer(s, &testServer{}) - go func() { - start := time.Now() - err := s.Serve(lis) - if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Logf("s.Serve: %v after %s", err, time.Since(start)) - } - }() - defer s.Stop() - - // make client - clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating client metrics") - ci := &clientMetadataInterceptor{ - timeout: 30 * time.Second, - metrics: clientMetrics, - clk: clock.NewFake(), - } - conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.Unary)) - if err != nil { - t.Fatalf("did not connect: %v", err) - } - c := test_proto.NewChillerClient(conn) + server := new(testTimeoutServer) + client, _, stop := setup(t, server, clock.NewFake()) + defer stop() testCases := []struct { timeout time.Duration @@ -224,7 +191,7 @@ func TestTimeouts(t *testing.T) { t.Run(tc.timeout.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) defer cancel() - _, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)}) + _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second)}) if err == nil { t.Fatal("Got no error, expected a timeout") } @@ -236,58 +203,22 @@ func TestTimeouts(t *testing.T) { } func TestRequestTimeTagging(t *testing.T) { - clk := clock.NewFake() - // Listen for TCP requests on a random system assigned port number - lis, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - // Retrieve the concrete port numberthe system assigned our listener - port := lis.Addr().(*net.TCPAddr).Port - - // Create a new ChillerServer - serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) + server := new(testTimeoutServer) + metrics, err := newServerMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating server metrics") - si := newServerMetadataInterceptor(serverMetrics, clk) - s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) - test_proto.RegisterChillerServer(s, &testServer{}) - // Chill until ill - go func() { - start := time.Now() - err := s.Serve(lis) - if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Logf("s.Serve: %v after %s", err, time.Since(start)) - } - }() - defer s.Stop() - - // Dial the ChillerServer - clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating client metrics") - ci := &clientMetadataInterceptor{ - timeout: 30 * time.Second, - metrics: clientMetrics, - clk: clk, - } - conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.Unary)) - if err != nil { - t.Fatalf("did not connect: %v", err) - } - // Create a ChillerClient with the connection to the ChillerServer - c := test_proto.NewChillerClient(conn) + client, _, stop := setup(t, server, metrics) + defer stop() // Make an RPC request with the ChillerClient with a timeout higher than the // requested ChillerServer delay so that the RPC completes normally ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if _, err := c.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil { + if _, err := client.Chill(ctx, &test_proto.Time{Duration: durationpb.New(time.Second * 5)}); err != nil { t.Fatalf("Unexpected error calling Chill RPC: %s", err) } // There should be one histogram sample in the serverInterceptor rpcLag stat - test.AssertMetricWithLabelsEquals(t, si.metrics.rpcLag, prometheus.Labels{}, 1) + test.AssertMetricWithLabelsEquals(t, metrics.rpcLag, prometheus.Labels{}, 1) } func TestClockSkew(t *testing.T) { @@ -297,32 +228,23 @@ func TestClockSkew(t *testing.T) { clientClk := clock.NewFake() clientClk.Set(time.Now()) - // Listen for TCP requests on a random system assigned port number - lis, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - port := lis.Addr().(*net.TCPAddr).Port + _, serverPort, stop := setup(t, &testTimeoutServer{}, serverClk) + defer stop() - // Start a gRPC server listening on that port - serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating server metrics") - si := newServerMetadataInterceptor(serverMetrics, serverClk) - s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) - test_proto.RegisterChillerServer(s, &testServer{}) - go func() { _ = s.Serve(lis) }() - defer s.Stop() - - // Start a gRPC client talking to the server clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) test.AssertNotError(t, err, "creating client metrics") - ci := &clientMetadataInterceptor{metrics: clientMetrics, clk: clientClk, timeout: time.Second} - conn, err := grpc.NewClient( - net.JoinHostPort("localhost", strconv.Itoa(port)), + ci := &clientMetadataInterceptor{ + timeout: 30 * time.Second, + metrics: clientMetrics, + clk: clientClk, + } + conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(serverPort)), grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.Unary), - ) - test.AssertNotError(t, err, "creating test client") + grpc.WithUnaryInterceptor(ci.Unary)) + if err != nil { + t.Fatalf("did not connect: %v", err) + } + client := test_proto.NewChillerClient(conn) // Create a context with plenty of timeout @@ -368,18 +290,15 @@ func (s *blockedServer) Chill(_ context.Context, _ *test_proto.Time) (*test_prot } func TestInFlightRPCStat(t *testing.T) { - clk := clock.NewFake() - // Listen for TCP requests on a random system assigned port number - lis, err := net.Listen("tcp", ":0") - if err != nil { - log.Fatalf("failed to listen: %v", err) - } - // Retrieve the concrete port numberthe system assigned our listener - port := lis.Addr().(*net.TCPAddr).Port - // Create a new blockedServer to act as a ChillerServer server := &blockedServer{} + metrics, err := newClientMetrics(metrics.NoopRegisterer) + test.AssertNotError(t, err, "creating client metrics") + + client, _, stop := setup(t, server, metrics) + defer stop() + // Increment the roadblock waitgroup - this will cause all chill RPCs to // the server to block until we call Done()! server.roadblock.Add(1) @@ -390,43 +309,11 @@ func TestInFlightRPCStat(t *testing.T) { numRPCs := 5 server.received.Add(numRPCs) - serverMetrics, err := newServerMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating server metrics") - si := newServerMetadataInterceptor(serverMetrics, clk) - s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) - test_proto.RegisterChillerServer(s, server) - // Chill until ill - go func() { - start := time.Now() - err := s.Serve(lis) - if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { - t.Logf("s.Serve: %v after %s", err, time.Since(start)) - } - }() - defer s.Stop() - - // Dial the ChillerServer - clientMetrics, err := newClientMetrics(metrics.NoopRegisterer) - test.AssertNotError(t, err, "creating client metrics") - ci := &clientMetadataInterceptor{ - timeout: 30 * time.Second, - metrics: clientMetrics, - clk: clk, - } - conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(ci.Unary)) - if err != nil { - t.Fatalf("did not connect: %v", err) - } - // Create a ChillerClient with the connection to the ChillerServer - c := test_proto.NewChillerClient(conn) - // Fire off a few RPCs. They will block on the blockedServer's roadblock wg for range numRPCs { go func() { // Ignore errors, just chilllll. - _, _ = c.Chill(context.Background(), &test_proto.Time{}) + _, _ = client.Chill(context.Background(), &test_proto.Time{}) }() } @@ -441,7 +328,7 @@ func TestInFlightRPCStat(t *testing.T) { } // We expect the inFlightRPCs gauge for the Chiller.Chill RPCs to be equal to numRPCs. - test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, float64(numRPCs)) + test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, float64(numRPCs)) // Unblock the blockedServer to let all of the Chiller.Chill RPCs complete server.roadblock.Done() @@ -449,7 +336,7 @@ func TestInFlightRPCStat(t *testing.T) { time.Sleep(1 * time.Second) // Check the gauge value again - test.AssertMetricWithLabelsEquals(t, ci.metrics.inFlightRPCs, labels, 0) + test.AssertMetricWithLabelsEquals(t, metrics.inFlightRPCs, labels, 0) } func TestServiceAuthChecker(t *testing.T) { @@ -524,3 +411,55 @@ func TestServiceAuthChecker(t *testing.T) { err = ac.checkContextAuth(ctx, "/package.ServiceName/Method/") test.AssertNotError(t, err, "checking allowed cert") } + +// setup creates a server and client, returning the created client, the running server's port, and a stop function. +func setup(t *testing.T, server test_proto.ChillerServer, opts ...any) (test_proto.ChillerClient, int, func()) { + clk := clock.NewFake() + serverMetricsVal, err := newServerMetrics(metrics.NoopRegisterer) + test.AssertNotError(t, err, "creating server metrics") + clientMetricsVal, err := newClientMetrics(metrics.NoopRegisterer) + test.AssertNotError(t, err, "creating client metrics") + + for _, opt := range opts { + switch optTyped := opt.(type) { + case clock.FakeClock: + clk = optTyped + case clientMetrics: + clientMetricsVal = optTyped + case serverMetrics: + serverMetricsVal = optTyped + default: + t.Fatalf("setup called with unrecognize option %#v", t) + } + } + lis, err := net.Listen("tcp", ":0") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + port := lis.Addr().(*net.TCPAddr).Port + + si := newServerMetadataInterceptor(serverMetricsVal, clk) + s := grpc.NewServer(grpc.UnaryInterceptor(si.Unary)) + test_proto.RegisterChillerServer(s, server) + + go func() { + start := time.Now() + err := s.Serve(lis) + if err != nil && !strings.HasSuffix(err.Error(), "use of closed network connection") { + t.Logf("s.Serve: %v after %s", err, time.Since(start)) + } + }() + + ci := &clientMetadataInterceptor{ + timeout: 30 * time.Second, + metrics: clientMetricsVal, + clk: clock.NewFake(), + } + conn, err := grpc.Dial(net.JoinHostPort("localhost", strconv.Itoa(port)), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor(ci.Unary)) + if err != nil { + t.Fatalf("did not connect: %v", err) + } + return test_proto.NewChillerClient(conn), port, s.Stop +}