From eb02b823e9fc128e77253e09ef0e823fe81edc9a Mon Sep 17 00:00:00 2001 From: Martynas Asipauskas Date: Wed, 21 Aug 2024 17:37:25 +0100 Subject: [PATCH] Fixing linting / ci errors (#202) --- docs/python_airflow_operator.md | 12 +++++++----- third_party/airflow/armada/hooks.py | 9 +++++---- third_party/airflow/armada/model.py | 3 +-- third_party/airflow/armada/operators/armada.py | 5 +++-- .../airflow/test/operators/test_armada.py | 18 +++++++++--------- third_party/airflow/test/test_model.py | 2 +- third_party/airflow/tox.ini | 2 +- 7 files changed, 27 insertions(+), 24 deletions(-) diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 91ee1a728cd..32fa081fe83 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -119,6 +119,8 @@ Args: #### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ ) + +#### template_fields_renderers(_: Dict[str, str_ _ = {'job_request': 'py'_ ) Initializes a new ArmadaOperator. @@ -225,7 +227,7 @@ Bases: `object` -### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', cluster: 'Optional[str]' = None, start_time: 'DateTime' = DateTime(2024, 8, 21, 14, 36, 0, 499682, tzinfo=Timezone('UTC')), last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN') +### _class_ armada.model.RunningJobContext(armada_queue: 'str', job_id: 'str', job_set_id: 'str', submit_time: 'DateTime', cluster: 'Optional[str]' = None, last_log_time: 'Optional[DateTime]' = None, job_state: 'str' = 'UNKNOWN') Bases: `object` @@ -241,10 +243,10 @@ Bases: `object` * **job_set_id** (*str*) – - * **cluster** (*str** | **None*) – + * **submit_time** (*DateTime*) – - * **start_time** (*DateTime*) – + * **cluster** (*str** | **None*) – * **last_log_time** (*DateTime** | **None*) – @@ -266,6 +268,6 @@ Bases: `object` #### last_log_time(_: DateTime | Non_ _ = Non_ ) -#### start_time(_: DateTim_ _ = DateTime(2024, 8, 21, 14, 36, 0, 499682, tzinfo=Timezone('UTC')_ ) - #### _property_ state(_: JobStat_ ) + +#### submit_time(_: DateTim_ ) diff --git a/third_party/airflow/armada/hooks.py b/third_party/airflow/armada/hooks.py index 09fe7342739..a894d09249e 100644 --- a/third_party/airflow/armada/hooks.py +++ b/third_party/airflow/armada/hooks.py @@ -13,6 +13,7 @@ from armada_client.armada.submit_pb2 import JobSubmitRequestItem from armada_client.client import ArmadaClient from armada_client.typings import JobState +from pendulum import DateTime from .model import RunningJobContext @@ -33,12 +34,12 @@ def client_for(args: GrpcChannelArgs) -> ArmadaClient: with ArmadaClientFactory.CLIENTS_LOCK: if channel_args_key not in ArmadaClientFactory.CLIENTS: ArmadaClientFactory.CLIENTS[channel_args_key] = ArmadaClient( - channel=ArmadaClientFactory.create_channel(args) + channel=ArmadaClientFactory._create_channel(args) ) return ArmadaClientFactory.CLIENTS[channel_args_key] @staticmethod - def create_channel(args: GrpcChannelArgs) -> grpc.Channel: + def _create_channel(args: GrpcChannelArgs) -> grpc.Channel: if args.auth is None: return grpc.insecure_channel( target=args.target, options=args.options, compression=args.compression @@ -97,9 +98,9 @@ def submit_job( if job.error: raise AirflowException(f"Error submitting job to Armada: {job.error}") - return RunningJobContext(queue, job.job_id, job_set_id, None) + return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow()) - def update_context( + def refresh_context( self, job_context: RunningJobContext, tracking_url: str ) -> RunningJobContext: response = self.client.get_job_status([job_context.job_id]) diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py index ca9c0520aab..91db62420e0 100644 --- a/third_party/airflow/armada/model.py +++ b/third_party/airflow/armada/model.py @@ -22,7 +22,6 @@ def __init__( options: Optional[Sequence[Tuple[str, Any]]] = None, compression: Optional[grpc.Compression] = None, auth: Optional[grpc.AuthMetadataPlugin] = None, - # auth_details: Optional[Dict[str, Any]] = None, ): self.target = target self.options = options @@ -59,8 +58,8 @@ class RunningJobContext: armada_queue: str job_id: str job_set_id: str + submit_time: DateTime cluster: Optional[str] = None - start_time: DateTime = DateTime.utcnow() last_log_time: Optional[DateTime] = None job_state: str = JobState.UNKNOWN.name diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 81da89aff29..aa07227b80e 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -237,6 +237,7 @@ def _reattach_or_submit_job( armada_queue=existing_run["armada_queue"], job_id=existing_run["armada_job_id"], job_set_id=existing_run["armada_job_set_id"], + submit_time=DateTime.utcnow(), ) # We haven't got a running job, submit a new one and persist state to xcom. @@ -276,7 +277,7 @@ def _running_job_terminated(self, context: RunningJobContext): def _not_acknowledged_within_timeout(self) -> bool: if self.job_context.state == JobState.UNKNOWN: if ( - DateTime.utcnow().diff(self.job_context.start_time).in_seconds() + DateTime.utcnow().diff(self.job_context.submit_time).in_seconds() > self.job_acknowledgement_timeout ): return True @@ -284,7 +285,7 @@ def _not_acknowledged_within_timeout(self) -> bool: @log_exceptions def _check_job_status_and_fetch_logs(self) -> None: - self.job_context = self.hook.update_context( + self.job_context = self.hook.refresh_context( self.job_context, self._trigger_tracking_message(self.job_context.job_id) ) diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py index 899a70ea0bd..c9435cafaed 100644 --- a/third_party/airflow/test/operators/test_armada.py +++ b/third_party/airflow/test/operators/test_armada.py @@ -24,7 +24,7 @@ def default_hook() -> MagicMock: mock = MagicMock() job_context = running_job_context() mock.submit_job.return_value = job_context - mock.update_context.return_value = dataclasses.replace( + mock.refresh_context.return_value = dataclasses.replace( job_context, job_state=JobState.SUCCEEDED.name, cluster=DEFAULT_CLUSTER ) mock.cancel_job.return_value = dataclasses.replace( @@ -88,15 +88,15 @@ def operator( def running_job_context( cluster: str = None, - start_time: DateTime = DateTime.now(), + submit_time: DateTime = DateTime.now(), job_state: str = JobState.UNKNOWN.name, ) -> RunningJobContext: return RunningJobContext( DEFAULT_QUEUE, DEFAULT_JOB_ID, DEFAULT_JOB_SET, + submit_time, cluster, - start_time, job_state=job_state, ) @@ -118,7 +118,7 @@ def running_job_context( def test_execute(job_states, context): op = operator(JobSubmitRequestItem()) - op.hook.update_context.side_effect = [ + op.hook.refresh_context.side_effect = [ running_job_context(cluster="cluster-1", job_state=s.name) for s in job_states ] @@ -127,7 +127,7 @@ def test_execute(job_states, context): op.hook.submit_job.assert_called_once_with( DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request ) - assert op.hook.update_context.call_count == len(job_states) + assert op.hook.refresh_context.call_count == len(job_states) # We're not polling for logs op.pod_manager.fetch_container_logs.assert_not_called() @@ -136,7 +136,7 @@ def test_execute(job_states, context): @patch("pendulum.DateTime.utcnow", return_value=DEFAULT_CURRENT_TIME) def test_execute_in_deferrable(_, context): op = operator(JobSubmitRequestItem(), deferrable=True) - op.hook.update_context.side_effect = [ + op.hook.refresh_context.side_effect = [ running_job_context(cluster="cluster-1", job_state=s.name) for s in [JobState.QUEUED, JobState.QUEUED] ] @@ -164,7 +164,7 @@ def test_execute_in_deferrable(_, context): def test_execute_fail(terminal_state, context): op = operator(JobSubmitRequestItem()) - op.hook.update_context.side_effect = [ + op.hook.refresh_context.side_effect = [ running_job_context(cluster="cluster-1", job_state=s.name) for s in [JobState.RUNNING, terminal_state] ] @@ -179,7 +179,7 @@ def test_execute_fail(terminal_state, context): op.hook.submit_job.assert_called_once_with( DEFAULT_QUEUE, DEFAULT_JOB_SET, op.job_request ) - assert op.hook.update_context.call_count == 2 + assert op.hook.refresh_context.call_count == 2 # We're not polling for logs op.pod_manager.fetch_container_logs.assert_not_called() @@ -200,7 +200,7 @@ def test_on_kill_terminates_running_job(): def test_not_acknowledged_within_timeout_terminates_running_job(context): job_context = running_job_context() op = operator(JobSubmitRequestItem(), job_acknowledgement_timeout_s=-1) - op.hook.update_context.return_value = job_context + op.hook.refresh_context.return_value = job_context with pytest.raises(AirflowException) as exec_info: op.execute(context) diff --git a/third_party/airflow/test/test_model.py b/third_party/airflow/test/test_model.py index d633681ced6..906b7315ad9 100644 --- a/third_party/airflow/test/test_model.py +++ b/third_party/airflow/test/test_model.py @@ -10,8 +10,8 @@ def test_roundtrip_running_job_context(): "queue_123", "job_id_123", "job_set_id_123", - "cluster-1.armada.localhost", DateTime.utcnow(), + "cluster-1.armada.localhost", DateTime.utcnow().add(minutes=-2), JobState.RUNNING.name, ) diff --git a/third_party/airflow/tox.ini b/third_party/airflow/tox.ini index 09dd8ce15ea..c8434a36981 100644 --- a/third_party/airflow/tox.ini +++ b/third_party/airflow/tox.ini @@ -13,7 +13,7 @@ allowlist_externals = find xargs commands = - coverage run -m unittest discover + coverage run -m pytest test/ coverage xml # This executes the dag files in examples but really only checks for imports and python errors bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3"