diff --git a/third_party/airflow/armada/operators/jobservice.py b/third_party/airflow/armada/operators/jobservice.py index 2d5b4248e42..7a21b87a8cf 100644 --- a/third_party/airflow/armada/operators/jobservice.py +++ b/third_party/airflow/armada/operators/jobservice.py @@ -2,6 +2,8 @@ from google.protobuf import empty_pb2 +import tenacity + class JobServiceClient: """ @@ -18,6 +20,11 @@ class JobServiceClient: def __init__(self, channel): self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) def get_job_status( self, queue: str, job_set_id: str, job_id: str ) -> jobservice_pb2.JobServiceResponse: @@ -35,6 +42,11 @@ def get_job_status( ) return self.job_stub.GetJobStatus(job_service_request) + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) def health(self) -> jobservice_pb2.HealthCheckResponse: """Health Check for GRPC Request""" return self.job_stub.Health(request=empty_pb2.Empty()) diff --git a/third_party/airflow/armada/operators/jobservice_asyncio.py b/third_party/airflow/armada/operators/jobservice_asyncio.py index bf2de7f1565..aaef8ee9832 100644 --- a/third_party/airflow/armada/operators/jobservice_asyncio.py +++ b/third_party/airflow/armada/operators/jobservice_asyncio.py @@ -4,6 +4,8 @@ from google.protobuf import empty_pb2 +import tenacity + class JobServiceAsyncIOClient: """ @@ -20,6 +22,11 @@ class JobServiceAsyncIOClient: def __init__(self, channel: grpc.aio.Channel) -> None: self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) async def get_job_status( self, queue: str, job_set_id: str, job_id: str ) -> jobservice_pb2.JobServiceResponse: @@ -38,6 +45,11 @@ async def get_job_status( response = await self.job_stub.GetJobStatus(job_service_request) return response + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) async def health(self) -> jobservice_pb2.HealthCheckResponse: """Health Check for GRPC Request""" response = await self.job_stub.Health(request=empty_pb2.Empty()) diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index e9eab1dd64b..25ebc828412 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -12,7 +12,8 @@ dependencies = [ "apache-airflow>=2.3.1", "grpcio>=1.46.3", "grpcio-tools>=1.46.3", - "types-protobuf>=3.19.22" + "types-protobuf>=3.19.22", + "tenacity" ] authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } diff --git a/third_party/airflow/tests/unit/job_service_mock.py b/third_party/airflow/tests/unit/job_service_mock.py index 6787ede0c71..1315d21a47a 100644 --- a/third_party/airflow/tests/unit/job_service_mock.py +++ b/third_party/airflow/tests/unit/job_service_mock.py @@ -1,4 +1,5 @@ from armada.jobservice import jobservice_pb2, jobservice_pb2_grpc +import tenacity # TODO - Make this a bit smarter, so we can hit at least one full @@ -22,10 +23,32 @@ def mock_dummy_mapper_terminal(request): class JobService(jobservice_pb2_grpc.JobServiceServicer): + @tenacity.retry( + stop=tenacity.stop_after_attempt(4), + wait=tenacity.wait_exponential(), + reraise=True, + ) def GetJobStatus(self, request, context): return mock_dummy_mapper_terminal(request) + @tenacity.retry( + stop=tenacity.stop_after_attempt(4), + wait=tenacity.wait_exponential(), + reraise=True, + ) def Health(self, request, context): return jobservice_pb2.HealthCheckResponse( status=jobservice_pb2.HealthCheckResponse.SERVING ) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(4), + wait=tenacity.wait_exponential(), + reraise=True, + ) + def tenacity_example(target_count): + current_count = JobService.tenacity_example.retry.statistics["attempt_number"] + if current_count < target_count: + raise IOError("dummy error") + else: + return diff --git a/third_party/airflow/tests/unit/test_tenacity_retry.py b/third_party/airflow/tests/unit/test_tenacity_retry.py new file mode 100644 index 00000000000..e2f27013321 --- /dev/null +++ b/third_party/airflow/tests/unit/test_tenacity_retry.py @@ -0,0 +1,8 @@ +from job_service_mock import JobService + + +def test_tenacity_retry(): + target_count = 3 + JobService.tenacity_example(target_count) + retry_count = JobService.tenacity_example.retry.statistics["attempt_number"] + assert retry_count == target_count