diff --git a/dapr/clients/grpc/interceptors.py b/dapr/clients/grpc/interceptors.py index 22098f53..adda29c1 100644 --- a/dapr/clients/grpc/interceptors.py +++ b/dapr/clients/grpc/interceptors.py @@ -1,7 +1,7 @@ from collections import namedtuple from typing import List, Tuple -from grpc import UnaryUnaryClientInterceptor, ClientCallDetails # type: ignore +from grpc import UnaryUnaryClientInterceptor, ClientCallDetails, StreamStreamClientInterceptor # type: ignore from dapr.conf import settings @@ -38,7 +38,7 @@ def intercept_unary_unary(self, continuation, client_call_details, request): return continuation(client_call_details, request) -class DaprClientInterceptor(UnaryUnaryClientInterceptor): +class DaprClientInterceptor(UnaryUnaryClientInterceptor, StreamStreamClientInterceptor): """The class implements a UnaryUnaryClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed. @@ -91,8 +91,8 @@ def _intercept_call(self, client_call_details: ClientCallDetails) -> ClientCallD return new_call_details def intercept_unary_unary(self, continuation, client_call_details, request): - """This method intercepts a unary-unary gRPC call. This is the implementation of the - abstract method defined in UnaryUnaryClientInterceptor defined in grpc. This is invoked + """This method intercepts a unary-unary gRPC call. It is the implementation of the + abstract method defined in UnaryUnaryClientInterceptor defined in grpc. It's invoked automatically by grpc based on the order in which interceptors are added to the channel. Args: @@ -108,3 +108,23 @@ def intercept_unary_unary(self, continuation, client_call_details, request): # Call continuation response = continuation(new_call_details, request) return response + + def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + """This method intercepts a stream-stream gRPC call. It is the implementation of the + abstract method defined in StreamStreamClientInterceptor defined in grpc. It's invoked + automatically by grpc based on the order in which interceptors are added to the channel. + + Args: + continuation: a callable to be invoked to continue with the RPC or next interceptor + client_call_details: a ClientCallDetails object describing the outgoing RPC + request_iterator: the request value for the RPC + + Returns: + A response object after invoking the continuation callable + """ + # Pre-process or intercept call + + new_call_details = self._intercept_call(client_call_details) + # Call continuation + response = continuation(new_call_details, request_iterator) + return response diff --git a/dapr/clients/grpc/subscription.py b/dapr/clients/grpc/subscription.py index 5ca30119..2c8c18e3 100644 --- a/dapr/clients/grpc/subscription.py +++ b/dapr/clients/grpc/subscription.py @@ -1,27 +1,16 @@ import json -from grpc import StreamStreamMultiCallable, RpcError, StatusCode # type: ignore +from grpc import RpcError, StatusCode, Call # type: ignore from dapr.clients.exceptions import StreamInactiveError from dapr.clients.grpc._response import TopicEventResponse +from dapr.clients.health import DaprHealth from dapr.proto import api_v1, appcallback_v1 import queue import threading from typing import Optional -def success(): - return appcallback_v1.TopicEventResponse.SUCCESS - - -def retry(): - return appcallback_v1.TopicEventResponse.RETRY - - -def drop(): - return appcallback_v1.TopicEventResponse.DROP - - class Subscription: def __init__(self, stub, pubsub_name, topic, metadata=None, dead_letter_topic=None): self._stub = stub @@ -29,10 +18,10 @@ def __init__(self, stub, pubsub_name, topic, metadata=None, dead_letter_topic=No self.topic = topic self.metadata = metadata or {} self.dead_letter_topic = dead_letter_topic or '' - self._stream: Optional[StreamStreamMultiCallable] = None # Type annotation for gRPC stream - self._response_thread: Optional[threading.Thread] = None # Type for thread - self._send_queue: queue.Queue = queue.Queue() # Type annotation for send queue - self._receive_queue: queue.Queue = queue.Queue() # Type annotation for receive queue + self._stream: Optional[Call] = None + self._response_thread: Optional[threading.Thread] = None + self._send_queue: queue.Queue = queue.Queue() + self._receive_queue: queue.Queue = queue.Queue() self._stream_active: bool = False self._stream_lock = threading.Lock() # Protects _stream_active @@ -56,7 +45,7 @@ def outgoing_request_iterator(): # Start sending back acknowledgement messages from the send queue while self._is_stream_active(): try: - response = self._send_queue.get(timeout=1) + response = self._send_queue.get() # Check again if the stream is still active if not self._is_stream_active(): break @@ -75,6 +64,7 @@ def outgoing_request_iterator(): self._response_thread.start() def _handle_incoming_messages(self): + reconnect = False try: # Check if the stream is not None if self._stream is not None: @@ -83,17 +73,26 @@ def _handle_incoming_messages(self): # Read messages from the stream and put them in the receive queue for message in self._stream: - if self._is_stream_active(): - self._receive_queue.put(message.event_message) - else: - break + self._receive_queue.put(message.event_message) except RpcError as e: - if e.code() != StatusCode.CANCELLED: + if e.code() == StatusCode.UNAVAILABLE: + print('Stream unavailable, attempting to reconnect...') + reconnect = True + elif e.code() != StatusCode.CANCELLED: print(f'gRPC error in stream: {e.details()}, Status Code: {e.code()}') + except Exception as e: raise Exception(f'Error while handling responses: {e}') finally: self._set_stream_inactive() + if reconnect: + self.reconnect_stream() + + def reconnect_stream(self): + DaprHealth.wait_until_ready() + print('Attempting to reconnect...') + self.close() + self.start() def next_message(self, timeout=None): msg = self.read_message_from_queue(self._receive_queue, timeout=timeout) diff --git a/examples/pubsub-streaming/subscriber.py b/examples/pubsub-streaming/subscriber.py index 701f5775..8b396281 100644 --- a/examples/pubsub-streaming/subscriber.py +++ b/examples/pubsub-streaming/subscriber.py @@ -1,6 +1,4 @@ from dapr.clients import DaprClient -from dapr.clients.grpc._response import TopicEventResponse -from dapr.clients.grpc.subscription import success, retry, drop def process_message(message):