Skip to content

Commit

Permalink
feat(airflow): try adopt running job in case of unexpected failures (#42
Browse files Browse the repository at this point in the history
)
  • Loading branch information
hussein-awala authored May 4, 2024
1 parent 2e1bf35 commit 204b311
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
49 changes: 41 additions & 8 deletions spark_on_k8s/airflow/operators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from enum import Enum
from typing import TYPE_CHECKING, Any

Expand All @@ -16,6 +17,7 @@

from airflow.utils.context import Context
from spark_on_k8s.client import ExecutorInstances, PodResources
from spark_on_k8s.utils.app_manager import SparkAppManager


class _AirflowKubernetesClientManager(KubernetesClientManager):
Expand Down Expand Up @@ -80,6 +82,9 @@ class SparkOnK8SOperator(BaseOperator):
**kwargs: Other keyword arguments for BaseOperator.
"""

_XCOM_DRIVER_POD_NAMESPACE = "driver_pod_namespace"
_XCOM_DRIVER_POD_NAME = "driver_pod_name"

_driver_pod_name: str | None = None

template_fields = (
Expand Down Expand Up @@ -197,9 +202,30 @@ def _render_nested_template_fields(

super()._render_nested_template_fields(content, context, jinja_env, seen_oids)

def execute(self, context):
def _persist_pod_name(self, context: Context):
context["ti"].xcom_push(key=self._XCOM_DRIVER_POD_NAMESPACE, value=self.namespace)
context["ti"].xcom_push(key=self._XCOM_DRIVER_POD_NAME, value=self._driver_pod_name)

def _try_to_adopt_job(self, context: Context, spark_app_manager: SparkAppManager) -> bool:
from spark_on_k8s.utils.spark_app_status import SparkAppStatus

xcom_driver_namespace = context["ti"].xcom_pull(key=self._XCOM_DRIVER_POD_NAMESPACE)
if not xcom_driver_namespace or xcom_driver_namespace != self.namespace:
return False
xcom_driver_pod_name = context["ti"].xcom_pull(key=self._XCOM_DRIVER_POD_NAME)
if xcom_driver_pod_name:
with contextlib.suppress(Exception):
app_status = spark_app_manager.app_status(
namespace=xcom_driver_namespace,
pod_name=xcom_driver_pod_name,
)
if app_status == SparkAppStatus.Running:
self._driver_pod_name = xcom_driver_pod_name
return True
return False

def _submit_new_job(self, context: Context):
from spark_on_k8s.client import ExecutorInstances, PodResources, SparkOnK8S
from spark_on_k8s.utils.app_manager import SparkAppManager

# post-process template fields
if self.driver_resources:
Expand Down Expand Up @@ -267,6 +293,19 @@ def execute(self, context):
executor_pod_template_path=self.executor_pod_template_path,
**submit_app_kwargs,
)

def execute(self, context: Context):
from spark_on_k8s.utils.app_manager import SparkAppManager

k8s_client_manager = _AirflowKubernetesClientManager(
kubernetes_conn_id=self.kubernetes_conn_id,
)
spark_app_manager = SparkAppManager(
k8s_client_manager=k8s_client_manager,
)
if not self._try_to_adopt_job(context, spark_app_manager):
self._submit_new_job(context)
self._persist_pod_name(context)
if self.app_waiter == "no_wait":
return
if self.deferrable:
Expand All @@ -279,12 +318,6 @@ def execute(self, context):
),
method_name="execute_complete",
)
k8s_client_manager = _AirflowKubernetesClientManager(
kubernetes_conn_id=self.kubernetes_conn_id,
)
spark_app_manager = SparkAppManager(
k8s_client_manager=k8s_client_manager,
)
if self.app_waiter == "wait":
spark_app_manager.wait_for_app(
namespace=self.namespace,
Expand Down
52 changes: 50 additions & 2 deletions tests/airflow/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

import pytest
from spark_on_k8s.utils.spark_app_status import SparkAppStatus

from conftest import PYTHON_312

Expand Down Expand Up @@ -43,7 +44,13 @@ def test_execute(self, mock_submit_app):
driver_tolerations=test_tolerations,
executor_pod_template_path="s3a://bucket/executor.yml",
)
spark_app_task.execute(None)
spark_app_task.execute(
{
"ti": mock.MagicMock(
xcom_pull=mock.MagicMock(return_value=None),
)
}
)
mock_submit_app.assert_called_once_with(
namespace="spark",
image="pyspark-job",
Expand Down Expand Up @@ -129,7 +136,13 @@ def test_rendering_templates(self, mock_submit_app):
"template_secret_value": "value from connection",
},
)
spark_app_task.execute(None)
spark_app_task.execute(
{
"ti": mock.MagicMock(
xcom_pull=mock.MagicMock(return_value=None),
)
}
)
app_id_suffix_kwarg = mock_submit_app.call_args.kwargs.get("app_id_suffix")
mock_submit_app.assert_called_once_with(
namespace="spark",
Expand Down Expand Up @@ -167,3 +180,38 @@ def test_rendering_templates(self, mock_submit_app):
executor_pod_template_path=None,
)
assert app_id_suffix_kwarg() == "-suffix"

@pytest.mark.parametrize(
"job_status, should_submit",
[
(SparkAppStatus.Running, False),
(SparkAppStatus.Succeeded, True),
(SparkAppStatus.Failed, True),
],
)
@mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.wait_for_app")
@mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.app_status")
@mock.patch("spark_on_k8s.client.SparkOnK8S.submit_app")
def test_job_adoption(
self, mock_submit_app, mock_app_status, mock_wait_for_app, job_status, should_submit
):
from spark_on_k8s.airflow.operators import SparkOnK8SOperator

mock_app_status.side_effect = [job_status, SparkAppStatus.Succeeded]
spark_app_task = SparkOnK8SOperator(
task_id="spark_application",
namespace="test-namespace",
image="pyspark-job",
app_path="local:///opt/spark/work-dir/job.py",
)
spark_app_task.execute(
{
"ti": mock.MagicMock(
xcom_pull=mock.MagicMock(side_effect=["test-namespace", "existing-pod"]),
)
}
)
if should_submit:
mock_submit_app.assert_called_once()
else:
mock_submit_app.assert_not_called()

0 comments on commit 204b311

Please sign in to comment.