Skip to content

Commit

Permalink
Fix up retry policies and fix state change logging
Browse files Browse the repository at this point in the history
Co-authored-by: Martynas Asipauskas <Martynas.Asipauskas@gresearch.co.uk>
  • Loading branch information
masipauskas and Martynas Asipauskas authored Oct 7, 2024
1 parent c94e4d8 commit 308d34c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 17 deletions.
22 changes: 12 additions & 10 deletions third_party/airflow/armada/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def _create_channel(self) -> grpc.Channel:
)

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand All @@ -63,7 +64,8 @@ def cancel_job(self, job_context: RunningJobContext) -> RunningJobContext:
return dataclasses.replace(job_context, job_state=JobState.CANCELLED.name)

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand All @@ -88,7 +90,8 @@ def submit_job(
return RunningJobContext(queue, job.job_id, job_set_id, DateTime.utcnow())

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand All @@ -112,26 +115,24 @@ def refresh_context(
cluster = run_details.cluster
return dataclasses.replace(job_context, job_state=state.name, cluster=cluster)

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
reraise=True,
)
@log_exceptions
def context_from_xcom(self, ti: TaskInstance, re_attach: bool) -> RunningJobContext:
result = ti.xcom_pull(key="job_context")
if result:
return self.refresh_context(RunningJobContext(
return RunningJobContext(
armada_queue=result["armada_queue"],
job_id=result["armada_job_id"],
job_set_id=result["armada_job_set_id"],
job_state=result.get("armada_job_state", "UNKNOWN"),
submit_time=DateTime.utcnow() if re_attach else result.get("armada_job_submit_time", DateTime.utcnow()),
last_log_time=None if re_attach else result.get("armada_job_last_log_time", None)
), None)
)

return None

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand All @@ -141,6 +142,7 @@ def context_to_xcom(self, ti: TaskInstance, ctx: RunningJobContext, lookout_url:
"armada_job_id": ctx.job_id,
"armada_job_set_id": ctx.job_set_id,
"armada_job_submit_time": ctx.submit_time,
"armada_job_state": ctx.job_state,
"armada_job_last_log_time": ctx.last_log_time,
"armada_lookout_url": lookout_url,
})
Expand Down
3 changes: 2 additions & 1 deletion third_party/airflow/armada/log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def _k8s_client(self, k8s_context) -> client.CoreV1Api:


@tenacity.retry(
wait=tenacity.wait_exponential(max=15),
wait=tenacity.wait_exponential(max=3),
retry=tenacity.retry_if_exception_type(HTTPError),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
def fetch_container_logs(
Expand Down
12 changes: 8 additions & 4 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def pod_manager(self) -> KubernetesPodLogManager:
return KubernetesPodLogManager(token_retriever=self.k8s_token_retriever)

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=15),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
retry=tenacity.retry_if_not_exception_type(jinja2.TemplateSyntaxError),
reraise=True,
)
Expand Down Expand Up @@ -350,12 +351,14 @@ def _check_job_status_and_fetch_logs(self, context) -> None:
self.job_context = dataclasses.replace(
self.job_context, last_log_time=last_log_time
)
self.hook.context_to_xcom(context["ti"], self.job_context, self.lookout_url(self.job_context.job_id))
except Exception as e:
self.log.warning(f"Error fetching logs {e}")

self.hook.context_to_xcom(context["ti"], self.job_context, self.lookout_url(self.job_context.job_id))

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=2),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand All @@ -364,7 +367,8 @@ def _xcom_pull(self, context, key: str) -> Any:
return task_instance.xcom_pull(key=key)

@tenacity.retry(
wait=tenacity.wait_random_exponential(max=2),
wait=tenacity.wait_random_exponential(max=3),
stop=tenacity.stop_after_attempt(5),
reraise=True,
)
@log_exceptions
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "armada_airflow"
version = "1.0.7"
version = "1.0.8"
description = "Armada Airflow Operator"
readme='README.md'
authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}]
Expand Down
2 changes: 1 addition & 1 deletion third_party/airflow/test/unit/operators/test_armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_publishes_xcom_state(context):
op = operator(JobSubmitRequestItem())
op.execute(context)

op.hook.context_to_xcom.assert_called_once()
assert op.hook.context_to_xcom.call_count == 2


def test_reattaches_to_running_job(context):
Expand Down

0 comments on commit 308d34c

Please sign in to comment.