diff --git a/client/python/armada_client/__init__.py b/client/python/armada_client/__init__.py index e69de29bb2d..c1e297ad39a 100644 --- a/client/python/armada_client/__init__.py +++ b/client/python/armada_client/__init__.py @@ -0,0 +1,14 @@ +try: + from .typings import JobState + from ._proto_methods import is_active, is_terminal + + JobState.is_active = is_active + JobState.is_terminal = is_terminal + + del is_active, is_terminal, JobState +except ImportError: + """ + Import errors occur during proto generation, where certain + modules import types that don't exist yet. We can safely ignore these failures + """ + pass diff --git a/client/python/armada_client/_proto_methods.py b/client/python/armada_client/_proto_methods.py new file mode 100644 index 00000000000..608527842f1 --- /dev/null +++ b/client/python/armada_client/_proto_methods.py @@ -0,0 +1,38 @@ +from armada_client.typings import JobState + + +def is_terminal(self) -> bool: + """ + Determines if a job state is terminal. + + Terminal states indicate that a job has completed its lifecycle, + whether successfully or due to failure. + + :param state: The current state of the job. + :type state: JobState + + :returns: True if the job state is terminal, False if it is active. + :rtype: bool + """ + terminal_states = { + JobState.SUCCEEDED, + JobState.FAILED, + JobState.CANCELLED, + JobState.PREEMPTED, + } + return self.value in terminal_states + + +def is_active(self) -> bool: + """ + Determines if a job state is active. + + Active states indicate that a job is still running or in a non-terminal state. + + :param state: The current state of the job. + :type state: JobState + + :returns: True if the job state is active, False if it is terminal. + :rtype: bool + """ + return not is_terminal(self.value) diff --git a/client/python/armada_client/asyncio_client.py b/client/python/armada_client/asyncio_client.py index 96ef1c03991..301e7923445 100644 --- a/client/python/armada_client/asyncio_client.py +++ b/client/python/armada_client/asyncio_client.py @@ -18,6 +18,8 @@ submit_pb2, submit_pb2_grpc, health_pb2, + job_pb2, + job_pb2_grpc, ) from armada_client.event import Event from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -104,6 +106,7 @@ def __init__( ) -> None: self.submit_stub = submit_pb2_grpc.SubmitStub(channel) self.event_stub = event_pb2_grpc.EventStub(channel) + self.job_stub = job_pb2_grpc.JobsStub(channel) self.event_timeout = event_timeout async def get_job_events_stream( @@ -169,7 +172,7 @@ async def event_health(self) -> health_pb2.HealthCheckResponse: async def submit_jobs( self, queue: str, job_set_id: str, job_request_items ) -> AsyncIterator[submit_pb2.JobSubmitResponse]: - """Submit a armada job. + """Submit an armada job. Uses SubmitJobs RPC to submit a job. @@ -185,6 +188,48 @@ async def submit_jobs( response = await self.submit_stub.SubmitJobs(request) return response + async def get_job_status(self, job_ids: List[str]) -> job_pb2.JobStatusResponse: + """ + Asynchronously retrieves the status of a list of jobs from Armada. + + :param job_ids: A list of unique job identifiers. + :type job_ids: List[str] + + :returns: The response from the server containing the job status. + :rtype: JobStatusResponse + """ + req = job_pb2.JobStatusRequest(job_ids=job_ids) + resp = await self.job_stub.GetJobStatus(req) + return resp + + async def get_job_details(self, job_ids: List[str]) -> job_pb2.JobDetailsResponse: + """ + Asynchronously retrieves the details of a job from Armada. + + :param job_ids: A list of unique job identifiers. + :type job_ids: List[str] + + :returns: The Armada job details response. + """ + req = job_pb2.JobDetailsRequest(job_ids=job_ids, expand_job_run=True) + resp = await self.job_stub.GetJobDetails(req) + return resp + + async def get_job_run_details( + self, run_ids: List[str] + ) -> job_pb2.JobRunDetailsResponse: + """ + Asynchronously retrieves the details of a job run from Armada. + + :param run_ids: A list of unique job run identifiers. + :type run_ids: List[str] + + :returns: The Armada run details response. + """ + req = job_pb2.JobRunDetailsRequest(run_ids=run_ids) + resp = await self.job_stub.GetJobRunDetails(req) + return resp + async def cancel_jobs( self, queue: str, diff --git a/client/python/armada_client/client.py b/client/python/armada_client/client.py index 93da95e3aef..1cf36c7c2ed 100644 --- a/client/python/armada_client/client.py +++ b/client/python/armada_client/client.py @@ -17,6 +17,8 @@ submit_pb2, submit_pb2_grpc, health_pb2, + job_pb2, + job_pb2_grpc, ) from armada_client.event import Event from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -102,6 +104,7 @@ def __init__(self, channel, event_timeout: timedelta = timedelta(minutes=15)): self.submit_stub = submit_pb2_grpc.SubmitStub(channel) self.event_stub = event_pb2_grpc.EventStub(channel) self.event_timeout = event_timeout + self.job_stub = job_pb2_grpc.JobsStub(channel) def get_job_events_stream( self, @@ -161,10 +164,47 @@ def event_health(self) -> health_pb2.HealthCheckResponse: """ return self.event_stub.Health(request=empty_pb2.Empty()) + def get_job_status(self, job_ids: List[str]) -> job_pb2.JobStatusResponse: + """ + Retrieves the status of a list of jobs from Armada. + + :param job_ids: A list of unique job identifiers. + :type job_ids: List[str] + + :returns: The response from the server containing the job status. + :rtype: JobStatusResponse + """ + req = job_pb2.JobStatusRequest(job_ids=job_ids) + return self.job_stub.GetJobStatus(req) + + def get_job_details(self, job_ids: List[str]) -> job_pb2.JobDetailsResponse: + """ + Retrieves the details of a job from Armada. + + :param job_ids: A list of unique job identifiers. + :type job_ids: List[str] + + :returns: The Armada job details response. + """ + req = job_pb2.JobDetailsRequest(job_ids=job_ids, expand_job_run=True) + return self.job_stub.GetJobDetails(req) + + def get_job_run_details(self, run_ids: List[str]) -> job_pb2.JobRunDetailsResponse: + """ + Retrieves the details of a job run from Armada. + + :param run_ids: A list of unique job run identifiers. + :type run_ids: List[str] + + :returns: The Armada run details response. + """ + req = job_pb2.JobRunDetailsRequest(run_ids=run_ids) + return self.job_stub.GetJobRunDetails(req) + def submit_jobs( self, queue: str, job_set_id: str, job_request_items ) -> submit_pb2.JobSubmitResponse: - """Submit a armada job. + """Submit an armada job. Uses SubmitJobs RPC to submit a job. diff --git a/client/python/pyproject.toml b/client/python/pyproject.toml index c286f711ca3..587c1f7ab00 100644 --- a/client/python/pyproject.toml +++ b/client/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "armada_client" -version = "0.3.2" +version = "0.3.3" description = "Armada gRPC API python client" readme = "README.md" requires-python = ">=3.7" diff --git a/client/python/tests/unit/server_mock.py b/client/python/tests/unit/server_mock.py index edfc81b31aa..8d19101203b 100644 --- a/client/python/tests/unit/server_mock.py +++ b/client/python/tests/unit/server_mock.py @@ -1,11 +1,16 @@ from google.protobuf import empty_pb2 + from armada_client.armada import ( submit_pb2_grpc, submit_pb2, event_pb2, event_pb2_grpc, health_pb2, + job_pb2_grpc, + job_pb2, ) +from armada_client.armada.job_pb2 import JobRunState +from armada_client.armada.submit_pb2 import JobState class SubmitService(submit_pb2_grpc.SubmitServicer): @@ -101,3 +106,46 @@ def Health(self, request, context): return health_pb2.HealthCheckResponse( status=health_pb2.HealthCheckResponse.SERVING ) + + +class QueryAPIService(job_pb2_grpc.JobsServicer): + DEFAULT_JOB_DETAILS = { + "queue": "test_queue", + "jobset": "test_jobset", + "namespace": "test_namespace", + "state": JobState.RUNNING, + "cancel_reason": "", + "latest_run_id": "0", + } + + DEFAULT_JOB_RUN_DETAILS = { + "job_id": "0", + "cluster": "test_cluster", + "node": "test_node", + "state": JobRunState.RUN_STATE_RUNNING, + } + + def GetJobStatus(self, request, context): + return job_pb2.JobStatusResponse( + job_states={job: JobState.RUNNING for job in request.job_ids} + ) + + def GetJobDetails(self, request, context): + return job_pb2.JobDetailsResponse( + job_details={ + job: job_pb2.JobDetails( + job_id=job, **QueryAPIService.DEFAULT_JOB_DETAILS + ) + for job in request.job_ids + } + ) + + def GetJobRunDetails(self, request, context): + return job_pb2.JobRunDetailsResponse( + job_run_details={ + run: job_pb2.JobRunDetails( + run_id=run, **QueryAPIService.DEFAULT_JOB_RUN_DETAILS + ) + for run in request.run_ids + } + ) diff --git a/client/python/tests/unit/test_asyncio_client.py b/client/python/tests/unit/test_asyncio_client.py index a7aebdcb224..6f4d8709c23 100644 --- a/client/python/tests/unit/test_asyncio_client.py +++ b/client/python/tests/unit/test_asyncio_client.py @@ -4,9 +4,17 @@ import pytest import pytest_asyncio -from server_mock import EventService, SubmitService - -from armada_client.armada import event_pb2_grpc, submit_pb2_grpc, submit_pb2, health_pb2 +from armada_client.typings import JobState +from armada_client.armada.job_pb2 import JobRunState +from server_mock import EventService, SubmitService, QueryAPIService + +from armada_client.armada import ( + event_pb2_grpc, + submit_pb2_grpc, + submit_pb2, + health_pb2, + job_pb2_grpc, +) from armada_client.asyncio_client import ArmadaAsyncIOClient from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( @@ -14,7 +22,6 @@ ) from armada_client.permissions import Permissions, Subject -from armada_client.typings import JobState @pytest.fixture @@ -22,6 +29,7 @@ def server_mock(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server) event_pb2_grpc.add_EventServicer_to_server(EventService(), server) + job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server) server.add_insecure_port("[::]:50051") server.start() yield @@ -302,3 +310,36 @@ async def test_health_submit(aio_client): async def test_health_event(aio_client): health = await aio_client.event_health() assert health.SERVING == health_pb2.HealthCheckResponse.SERVING + + +@pytest.mark.asyncio +async def test_job_status(aio_client): + await test_create_queue(aio_client) + await test_submit_job(aio_client) + + job_status_response = await aio_client.get_job_status(["job-1"]) + assert job_status_response.job_states["job-1"] == submit_pb2.JobState.RUNNING + + +@pytest.mark.asyncio +async def test_job_details(aio_client): + await test_create_queue(aio_client) + await test_submit_job(aio_client) + + job_details_response = await aio_client.get_job_details(["job-1"]) + job_details = job_details_response.job_details + assert job_details["job-1"].state == submit_pb2.JobState.RUNNING + assert job_details["job-1"].job_id == "job-1" + assert job_details["job-1"].queue == "test_queue" + + +@pytest.mark.asyncio +async def test_job_run_details(aio_client): + await test_create_queue(aio_client) + await test_submit_job(aio_client) + + run_details_response = await aio_client.get_job_run_details(["run-1"]) + run_details = run_details_response.job_run_details + assert run_details["run-1"].state == JobRunState.RUN_STATE_RUNNING + assert run_details["run-1"].run_id == "run-1" + assert run_details["run-1"].cluster == "test_cluster" diff --git a/client/python/tests/unit/test_client.py b/client/python/tests/unit/test_client.py index ab227dff7b7..70eba72439b 100644 --- a/client/python/tests/unit/test_client.py +++ b/client/python/tests/unit/test_client.py @@ -3,9 +3,17 @@ import grpc import pytest -from server_mock import EventService, SubmitService - -from armada_client.armada import event_pb2_grpc, submit_pb2_grpc, submit_pb2, health_pb2 +from armada_client.typings import JobState +from armada_client.armada.job_pb2 import JobRunState +from server_mock import EventService, SubmitService, QueryAPIService + +from armada_client.armada import ( + event_pb2_grpc, + submit_pb2_grpc, + submit_pb2, + health_pb2, + job_pb2_grpc, +) from armada_client.client import ArmadaClient from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( @@ -13,7 +21,6 @@ ) from armada_client.permissions import Permissions, Subject -from armada_client.typings import JobState @pytest.fixture(scope="session", autouse=True) @@ -21,6 +28,7 @@ def server_mock(): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server) event_pb2_grpc.add_EventServicer_to_server(EventService(), server) + job_pb2_grpc.add_JobsServicer_to_server(QueryAPIService(), server) server.add_insecure_port("[::]:50051") server.start() @@ -278,3 +286,31 @@ def test_health_submit(): def test_health_event(): health = tester.event_health() assert health.SERVING == health_pb2.HealthCheckResponse.SERVING + + +def test_job_status(): + test_create_queue() + test_submit_job() + + job_status_response = tester.get_job_status(["job-1"]) + assert job_status_response.job_states["job-1"] == submit_pb2.JobState.RUNNING + + +def test_job_details(): + test_create_queue() + test_submit_job() + + job_details = tester.get_job_details(["job-1"]).job_details + assert job_details["job-1"].state == submit_pb2.JobState.RUNNING + assert job_details["job-1"].job_id == "job-1" + assert job_details["job-1"].queue == "test_queue" + + +def test_job_run_details(): + test_create_queue() + test_submit_job() + + run_details = tester.get_job_run_details(["run-1"]).job_run_details + assert run_details["run-1"].state == JobRunState.RUN_STATE_RUNNING + assert run_details["run-1"].run_id == "run-1" + assert run_details["run-1"].cluster == "test_cluster" diff --git a/docs/python_armada_client.md b/docs/python_armada_client.md index c0eca79b234..e2dc1228f89 100644 --- a/docs/python_armada_client.md +++ b/docs/python_armada_client.md @@ -255,6 +255,28 @@ Health check for Event Service. +#### get_job_details(job_ids) +Retrieves the details of a job from Armada. + + +* **Parameters** + + **job_ids** (*List**[**str**]*) – A list of unique job identifiers. + + + +* **Returns** + + The Armada job details response. + + + +* **Return type** + + armada.job_pb2.JobDetailsResponse + + + #### get_job_events_stream(queue, job_set_id, from_message_id=None) Get event stream for a job set. @@ -296,6 +318,50 @@ for event in events: +#### get_job_run_details(run_ids) +Retrieves the details of a job run from Armada. + + +* **Parameters** + + **run_ids** (*List**[**str**]*) – A list of unique job run identifiers. + + + +* **Returns** + + The Armada run details response. + + + +* **Return type** + + armada.job_pb2.JobRunDetailsResponse + + + +#### get_job_status(job_ids) +Retrieves the status of a list of jobs from Armada. + + +* **Parameters** + + **job_ids** (*List**[**str**]*) – A list of unique job identifiers. + + + +* **Returns** + + The response from the server containing the job status. + + + +* **Return type** + + JobStatusResponse + + + #### get_queue(name) Get the queue by name. @@ -398,7 +464,7 @@ Health check for Submit Service. #### submit_jobs(queue, job_set_id, job_request_items) -Submit a armada job. +Submit an armada job. Uses SubmitJobs RPC to submit a job.