diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcIoRequestProcessor.java b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcIoRequestProcessor.java index 62339a1c5e..c3904ac366 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcIoRequestProcessor.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcIoRequestProcessor.java @@ -4,21 +4,11 @@ import static com.linkedin.venice.listener.ReadQuotaEnforcementHandler.INVALID_REQUEST_RESOURCE_MSG; import static com.linkedin.venice.listener.ReadQuotaEnforcementHandler.SERVER_OVER_CAPACITY_MSG; -import com.google.protobuf.ByteString; -import com.linkedin.davinci.listener.response.ReadResponse; -import com.linkedin.venice.HttpConstants; -import com.linkedin.venice.exceptions.VeniceException; import com.linkedin.venice.listener.QuotaEnforcementHandler; import com.linkedin.venice.listener.QuotaEnforcementHandler.QuotaEnforcementResult; -import com.linkedin.venice.listener.RequestStatsRecorder; import com.linkedin.venice.listener.StorageReadRequestHandler; import com.linkedin.venice.listener.request.RouterRequest; -import com.linkedin.venice.listener.response.AbstractReadResponse; -import com.linkedin.venice.protocols.MultiKeyResponse; -import com.linkedin.venice.protocols.SingleGetResponse; -import com.linkedin.venice.protocols.VeniceServerResponse; import com.linkedin.venice.response.VeniceReadResponseStatus; -import io.grpc.stub.StreamObserver; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -46,10 +36,12 @@ public class GrpcIoRequestProcessor { private static final Logger LOGGER = LogManager.getLogger(GrpcIoRequestProcessor.class); private final QuotaEnforcementHandler quotaEnforcementHandler; private final StorageReadRequestHandler storageReadRequestHandler; + private final GrpcReplyProcessor replyProcessor; public GrpcIoRequestProcessor(GrpcServiceDependencies services) { this.quotaEnforcementHandler = services.getQuotaEnforcementHandler(); this.storageReadRequestHandler = services.getStorageReadRequestHandler(); + this.replyProcessor = services.getGrpcReplyProcessor(); } /** @@ -74,7 +66,7 @@ public GrpcIoRequestProcessor(GrpcServiceDependencies services) { * indicating an unknown quota enforcement result. * * - * After determining the appropriate response, the method calls {@link #sendResponse(GrpcRequestContext)} to + * After determining the appropriate response, the method calls {@link GrpcReplyProcessor#sendResponse(GrpcRequestContext)} to * finalize and send the response to the client. * * This method is executed in the gRPC executor thread, so it should not perform any blocking operations. @@ -88,8 +80,9 @@ public void processRequest(GrpcRequestContext requestContext) { QuotaEnforcementResult result = quotaEnforcementHandler.enforceQuota(request); // If the request is allowed, hand it off to the storage read request handler if (result == ALLOWED) { - GrpcStorageResponseHandlerCallback callback = GrpcStorageResponseHandlerCallback.create(requestContext); - storageReadRequestHandler.queueIoRequestForAsyncProcessing(request, callback); + storageReadRequestHandler.queueIoRequestForAsyncProcessing( + request, + GrpcStorageResponseHandlerCallback.create(requestContext, replyProcessor)); return; } @@ -110,167 +103,9 @@ public void processRequest(GrpcRequestContext requestContext) { default: requestContext.setReadResponseStatus(VeniceReadResponseStatus.INTERNAL_SERVER_ERROR); requestContext.setErrorMessage("Unknown quota enforcement result: " + result); + LOGGER.error("Unknown quota enforcement result: {}", result); } - sendResponse(requestContext); - } - - /** - * Callers must ensure that all fields in the request context are properly set before invoking this method. - * Callers must also use the appropriate {@link GrpcRequestContext#readResponseStatus} to comply with the API contract. - * - * @param requestContext The context of the request for which a response is being sent - * @param The type of the response observer - */ - public static void sendResponse(GrpcRequestContext requestContext) { - GrpcRequestContext.GrpcRequestType grpcRequestType = requestContext.getGrpcRequestType(); - switch (grpcRequestType) { - case SINGLE_GET: - sendSingleGetResponse((GrpcRequestContext) requestContext); - break; - case MULTI_GET: - case COMPUTE: - sendMultiKeyResponse((GrpcRequestContext) requestContext); - break; - case LEGACY: - sendVeniceServerResponse((GrpcRequestContext) requestContext); - break; - default: - VeniceException veniceException = new VeniceException("Unknown response type: " + grpcRequestType); - LOGGER.error("Unknown response type: {}", grpcRequestType, veniceException); - throw veniceException; - } - } - - /** - * Sends a single get response to the client and records the request statistics via {@link #reportRequestStats}. - * Since {@link io.grpc.stub.StreamObserver} is not thread-safe, synchronization is required before invoking - * {@link io.grpc.stub.StreamObserver#onNext} and {@link io.grpc.stub.StreamObserver#onCompleted}. - * - * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. - */ - public static void sendSingleGetResponse(GrpcRequestContext requestContext) { - ReadResponse readResponse = requestContext.getReadResponse(); - SingleGetResponse.Builder builder = SingleGetResponse.newBuilder(); - VeniceReadResponseStatus responseStatus = requestContext.getReadResponseStatus(); - - if (readResponse == null) { - builder.setStatusCode(requestContext.getReadResponseStatus().getCode()); - builder.setErrorMessage(requestContext.getErrorMessage()); - } else if (readResponse.isFound()) { - builder.setRcu(readResponse.getRCU()) - .setStatusCode(responseStatus.getCode()) - .setSchemaId(readResponse.getResponseSchemaIdHeader()) - .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) - .setContentLength(readResponse.getResponseBody().readableBytes()) - .setContentType(HttpConstants.AVRO_BINARY) - .setValue(GrpcUtils.toByteString(readResponse.getResponseBody())); - } else { - builder.setStatusCode(responseStatus.getCode()) - .setRcu(readResponse.getRCU()) - .setErrorMessage("Key not found") - .setContentLength(0); - } - - StreamObserver responseObserver = requestContext.getResponseObserver(); - synchronized (responseObserver) { - responseObserver.onNext(builder.build()); - responseObserver.onCompleted(); - } - - reportRequestStats(requestContext); - } - - /** - * Sends a multi key response (multiGet and compute requests) to the client and records the request statistics via {@link #reportRequestStats}. - * Since {@link io.grpc.stub.StreamObserver} is not thread-safe, synchronization is required before invoking - * {@link io.grpc.stub.StreamObserver#onNext} and {@link io.grpc.stub.StreamObserver#onCompleted}. - * - * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. - */ - public static void sendMultiKeyResponse(GrpcRequestContext requestContext) { - ReadResponse readResponse = requestContext.getReadResponse(); - MultiKeyResponse.Builder builder = MultiKeyResponse.newBuilder(); - VeniceReadResponseStatus responseStatus = requestContext.getReadResponseStatus(); - - if (readResponse == null) { - builder.setStatusCode(responseStatus.getCode()); - builder.setErrorMessage(requestContext.getErrorMessage()); - } else if (readResponse.isFound()) { - builder.setStatusCode(responseStatus.getCode()) - .setRcu(readResponse.getRCU()) - .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) - .setContentLength(readResponse.getResponseBody().readableBytes()) - .setContentType(HttpConstants.AVRO_BINARY) - .setValue(GrpcUtils.toByteString(readResponse.getResponseBody())); - } else { - builder.setStatusCode(responseStatus.getCode()) - .setRcu(readResponse.getRCU()) - .setErrorMessage("Key not found") - .setContentLength(0); - } - - StreamObserver responseObserver = requestContext.getResponseObserver(); - synchronized (responseObserver) { - responseObserver.onNext(builder.build()); - responseObserver.onCompleted(); - } - reportRequestStats(requestContext); - } - - /** - * Sends response (for the legacy API) to the client and records the request statistics via {@link #reportRequestStats}. - * Since {@link io.grpc.stub.StreamObserver} is not thread-safe, synchronization is required before invoking - * {@link io.grpc.stub.StreamObserver#onNext} and {@link io.grpc.stub.StreamObserver#onCompleted}. - * - * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. - */ - public static void sendVeniceServerResponse(GrpcRequestContext requestContext) { - ReadResponse readResponse = requestContext.getReadResponse(); - VeniceServerResponse.Builder builder = VeniceServerResponse.newBuilder(); - VeniceReadResponseStatus responseStatus = requestContext.getReadResponseStatus(); - - if (readResponse == null) { - builder.setErrorCode(responseStatus.getCode()); - builder.setErrorMessage(requestContext.getErrorMessage()); - } else if (readResponse.isFound()) { - builder.setErrorCode(responseStatus.getCode()) - .setResponseRCU(readResponse.getRCU()) - .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) - .setIsStreamingResponse(readResponse.isStreamingResponse()) - .setSchemaId(readResponse.getResponseSchemaIdHeader()) - .setData(GrpcUtils.toByteString(readResponse.getResponseBody())); - } else { - builder.setErrorCode(responseStatus.getCode()).setErrorMessage("Key not found").setData(ByteString.EMPTY); - } - - StreamObserver responseObserver = requestContext.getResponseObserver(); - synchronized (responseObserver) { - responseObserver.onNext(builder.build()); - responseObserver.onCompleted(); - } - - reportRequestStats(requestContext); - } - - /** - * Records the request statistics based on the provided {@link GrpcRequestContext}. - * This method updates the {@link RequestStatsRecorder} with statistics from the {@link GrpcRequestContext} and {@link ReadResponse}. - * @param requestContext The context of the gRPC request, which contains the response and stats recorder to be updated. - */ - public static void reportRequestStats(GrpcRequestContext requestContext) { - ReadResponse readResponse = requestContext.getReadResponse(); - RequestStatsRecorder requestStatsRecorder = requestContext.getRequestStatsRecorder(); - AbstractReadResponse abstractReadResponse = (AbstractReadResponse) readResponse; - if (readResponse == null) { - requestStatsRecorder.setReadResponseStats(null).setResponseSize(0); - } else if (readResponse.isFound()) { - requestStatsRecorder.setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()) - .setResponseSize(abstractReadResponse.getResponseBody().readableBytes()); - } else { - requestStatsRecorder.setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()).setResponseSize(0); - } - - RequestStatsRecorder.recordRequestCompletionStats(requestContext.getRequestStatsRecorder(), true, -1); + replyProcessor.sendResponse(requestContext); } } diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcReplyProcessor.java b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcReplyProcessor.java new file mode 100644 index 0000000000..e0e3520500 --- /dev/null +++ b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcReplyProcessor.java @@ -0,0 +1,184 @@ +package com.linkedin.venice.grpc; + +import com.google.protobuf.ByteString; +import com.linkedin.davinci.listener.response.ReadResponse; +import com.linkedin.venice.HttpConstants; +import com.linkedin.venice.exceptions.VeniceException; +import com.linkedin.venice.listener.RequestStatsRecorder; +import com.linkedin.venice.listener.response.AbstractReadResponse; +import com.linkedin.venice.protocols.MultiKeyResponse; +import com.linkedin.venice.protocols.SingleGetResponse; +import com.linkedin.venice.protocols.VeniceServerResponse; +import com.linkedin.venice.response.VeniceReadResponseStatus; +import io.grpc.stub.StreamObserver; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + + +/** + * This class is responsible for sending responses to the client and recording request statistics. + * Though it has methods which do not have any side effects, we have not made them static to allow + * for easier testing of the callers. + */ +public class GrpcReplyProcessor { + private static final Logger LOGGER = LogManager.getLogger(GrpcIoRequestProcessor.class); + + /** + * Callers must ensure that all fields in the request context are properly set before invoking this method. + * Callers must also use the appropriate {@link GrpcRequestContext#readResponseStatus} to comply with the API contract. + * + * @param requestContext The context of the request for which a response is being sent + * @param The type of the response observer + */ + void sendResponse(GrpcRequestContext requestContext) { + GrpcRequestContext.GrpcRequestType grpcRequestType = requestContext.getGrpcRequestType(); + switch (grpcRequestType) { + case SINGLE_GET: + sendSingleGetResponse((GrpcRequestContext) requestContext); + break; + case MULTI_GET: + case COMPUTE: + sendMultiKeyResponse((GrpcRequestContext) requestContext); + break; + case LEGACY: + sendVeniceServerResponse((GrpcRequestContext) requestContext); + break; + default: + VeniceException veniceException = new VeniceException("Unknown response type: " + grpcRequestType); + LOGGER.error("Unknown response type: {}", grpcRequestType, veniceException); + throw veniceException; + } + } + + /** + * Sends a single get response to the client and records the request statistics via {@link #reportRequestStats}. + * Since {@link io.grpc.stub.StreamObserver} is not thread-safe, synchronization is required before invoking + * {@link io.grpc.stub.StreamObserver#onNext} and {@link io.grpc.stub.StreamObserver#onCompleted}. + * + * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. + */ + void sendSingleGetResponse(GrpcRequestContext requestContext) { + ReadResponse readResponse = requestContext.getReadResponse(); + SingleGetResponse.Builder builder = SingleGetResponse.newBuilder(); + VeniceReadResponseStatus responseStatus = requestContext.getReadResponseStatus(); + + if (readResponse == null) { + builder.setStatusCode(requestContext.getReadResponseStatus().getCode()); + builder.setErrorMessage(requestContext.getErrorMessage()); + } else if (readResponse.isFound()) { + builder.setRcu(readResponse.getRCU()) + .setStatusCode(responseStatus.getCode()) + .setSchemaId(readResponse.getResponseSchemaIdHeader()) + .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) + .setContentLength(readResponse.getResponseBody().readableBytes()) + .setContentType(HttpConstants.AVRO_BINARY) + .setValue(GrpcUtils.toByteString(readResponse.getResponseBody())); + } else { + builder.setStatusCode(responseStatus.getCode()) + .setRcu(readResponse.getRCU()) + .setErrorMessage("Key not found") + .setContentLength(0); + } + + StreamObserver responseObserver = requestContext.getResponseObserver(); + synchronized (responseObserver) { + responseObserver.onNext(builder.build()); + responseObserver.onCompleted(); + } + + reportRequestStats(requestContext); + } + + /** + * Sends a multi key response (multiGet and compute requests) to the client and records the request statistics via {@link #reportRequestStats}. + * Since {@link StreamObserver} is not thread-safe, synchronization is required before invoking + * {@link StreamObserver#onNext} and {@link StreamObserver#onCompleted}. + * + * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. + */ + void sendMultiKeyResponse(GrpcRequestContext requestContext) { + ReadResponse readResponse = requestContext.getReadResponse(); + MultiKeyResponse.Builder builder = MultiKeyResponse.newBuilder(); + VeniceReadResponseStatus responseStatus = requestContext.getReadResponseStatus(); + + if (readResponse == null) { + builder.setStatusCode(responseStatus.getCode()); + builder.setErrorMessage(requestContext.getErrorMessage()); + } else if (readResponse.isFound()) { + builder.setStatusCode(responseStatus.getCode()) + .setRcu(readResponse.getRCU()) + .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) + .setContentLength(readResponse.getResponseBody().readableBytes()) + .setContentType(HttpConstants.AVRO_BINARY) + .setValue(GrpcUtils.toByteString(readResponse.getResponseBody())); + } else { + builder.setStatusCode(responseStatus.getCode()) + .setRcu(readResponse.getRCU()) + .setErrorMessage("Key not found") + .setContentLength(0); + } + + StreamObserver responseObserver = requestContext.getResponseObserver(); + synchronized (responseObserver) { + responseObserver.onNext(builder.build()); + responseObserver.onCompleted(); + } + reportRequestStats(requestContext); + } + + /** + * Sends response (for the legacy API) to the client and records the request statistics via {@link #reportRequestStats}. + * Since {@link StreamObserver} is not thread-safe, synchronization is required before invoking + * {@link StreamObserver#onNext} and {@link StreamObserver#onCompleted}. + * + * @param requestContext The context of the gRPC request, which includes the response and stats recorder to be updated. + */ + void sendVeniceServerResponse(GrpcRequestContext requestContext) { + ReadResponse readResponse = requestContext.getReadResponse(); + VeniceServerResponse.Builder builder = VeniceServerResponse.newBuilder(); + VeniceReadResponseStatus readResponseStatus = requestContext.getReadResponseStatus(); + + if (readResponse == null) { + builder.setErrorCode(readResponseStatus.getCode()); + builder.setErrorMessage(requestContext.getErrorMessage()); + } else if (readResponse.isFound()) { + builder.setErrorCode(readResponseStatus.getCode()) + .setResponseRCU(readResponse.getRCU()) + .setCompressionStrategy(readResponse.getCompressionStrategy().getValue()) + .setIsStreamingResponse(readResponse.isStreamingResponse()) + .setSchemaId(readResponse.getResponseSchemaIdHeader()) + .setData(GrpcUtils.toByteString(readResponse.getResponseBody())); + } else { + builder.setErrorCode(readResponseStatus.getCode()).setErrorMessage("Key not found").setData(ByteString.EMPTY); + } + + StreamObserver responseObserver = requestContext.getResponseObserver(); + synchronized (responseObserver) { + responseObserver.onNext(builder.build()); + responseObserver.onCompleted(); + } + + reportRequestStats(requestContext); + } + + /** + * Records the request statistics based on the provided {@link GrpcRequestContext}. + * This method updates the {@link RequestStatsRecorder} with statistics from the {@link GrpcRequestContext} and {@link ReadResponse}. + * @param requestContext The context of the gRPC request, which contains the response and stats recorder to be updated. + */ + void reportRequestStats(GrpcRequestContext requestContext) { + ReadResponse readResponse = requestContext.getReadResponse(); + RequestStatsRecorder requestStatsRecorder = requestContext.getRequestStatsRecorder(); + AbstractReadResponse abstractReadResponse = (AbstractReadResponse) readResponse; + if (readResponse == null) { + requestStatsRecorder.setReadResponseStats(null).setResponseSize(0); + } else if (readResponse.isFound()) { + requestStatsRecorder.setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()) + .setResponseSize(abstractReadResponse.getResponseBody().readableBytes()); + } else { + requestStatsRecorder.setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()).setResponseSize(0); + } + + RequestStatsRecorder.recordRequestCompletionStats(requestContext.getRequestStatsRecorder(), true, -1); + } +} diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcRequestContext.java b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcRequestContext.java index da467ba64e..0b5feb4308 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcRequestContext.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcRequestContext.java @@ -1,5 +1,7 @@ package com.linkedin.venice.grpc; +import static java.util.Objects.requireNonNull; + import com.linkedin.davinci.listener.response.ReadResponse; import com.linkedin.venice.listener.RequestStatsRecorder; import com.linkedin.venice.listener.request.RouterRequest; @@ -28,9 +30,10 @@ private GrpcRequestContext( RequestStatsRecorder requestStatsRecorder, StreamObserver responseObserver, GrpcRequestType grpcRequestType) { - this.requestStatsRecorder = requestStatsRecorder; - this.responseObserver = responseObserver; - this.grpcRequestType = grpcRequestType; + this.requestStatsRecorder = + requireNonNull(requestStatsRecorder, "RequestStatsRecorder cannot be null in GrpcRequestContext"); + this.responseObserver = requireNonNull(responseObserver, "ResponseObserver cannot be null in GrpcRequestContext"); + this.grpcRequestType = requireNonNull(grpcRequestType, "GrpcRequestType cannot be null in GrpcRequestContext"); } public static GrpcRequestContext create( diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcServiceDependencies.java b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcServiceDependencies.java index 323561d5ea..3cba03c8e0 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcServiceDependencies.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcServiceDependencies.java @@ -15,6 +15,7 @@ public class GrpcServiceDependencies { private final AggServerHttpRequestStats singleGetStats; private final AggServerHttpRequestStats multiGetStats; private final AggServerHttpRequestStats computeStats; + private final GrpcReplyProcessor grpcReplyProcessor; private GrpcServiceDependencies(Builder builder) { this.diskHealthCheckService = builder.diskHealthCheckService; @@ -23,6 +24,7 @@ private GrpcServiceDependencies(Builder builder) { this.singleGetStats = builder.singleGetStats; this.multiGetStats = builder.multiGetStats; this.computeStats = builder.computeStats; + this.grpcReplyProcessor = builder.grpcReplyProcessor; } public DiskHealthCheckService getDiskHealthCheckService() { @@ -49,6 +51,10 @@ public AggServerHttpRequestStats getComputeStats() { return computeStats; } + public GrpcReplyProcessor getGrpcReplyProcessor() { + return grpcReplyProcessor; + } + public static class Builder { private DiskHealthCheckService diskHealthCheckService; private StorageReadRequestHandler storageReadRequestHandler; @@ -56,6 +62,7 @@ public static class Builder { private AggServerHttpRequestStats singleGetStats; private AggServerHttpRequestStats multiGetStats; private AggServerHttpRequestStats computeStats; + private GrpcReplyProcessor grpcReplyProcessor; public Builder setDiskHealthCheckService(DiskHealthCheckService diskHealthCheckService) { this.diskHealthCheckService = diskHealthCheckService; @@ -87,11 +94,20 @@ public Builder setComputeStats(AggServerHttpRequestStats computeStats) { return this; } + public Builder setGrpcReplyProcessor(GrpcReplyProcessor grpcReplyProcessor) { + this.grpcReplyProcessor = grpcReplyProcessor; + return this; + } + public GrpcServiceDependencies build() { // Validate that all required fields are set if (quotaEnforcementHandler == null) { - quotaEnforcementHandler = new NoOpReadQuotaEnforcementHandler(); + quotaEnforcementHandler = NoOpReadQuotaEnforcementHandler.getInstance(); + } + if (grpcReplyProcessor == null) { + grpcReplyProcessor = new GrpcReplyProcessor(); } + singleGetStats = Objects.requireNonNull(singleGetStats, "singleGetStats cannot be null"); multiGetStats = Objects.requireNonNull(multiGetStats, "multiGetStats cannot be null"); computeStats = Objects.requireNonNull(computeStats, "computeStats cannot be null"); diff --git a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallback.java b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallback.java index 4a766f7d61..fc0a48629d 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallback.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallback.java @@ -13,25 +13,34 @@ */ public class GrpcStorageResponseHandlerCallback implements StorageResponseHandlerCallback { private final GrpcRequestContext requestContext; + private final GrpcReplyProcessor grpcReplyProcessor; - private GrpcStorageResponseHandlerCallback(GrpcRequestContext requestContext) { + private GrpcStorageResponseHandlerCallback(GrpcRequestContext requestContext, GrpcReplyProcessor grpcReplyProcessor) { this.requestContext = requestContext; + this.grpcReplyProcessor = grpcReplyProcessor; } // Factory method for creating an instance of this class. - public static GrpcStorageResponseHandlerCallback create(GrpcRequestContext requestContext) { - return new GrpcStorageResponseHandlerCallback(requestContext); + public static GrpcStorageResponseHandlerCallback create( + GrpcRequestContext requestContext, + GrpcReplyProcessor grpcReplyProcessor) { + return new GrpcStorageResponseHandlerCallback(requestContext, grpcReplyProcessor); } @Override public void onReadResponse(ReadResponse readResponse) { + if (readResponse == null) { + onError(VeniceReadResponseStatus.INTERNAL_SERVER_ERROR, "StorageHandler returned a unexpected null response"); + return; + } + if (readResponse.isFound()) { requestContext.setReadResponseStatus(VeniceReadResponseStatus.OK); } else { requestContext.setReadResponseStatus(VeniceReadResponseStatus.KEY_NOT_FOUND); } requestContext.setReadResponse(readResponse); - GrpcIoRequestProcessor.sendResponse(requestContext); + grpcReplyProcessor.sendResponse(requestContext); } @Override @@ -39,6 +48,6 @@ public void onError(VeniceReadResponseStatus readResponseStatus, String message) requestContext.setReadResponseStatus(readResponseStatus); requestContext.setErrorMessage(message); requestContext.setReadResponse(null); - GrpcIoRequestProcessor.sendResponse(requestContext); + grpcReplyProcessor.sendResponse(requestContext); } } diff --git a/services/venice-server/src/main/java/com/linkedin/venice/listener/NoOpReadQuotaEnforcementHandler.java b/services/venice-server/src/main/java/com/linkedin/venice/listener/NoOpReadQuotaEnforcementHandler.java index 71b749266a..a42fd3eddc 100644 --- a/services/venice-server/src/main/java/com/linkedin/venice/listener/NoOpReadQuotaEnforcementHandler.java +++ b/services/venice-server/src/main/java/com/linkedin/venice/listener/NoOpReadQuotaEnforcementHandler.java @@ -7,6 +7,15 @@ * A no-op implementation of {@link QuotaEnforcementHandler} that allows all requests. */ public class NoOpReadQuotaEnforcementHandler implements QuotaEnforcementHandler { + private static final NoOpReadQuotaEnforcementHandler INSTANCE = new NoOpReadQuotaEnforcementHandler(); + + private NoOpReadQuotaEnforcementHandler() { + } + + public static NoOpReadQuotaEnforcementHandler getInstance() { + return INSTANCE; + } + @Override public QuotaEnforcementResult enforceQuota(RouterRequest request) { return QuotaEnforcementResult.ALLOWED; diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcIoRequestProcessorTest.java b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcIoRequestProcessorTest.java new file mode 100644 index 0000000000..8763432a76 --- /dev/null +++ b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcIoRequestProcessorTest.java @@ -0,0 +1,94 @@ +package com.linkedin.venice.grpc; + +import static com.linkedin.venice.listener.QuotaEnforcementHandler.QuotaEnforcementResult.REJECTED; +import static com.linkedin.venice.listener.ReadQuotaEnforcementHandler.INVALID_REQUEST_RESOURCE_MSG; +import static com.linkedin.venice.listener.ReadQuotaEnforcementHandler.SERVER_OVER_CAPACITY_MSG; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.linkedin.venice.listener.QuotaEnforcementHandler; +import com.linkedin.venice.listener.QuotaEnforcementHandler.QuotaEnforcementResult; +import com.linkedin.venice.listener.StorageReadRequestHandler; +import com.linkedin.venice.listener.request.RouterRequest; +import com.linkedin.venice.response.VeniceReadResponseStatus; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + + +public class GrpcIoRequestProcessorTest { + private QuotaEnforcementHandler quotaEnforcementHandler; + private StorageReadRequestHandler storageReadRequestHandler; + private GrpcReplyProcessor grpcReplyProcessor; + private GrpcServiceDependencies grpcServiceDependencies; + private GrpcIoRequestProcessor processor; + private GrpcRequestContext requestContext; + private RouterRequest request; + + @BeforeMethod + public void setUp() { + quotaEnforcementHandler = mock(QuotaEnforcementHandler.class); + storageReadRequestHandler = mock(StorageReadRequestHandler.class); + grpcReplyProcessor = mock(GrpcReplyProcessor.class); + grpcServiceDependencies = mock(GrpcServiceDependencies.class); + when(grpcServiceDependencies.getQuotaEnforcementHandler()).thenReturn(quotaEnforcementHandler); + when(grpcServiceDependencies.getStorageReadRequestHandler()).thenReturn(storageReadRequestHandler); + when(grpcServiceDependencies.getGrpcReplyProcessor()).thenReturn(grpcReplyProcessor); + + request = mock(RouterRequest.class); + when(request.getResourceName()).thenReturn("testResource_v1"); + requestContext = mock(GrpcRequestContext.class); + when(requestContext.getRouterRequest()).thenReturn(request); + when(requestContext.getGrpcRequestType()).thenReturn(GrpcRequestContext.GrpcRequestType.SINGLE_GET); + + processor = new GrpcIoRequestProcessor(grpcServiceDependencies); + } + + @Test + public void testProcessRequestAllowed() { + // Case when quota enforcement result is ALLOWED + when(quotaEnforcementHandler.enforceQuota(request)).thenReturn(QuotaEnforcementResult.ALLOWED); + + processor.processRequest(requestContext); + + // Verify that the request is handed off to the storage read request handler + verify(storageReadRequestHandler) + .queueIoRequestForAsyncProcessing(eq(request), any(GrpcStorageResponseHandlerCallback.class)); + verify(requestContext, never()).setErrorMessage(anyString()); + verify(requestContext, never()).setReadResponseStatus(any()); + + // should not call grpcReplyProcessor.sendResponse() because the request is handed off to the storage read request + // handler which is mocked + verify(grpcReplyProcessor, never()).sendResponse(requestContext); + } + + @Test + public void testProcessRequestQuotaEnforcementErrors() { + // BAD_REQUEST case + when(quotaEnforcementHandler.enforceQuota(request)).thenReturn(QuotaEnforcementResult.BAD_REQUEST); + when(request.getResourceName()).thenReturn("testResource"); + processor.processRequest(requestContext); + verify(requestContext).setErrorMessage(INVALID_REQUEST_RESOURCE_MSG + "testResource"); + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.BAD_REQUEST); + verify(grpcReplyProcessor).sendResponse(requestContext); + + // REJECTED case + when(quotaEnforcementHandler.enforceQuota(request)).thenReturn(REJECTED); + processor.processRequest(requestContext); + verify(requestContext).setErrorMessage("Quota exceeded for resource: testResource"); + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.TOO_MANY_REQUESTS); + verify(grpcReplyProcessor, times(2)).sendResponse(requestContext); + + // OVER_CAPACITY case + when(quotaEnforcementHandler.enforceQuota(request)).thenReturn(QuotaEnforcementResult.OVER_CAPACITY); + processor.processRequest(requestContext); + verify(requestContext).setErrorMessage(SERVER_OVER_CAPACITY_MSG); + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.SERVICE_UNAVAILABLE); + verify(grpcReplyProcessor, times(3)).sendResponse(requestContext); + } +} diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcReplyProcessorTest.java b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcReplyProcessorTest.java new file mode 100644 index 0000000000..53de0ef56e --- /dev/null +++ b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcReplyProcessorTest.java @@ -0,0 +1,271 @@ +package com.linkedin.venice.grpc; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +import com.google.protobuf.ByteString; +import com.linkedin.davinci.listener.response.ReadResponse; +import com.linkedin.venice.compression.CompressionStrategy; +import com.linkedin.venice.protocols.MultiKeyResponse; +import com.linkedin.venice.protocols.SingleGetResponse; +import com.linkedin.venice.protocols.VeniceServerResponse; +import com.linkedin.venice.response.VeniceReadResponseStatus; +import io.grpc.stub.StreamObserver; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.testng.annotations.Test; + + +public class GrpcReplyProcessorTest { + @Test + public void testSendResponse() { + GrpcReplyProcessor replyProcessor = new GrpcReplyProcessor(); + + // Spy on the GrpcReplyProcessor to verify the method calls + GrpcReplyProcessor spyReplyProcessor = spy(replyProcessor); + + // Case 1: Test for SINGLE_GET case + GrpcRequestContext singleGetRequestContext = mock(GrpcRequestContext.class); + when(singleGetRequestContext.getGrpcRequestType()).thenReturn(GrpcRequestContext.GrpcRequestType.SINGLE_GET); + doNothing().when(spyReplyProcessor).sendSingleGetResponse(singleGetRequestContext); + spyReplyProcessor.sendResponse(singleGetRequestContext); + verify(spyReplyProcessor).sendSingleGetResponse(singleGetRequestContext); + verify(spyReplyProcessor, never()).sendMultiKeyResponse(any()); + verify(spyReplyProcessor, never()).sendVeniceServerResponse(any()); + + // Case 2: Test for MULTI_GET case + spyReplyProcessor = spy(replyProcessor); + GrpcRequestContext multiGetRequestContext = mock(GrpcRequestContext.class); + when(multiGetRequestContext.getGrpcRequestType()).thenReturn(GrpcRequestContext.GrpcRequestType.MULTI_GET); + doNothing().when(spyReplyProcessor).sendMultiKeyResponse(multiGetRequestContext); + spyReplyProcessor.sendResponse(multiGetRequestContext); + verify(spyReplyProcessor).sendMultiKeyResponse(multiGetRequestContext); + verify(spyReplyProcessor, never()).sendSingleGetResponse(any()); + verify(spyReplyProcessor, never()).sendVeniceServerResponse(any()); + + when(multiGetRequestContext.getGrpcRequestType()).thenReturn(GrpcRequestContext.GrpcRequestType.COMPUTE); + doNothing().when(spyReplyProcessor).sendMultiKeyResponse(multiGetRequestContext); + spyReplyProcessor.sendResponse(multiGetRequestContext); + verify(spyReplyProcessor, times(2)).sendMultiKeyResponse(multiGetRequestContext); + verify(spyReplyProcessor, never()).sendSingleGetResponse(any()); + verify(spyReplyProcessor, never()).sendVeniceServerResponse(any()); + + // Case 3: Test for LEGACY case + spyReplyProcessor = spy(replyProcessor); + GrpcRequestContext legacyRequestContext = mock(GrpcRequestContext.class); + when(legacyRequestContext.getGrpcRequestType()).thenReturn(GrpcRequestContext.GrpcRequestType.LEGACY); + doNothing().when(spyReplyProcessor).sendVeniceServerResponse(legacyRequestContext); + spyReplyProcessor.sendResponse(legacyRequestContext); + verify(spyReplyProcessor).sendVeniceServerResponse(legacyRequestContext); + verify(spyReplyProcessor, never()).sendSingleGetResponse(any()); + verify(spyReplyProcessor, never()).sendMultiKeyResponse(any()); + } + + @Test + public void testSendSingleGetResponse() { + GrpcReplyProcessor replyProcessor = new GrpcReplyProcessor(); + GrpcReplyProcessor spyReplyProcessor = spy(replyProcessor); + doNothing().when(spyReplyProcessor).reportRequestStats(any()); + + // 1. Test Case: readResponse is null + GrpcRequestContext requestContext = mock(GrpcRequestContext.class); + StreamObserver responseObserver = mock(StreamObserver.class); + when(requestContext.getResponseObserver()).thenReturn(responseObserver); + VeniceReadResponseStatus status = VeniceReadResponseStatus.TOO_MANY_REQUESTS; + when(requestContext.getReadResponseStatus()).thenReturn(status); + when(requestContext.getReadResponse()).thenReturn(null); + when(requestContext.getErrorMessage()).thenReturn("Some error"); + + spyReplyProcessor.sendSingleGetResponse(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SingleGetResponse.class); + verify(responseObserver).onNext(captor.capture()); + SingleGetResponse capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getStatusCode(), status.getCode()); + assertEquals(capturedResponse.getErrorMessage(), "Some error"); + verify(responseObserver).onCompleted(); + + InOrder inOrder = inOrder(responseObserver); + inOrder.verify(responseObserver).onNext(any()); + inOrder.verify(responseObserver).onCompleted(); + + // 2. Test Case: readResponse is found + status = VeniceReadResponseStatus.OK; + when(requestContext.getReadResponseStatus()).thenReturn(status); + ReadResponse readResponse = mock(ReadResponse.class); + when(readResponse.isFound()).thenReturn(true); + when(requestContext.getReadResponse()).thenReturn(readResponse); + when(readResponse.getRCU()).thenReturn(1); + when(readResponse.getResponseSchemaIdHeader()).thenReturn(1); + when(readResponse.getCompressionStrategy()).thenReturn(CompressionStrategy.GZIP); + + ByteBuf responseBody = Unpooled.EMPTY_BUFFER; + when(readResponse.getResponseBody()).thenReturn(responseBody); + + spyReplyProcessor.sendSingleGetResponse(requestContext); + + verify(responseObserver, times(2)).onNext(captor.capture()); // Capturing the second call + capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getStatusCode(), status.getCode()); + assertEquals(capturedResponse.getRcu(), 1); + assertEquals(capturedResponse.getSchemaId(), 1); + assertEquals(capturedResponse.getCompressionStrategy(), CompressionStrategy.GZIP.getValue()); + assertEquals(capturedResponse.getContentLength(), 0); + assertEquals(capturedResponse.getValue(), ByteString.EMPTY); + verify(responseObserver, times(2)).onCompleted(); + + inOrder.verify(responseObserver).onNext(any()); + inOrder.verify(responseObserver).onCompleted(); + + // 3. Test Case: readResponse is not found + status = VeniceReadResponseStatus.KEY_NOT_FOUND; + when(requestContext.getReadResponseStatus()).thenReturn(status); + when(readResponse.isFound()).thenReturn(false); + when(readResponse.getRCU()).thenReturn(5); + when(readResponse.getResponseSchemaIdHeader()).thenReturn(-1); + spyReplyProcessor.sendSingleGetResponse(requestContext); + + verify(responseObserver, times(3)).onNext(captor.capture()); // Capturing the third call + capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getStatusCode(), status.getCode()); + assertEquals(capturedResponse.getRcu(), 5); + assertEquals(capturedResponse.getErrorMessage(), "Key not found"); + assertEquals(capturedResponse.getContentLength(), 0); + verify(responseObserver, times(3)).onCompleted(); + inOrder.verify(responseObserver).onNext(any()); + inOrder.verify(responseObserver).onCompleted(); + + verify(spyReplyProcessor, times(3)).reportRequestStats(requestContext); + } + + @Test + public void testSendVeniceServerResponse() { + GrpcReplyProcessor replyProcessor = new GrpcReplyProcessor(); + GrpcReplyProcessor spyReplyProcessor = spy(replyProcessor); + doNothing().when(spyReplyProcessor).reportRequestStats(any()); + + GrpcRequestContext requestContext = mock(GrpcRequestContext.class); + StreamObserver responseObserver = mock(StreamObserver.class); + + // 1. Test Case: readResponse is null + when(requestContext.getResponseObserver()).thenReturn(responseObserver); + VeniceReadResponseStatus status = VeniceReadResponseStatus.BAD_REQUEST; + when(requestContext.getReadResponseStatus()).thenReturn(status); + when(requestContext.getReadResponse()).thenReturn(null); + when(requestContext.getErrorMessage()).thenReturn("Null read response"); + + spyReplyProcessor.sendVeniceServerResponse(requestContext); + + ArgumentCaptor captor = ArgumentCaptor.forClass(VeniceServerResponse.class); + verify(responseObserver).onNext(captor.capture()); + VeniceServerResponse capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getErrorCode(), status.getCode()); + assertEquals(capturedResponse.getErrorMessage(), "Null read response"); + verify(responseObserver).onCompleted(); + + // 2. Test Case: readResponse is found + ReadResponse readResponse = mock(ReadResponse.class); + when(readResponse.getResponseBody()).thenReturn(Unpooled.EMPTY_BUFFER); + when(requestContext.getReadResponse()).thenReturn(readResponse); + status = VeniceReadResponseStatus.OK; + when(requestContext.getReadResponseStatus()).thenReturn(status); + when(requestContext.getErrorMessage()).thenReturn(null); + when(readResponse.isFound()).thenReturn(true); + when(requestContext.getReadResponse()).thenReturn(readResponse); + when(readResponse.getRCU()).thenReturn(120); + when(readResponse.getCompressionStrategy()).thenReturn(CompressionStrategy.GZIP); + when(readResponse.isStreamingResponse()).thenReturn(true); + when(readResponse.getResponseSchemaIdHeader()).thenReturn(2); + + spyReplyProcessor.sendVeniceServerResponse(requestContext); + + verify(responseObserver, times(2)).onNext(captor.capture()); // Capturing the second call + capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getErrorCode(), status.getCode()); + assertEquals(capturedResponse.getResponseRCU(), 120); + assertEquals(capturedResponse.getCompressionStrategy(), CompressionStrategy.GZIP.getValue()); + assertEquals(capturedResponse.getSchemaId(), 2); + assertTrue(capturedResponse.getIsStreamingResponse()); + assertEquals(capturedResponse.getData(), ByteString.EMPTY); + verify(responseObserver, times(2)).onCompleted(); + + // 3. Test Case: readResponse is not found + when(readResponse.isFound()).thenReturn(false); + status = VeniceReadResponseStatus.KEY_NOT_FOUND; + when(requestContext.getReadResponseStatus()).thenReturn(status); + + spyReplyProcessor.sendVeniceServerResponse(requestContext); + + verify(responseObserver, times(3)).onNext(captor.capture()); // Capturing the third call + capturedResponse = captor.getValue(); + assertEquals(capturedResponse.getErrorCode(), status.getCode()); + assertEquals(capturedResponse.getErrorMessage(), "Key not found"); + assertEquals(capturedResponse.getData(), ByteString.EMPTY); + verify(responseObserver, times(3)).onCompleted(); + + // 4. Verify synchronization on responseObserver + InOrder inOrder = inOrder(responseObserver); + inOrder.verify(responseObserver).onNext(any()); + inOrder.verify(responseObserver).onCompleted(); + + // 5. Verify reportRequestStats is called at the end + verify(spyReplyProcessor, times(3)).reportRequestStats(requestContext); + } + + // @Test + // public void testReportRequestStats() { + // GrpcReplyProcessor processor = new GrpcReplyProcessor(); + // + // // Mock the GrpcRequestContext and its dependencies + // GrpcRequestContext requestContext = mock(GrpcRequestContext.class); + // ReadResponse readResponse = mock(ReadResponse.class); + // RequestStatsRecorder requestStatsRecorder = mock(RequestStatsRecorder.class); + // AbstractReadResponse abstractReadResponse = mock(AbstractReadResponse.class); + // ByteBuf responseBody = mock(ByteBuf.class); + // + // // Mock behavior for requestContext + // when(requestContext.getReadResponse()).thenReturn(readResponse); + // when(requestContext.getRequestStatsRecorder()).thenReturn(requestStatsRecorder); + // + // // 1. Test Case: readResponse is null + // when(requestContext.getReadResponse()).thenReturn(null); + // + // processor.reportRequestStats(requestContext); + // + // // Verify that recorder sets read response stats to null and response size to 0 + // verify(requestStatsRecorder).setReadResponseStats(null); + // verify(requestStatsRecorder).setResponseSize(0); + // + // // 2. Test Case: readResponse.isFound() is true + // when(readResponse.isFound()).thenReturn(true); + // when(requestContext.getReadResponse()).thenReturn(abstractReadResponse); + // when(abstractReadResponse.getResponseBody()).thenReturn(responseBody); + // when(responseBody.readableBytes()).thenReturn(512); + // when(abstractReadResponse.getReadResponseStatsRecorder()).thenReturn(mock(ReadResponseStatsRecorder.class)); + // + // processor.reportRequestStats(requestContext); + // + // // Verify that recorder sets the stats and the correct response size + // verify(requestStatsRecorder, times(2)).setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()); + // verify(requestStatsRecorder, times(2)).setResponseSize(512); + // + // // 3. Test Case: readResponse.isFound() is false + // when(readResponse.isFound()).thenReturn(false); + // processor.reportRequestStats(requestContext); + // + // // Verify that recorder sets the stats and the response size to 0 + // verify(requestStatsRecorder, times(3)).setReadResponseStats(abstractReadResponse.getReadResponseStatsRecorder()); + // verify(requestStatsRecorder).setResponseSize(0); + // } +} diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcServiceDependenciesTest.java b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcServiceDependenciesTest.java new file mode 100644 index 0000000000..f96c5421e7 --- /dev/null +++ b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcServiceDependenciesTest.java @@ -0,0 +1,80 @@ +package com.linkedin.venice.grpc; + +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertSame; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import com.linkedin.davinci.storage.DiskHealthCheckService; +import com.linkedin.venice.listener.NoOpReadQuotaEnforcementHandler; +import com.linkedin.venice.listener.StorageReadRequestHandler; +import com.linkedin.venice.stats.AggServerHttpRequestStats; +import org.mockito.Mockito; +import org.testng.annotations.Test; + + +public class GrpcServiceDependenciesTest { + @Test + public void testBuilderValidation() { + DiskHealthCheckService diskHealthCheckService = Mockito.mock(DiskHealthCheckService.class); + StorageReadRequestHandler storageReadRequestHandler = Mockito.mock(StorageReadRequestHandler.class); + NoOpReadQuotaEnforcementHandler quotaEnforcementHandler = Mockito.mock(NoOpReadQuotaEnforcementHandler.class); + AggServerHttpRequestStats singleGetStats = Mockito.mock(AggServerHttpRequestStats.class); + AggServerHttpRequestStats multiGetStats = Mockito.mock(AggServerHttpRequestStats.class); + AggServerHttpRequestStats computeStats = Mockito.mock(AggServerHttpRequestStats.class); + GrpcReplyProcessor grpcReplyProcessor = Mockito.mock(GrpcReplyProcessor.class); + + // Test with all fields set + GrpcServiceDependencies dependencies = + new GrpcServiceDependencies.Builder().setDiskHealthCheckService(diskHealthCheckService) + .setStorageReadRequestHandler(storageReadRequestHandler) + .setQuotaEnforcementHandler(quotaEnforcementHandler) + .setSingleGetStats(singleGetStats) + .setMultiGetStats(multiGetStats) + .setComputeStats(computeStats) + .setGrpcReplyProcessor(grpcReplyProcessor) + .build(); + + assertNotNull(dependencies); + assertSame(dependencies.getDiskHealthCheckService(), diskHealthCheckService); + assertSame(dependencies.getStorageReadRequestHandler(), storageReadRequestHandler); + assertSame(dependencies.getQuotaEnforcementHandler(), quotaEnforcementHandler); + assertSame(dependencies.getSingleGetStats(), singleGetStats); + assertSame(dependencies.getMultiGetStats(), multiGetStats); + assertSame(dependencies.getComputeStats(), computeStats); + assertSame(dependencies.getGrpcReplyProcessor(), grpcReplyProcessor); + } + + @Test + public void testBuilderValidationWithMissingFields() { + assertThrows( + NullPointerException.class, + () -> new GrpcServiceDependencies.Builder().setDiskHealthCheckService(null).build()); + } + + @Test + public void testBuilderValidationWithDefaultValues() { + DiskHealthCheckService diskHealthCheckService = Mockito.mock(DiskHealthCheckService.class); + StorageReadRequestHandler storageReadRequestHandler = Mockito.mock(StorageReadRequestHandler.class); + AggServerHttpRequestStats singleGetStats = Mockito.mock(AggServerHttpRequestStats.class); + AggServerHttpRequestStats multiGetStats = Mockito.mock(AggServerHttpRequestStats.class); + AggServerHttpRequestStats computeStats = Mockito.mock(AggServerHttpRequestStats.class); + + GrpcServiceDependencies dependencies = + new GrpcServiceDependencies.Builder().setDiskHealthCheckService(diskHealthCheckService) + .setStorageReadRequestHandler(storageReadRequestHandler) + .setSingleGetStats(singleGetStats) + .setMultiGetStats(multiGetStats) + .setComputeStats(computeStats) + .build(); + + assertNotNull(dependencies); + assertTrue(dependencies.getQuotaEnforcementHandler() instanceof NoOpReadQuotaEnforcementHandler); + assertNotNull(dependencies.getGrpcReplyProcessor()); + assertSame(dependencies.getDiskHealthCheckService(), diskHealthCheckService); + assertSame(dependencies.getStorageReadRequestHandler(), storageReadRequestHandler); + assertSame(dependencies.getSingleGetStats(), singleGetStats); + assertSame(dependencies.getMultiGetStats(), multiGetStats); + assertSame(dependencies.getComputeStats(), computeStats); + } +} diff --git a/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallbackTest.java b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallbackTest.java new file mode 100644 index 0000000000..e1651a1dd7 --- /dev/null +++ b/services/venice-server/src/test/java/com/linkedin/venice/grpc/GrpcStorageResponseHandlerCallbackTest.java @@ -0,0 +1,86 @@ +package com.linkedin.venice.grpc; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.linkedin.davinci.listener.response.ReadResponse; +import com.linkedin.venice.response.VeniceReadResponseStatus; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + + +public class GrpcStorageResponseHandlerCallbackTest { + private GrpcRequestContext requestContext; + private GrpcReplyProcessor grpcReplyProcessor; + private GrpcStorageResponseHandlerCallback callback; + + @BeforeMethod + public void setUp() { + requestContext = mock(GrpcRequestContext.class); + grpcReplyProcessor = mock(GrpcReplyProcessor.class); + callback = GrpcStorageResponseHandlerCallback.create(requestContext, grpcReplyProcessor); + } + + @Test + public void testOnReadResponseVariousScenarios() { + // Case 1: Response found + ReadResponse foundResponse = mock(ReadResponse.class); + when(foundResponse.isFound()).thenReturn(true); + ByteBuf foundResponseBody = Unpooled.EMPTY_BUFFER; + when(foundResponse.getResponseBody()).thenReturn(foundResponseBody); + + callback.onReadResponse(foundResponse); + + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.OK); + verify(requestContext).setReadResponse(foundResponse); + verify(grpcReplyProcessor).sendResponse(requestContext); + + reset(requestContext, grpcReplyProcessor); // Resetting mocks before next scenario + + // Case 2: Response not found (key not found) + ReadResponse notFoundResponse = mock(ReadResponse.class); + when(notFoundResponse.isFound()).thenReturn(false); + + callback.onReadResponse(notFoundResponse); + + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.KEY_NOT_FOUND); + verify(requestContext).setReadResponse(notFoundResponse); + verify(grpcReplyProcessor).sendResponse(requestContext); + + reset(requestContext, grpcReplyProcessor); // Resetting mocks before next scenario + + // Case 3: Null ReadResponse + callback.onReadResponse(null); + + verify(requestContext).setReadResponseStatus(VeniceReadResponseStatus.INTERNAL_SERVER_ERROR); + verify(requestContext).setReadResponse(null); + verify(grpcReplyProcessor).sendResponse(requestContext); + } + + @Test + public void testOnErrorVariousScenarios() { + // Case 1: Standard error case + VeniceReadResponseStatus errorStatus = VeniceReadResponseStatus.INTERNAL_SERVER_ERROR; + String errorMessage = "An error occurred"; + + callback.onError(errorStatus, errorMessage); + + verify(requestContext).setReadResponseStatus(errorStatus); + verify(requestContext).setErrorMessage(errorMessage); + verify(requestContext).setReadResponse(null); + verify(grpcReplyProcessor).sendResponse(requestContext); + + reset(requestContext, grpcReplyProcessor); + + // Case 2: Null error message + callback.onError(errorStatus, null); + verify(requestContext).setReadResponseStatus(errorStatus); + verify(requestContext).setErrorMessage(null); + verify(requestContext).setReadResponse(null); + verify(grpcReplyProcessor).sendResponse(requestContext); + } +}