Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resiliency for ArmadaOperator services calls #2591

Closed
wants to merge 14 commits into from
12 changes: 12 additions & 0 deletions third_party/airflow/armada/operators/jobservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from google.protobuf import empty_pb2

import tenacity


class JobServiceClient:
"""
Expand All @@ -18,6 +20,11 @@ class JobServiceClient:
def __init__(self, channel):
self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel)

@tenacity.retry(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love some kind of test that can verify that this library works with GRPC.

Can you look into adding a unit test? You could use our mock python grpc class for jobservice.

Otherwise this is looking really good!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should check whether tenacity makes retries or not , right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep.

stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(),
reraise=True,
)
def get_job_status(
self, queue: str, job_set_id: str, job_id: str
) -> jobservice_pb2.JobServiceResponse:
Expand All @@ -35,6 +42,11 @@ def get_job_status(
)
return self.job_stub.GetJobStatus(job_service_request)

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(),
reraise=True,
)
def health(self) -> jobservice_pb2.HealthCheckResponse:
"""Health Check for GRPC Request"""
return self.job_stub.Health(request=empty_pb2.Empty())
12 changes: 12 additions & 0 deletions third_party/airflow/armada/operators/jobservice_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from google.protobuf import empty_pb2

import tenacity


class JobServiceAsyncIOClient:
"""
Expand All @@ -20,6 +22,11 @@ class JobServiceAsyncIOClient:
def __init__(self, channel: grpc.aio.Channel) -> None:
self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel)

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(),
reraise=True,
)
async def get_job_status(
self, queue: str, job_set_id: str, job_id: str
) -> jobservice_pb2.JobServiceResponse:
Expand All @@ -38,6 +45,11 @@ async def get_job_status(
response = await self.job_stub.GetJobStatus(job_service_request)
return response

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(),
reraise=True,
)
async def health(self) -> jobservice_pb2.HealthCheckResponse:
"""Health Check for GRPC Request"""
response = await self.job_stub.Health(request=empty_pb2.Empty())
Expand Down
3 changes: 2 additions & 1 deletion third_party/airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ dependencies = [
"apache-airflow>=2.3.1",
"grpcio>=1.46.3",
"grpcio-tools>=1.46.3",
"types-protobuf>=3.19.22"
"types-protobuf>=3.19.22",
"tenacity"
]
authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}]
license = { text = "Apache Software License" }
Expand Down
23 changes: 23 additions & 0 deletions third_party/airflow/tests/unit/job_service_mock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from armada.jobservice import jobservice_pb2, jobservice_pb2_grpc
import tenacity


# TODO - Make this a bit smarter, so we can hit at least one full
Expand All @@ -22,10 +23,32 @@ def mock_dummy_mapper_terminal(request):


class JobService(jobservice_pb2_grpc.JobServiceServicer):
@tenacity.retry(
stop=tenacity.stop_after_attempt(4),
wait=tenacity.wait_exponential(),
reraise=True,
)
def GetJobStatus(self, request, context):
return mock_dummy_mapper_terminal(request)

@tenacity.retry(
stop=tenacity.stop_after_attempt(4),
wait=tenacity.wait_exponential(),
reraise=True,
)
def Health(self, request, context):
return jobservice_pb2.HealthCheckResponse(
status=jobservice_pb2.HealthCheckResponse.SERVING
)

@tenacity.retry(
stop=tenacity.stop_after_attempt(4),
wait=tenacity.wait_exponential(),
reraise=True,
)
def tenacity_example(target_count):
current_count = JobService.tenacity_example.retry.statistics["attempt_number"]
if current_count < target_count:
raise IOError("dummy error")
else:
return
8 changes: 8 additions & 0 deletions third_party/airflow/tests/unit/test_tenacity_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from job_service_mock import JobService


def test_tenacity_retry():
target_count = 3
JobService.tenacity_example(target_count)
retry_count = JobService.tenacity_example.retry.statistics["attempt_number"]
assert retry_count == target_count
Loading