Skip to content

Commit

Permalink
feat(python-client, airflow): support startup_timeout for pending sta…
Browse files Browse the repository at this point in the history
…te waiting (#84)
  • Loading branch information
hussein-awala authored Jul 24, 2024
1 parent 1e0d779 commit 1a1ac96
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 56 deletions.
82 changes: 41 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,48 +157,48 @@ The Python client and the CLI can be configured with environment variables to av
time if you have a common configuration for all your apps. The environment variables are the same for both the client
and the CLI. Here is a list of the available environment variables:

| Environment Variable | Description | Default |
|-------------------------------------------|-------------------------------------------------------|--------------|
| SPARK_ON_K8S_DOCKER_IMAGE | The docker image to use for the spark pods | |
| SPARK_ON_K8S_APP_PATH | The path to the app file | |
| SPARK_ON_K8S_NAMESPACE | The namespace to use | default |
| SPARK_ON_K8S_SERVICE_ACCOUNT | The service account to use | spark |
| SPARK_ON_K8S_SPARK_CONF | The spark configuration to use | {} |
| SPARK_ON_K8S_CLASS_NAME | The class name to use | |
| Environment Variable | Description | Default |
|-------------------------------------------|-------------------------------------------------------|----------------|
| SPARK_ON_K8S_DOCKER_IMAGE | The docker image to use for the spark pods | |
| SPARK_ON_K8S_APP_PATH | The path to the app file | |
| SPARK_ON_K8S_NAMESPACE | The namespace to use | default |
| SPARK_ON_K8S_SERVICE_ACCOUNT | The service account to use | spark |
| SPARK_ON_K8S_SPARK_CONF | The spark configuration to use | {} |
| SPARK_ON_K8S_CLASS_NAME | The class name to use | |
| SPARK_ON_K8S_PACKAGES | The maven packages list to add to the classpath |
| SPARK_ON_K8S_APP_ARGUMENTS | The arguments to pass to the app | [] |
| SPARK_ON_K8S_APP_WAITER | The waiter to use to wait for the app to finish | no_wait |
| SPARK_ON_K8S_IMAGE_PULL_POLICY | The image pull policy to use | IfNotPresent |
| SPARK_ON_K8S_UI_REVERSE_PROXY | Whether to use a reverse proxy to access the spark UI | false |
| SPARK_ON_K8S_DRIVER_CPU | The driver CPU | 1 |
| SPARK_ON_K8S_DRIVER_MEMORY | The driver memory | 1024 |
| SPARK_ON_K8S_DRIVER_MEMORY_OVERHEAD | The driver memory overhead | 512 |
| SPARK_ON_K8S_EXECUTOR_CPU | The executor CPU | 1 |
| SPARK_ON_K8S_EXECUTOR_MEMORY | The executor memory | 1024 |
| SPARK_ON_K8S_EXECUTOR_MEMORY_OVERHEAD | The executor memory overhead | 512 |
| SPARK_ON_K8S_EXECUTOR_MIN_INSTANCES | The minimum number of executor instances | |
| SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES | The maximum number of executor instances | |
| SPARK_ON_K8S_EXECUTOR_INITIAL_INSTANCES | The initial number of executor instances | |
| SPARK_ON_K8S_CONFIG_FILE | The path to the config file | |
| SPARK_ON_K8S_CONTEXT | The context to use | |
| SPARK_ON_K8S_CLIENT_CONFIG | The sync Kubernetes client configuration to use | |
| SPARK_ON_K8S_ASYNC_CLIENT_CONFIG | The async Kubernetes client configuration to use | |
| SPARK_ON_K8S_IN_CLUSTER | Whether to use the in cluster Kubernetes config | false |
| SPARK_ON_K8S_API_DEFAULT_NAMESPACE | The default namespace to use for the API | default |
| SPARK_ON_K8S_API_HOST | The host to use for the API | 127.0.0.1 |
| SPARK_ON_K8S_API_PORT | The port to use for the API | 8000 |
| SPARK_ON_K8S_API_WORKERS | The number of workers to use for the API | 4 |
| SPARK_ON_K8S_API_LOG_LEVEL | The log level to use for the API | info |
| SPARK_ON_K8S_API_LIMIT_CONCURRENCY | The limit concurrency to use for the API | 1000 |
| SPARK_ON_K8S_API_SPARK_HISTORY_HOST | The host to use for the spark history server | |
| SPARK_ON_K8S_SPARK_DRIVER_NODE_SELECTOR | The node selector to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_NODE_SELECTOR | The node selector to use for the executor pods | {} |
| SPARK_ON_K8S_SPARK_DRIVER_LABELS | The labels to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_LABELS | The labels to use for the executor pods | {} |
| SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS | The annotations to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS | The annotations to use for the executor pods | {} |
| SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH | The path to the executor pod template | |

| SPARK_ON_K8S_APP_ARGUMENTS | The arguments to pass to the app | [] |
| SPARK_ON_K8S_APP_WAITER | The waiter to use to wait for the app to finish | no_wait |
| SPARK_ON_K8S_IMAGE_PULL_POLICY | The image pull policy to use | IfNotPresent |
| SPARK_ON_K8S_UI_REVERSE_PROXY | Whether to use a reverse proxy to access the spark UI | false |
| SPARK_ON_K8S_DRIVER_CPU | The driver CPU | 1 |
| SPARK_ON_K8S_DRIVER_MEMORY | The driver memory | 1024 |
| SPARK_ON_K8S_DRIVER_MEMORY_OVERHEAD | The driver memory overhead | 512 |
| SPARK_ON_K8S_EXECUTOR_CPU | The executor CPU | 1 |
| SPARK_ON_K8S_EXECUTOR_MEMORY | The executor memory | 1024 |
| SPARK_ON_K8S_EXECUTOR_MEMORY_OVERHEAD | The executor memory overhead | 512 |
| SPARK_ON_K8S_EXECUTOR_MIN_INSTANCES | The minimum number of executor instances | |
| SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES | The maximum number of executor instances | |
| SPARK_ON_K8S_EXECUTOR_INITIAL_INSTANCES | The initial number of executor instances | |
| SPARK_ON_K8S_CONFIG_FILE | The path to the config file | |
| SPARK_ON_K8S_CONTEXT | The context to use | |
| SPARK_ON_K8S_CLIENT_CONFIG | The sync Kubernetes client configuration to use | |
| SPARK_ON_K8S_ASYNC_CLIENT_CONFIG | The async Kubernetes client configuration to use | |
| SPARK_ON_K8S_IN_CLUSTER | Whether to use the in cluster Kubernetes config | false |
| SPARK_ON_K8S_API_DEFAULT_NAMESPACE | The default namespace to use for the API | default |
| SPARK_ON_K8S_API_HOST | The host to use for the API | 127.0.0.1 |
| SPARK_ON_K8S_API_PORT | The port to use for the API | 8000 |
| SPARK_ON_K8S_API_WORKERS | The number of workers to use for the API | 4 |
| SPARK_ON_K8S_API_LOG_LEVEL | The log level to use for the API | info |
| SPARK_ON_K8S_API_LIMIT_CONCURRENCY | The limit concurrency to use for the API | 1000 |
| SPARK_ON_K8S_API_SPARK_HISTORY_HOST | The host to use for the spark history server | |
| SPARK_ON_K8S_SPARK_DRIVER_NODE_SELECTOR | The node selector to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_NODE_SELECTOR | The node selector to use for the executor pods | {} |
| SPARK_ON_K8S_SPARK_DRIVER_LABELS | The labels to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_LABELS | The labels to use for the executor pods | {} |
| SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS | The annotations to use for the driver pod | {} |
| SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS | The annotations to use for the executor pods | {} |
| SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH | The path to the executor pod template | |
| SPARK_ON_K8S_STARTUP_TIMEOUT | The timeout to wait for the app to start in seconds | 0 (no timeout) |

## Examples

Expand Down
42 changes: 28 additions & 14 deletions spark_on_k8s/airflow/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class SparkOnK8SOperator(BaseOperator):
deferrable (bool, optional): Whether the operator is deferrable. Defaults to False.
on_kill_action (Literal["keep", "delete", "kill"], optional): Action to take when the
operator is killed. Defaults to "delete".
startup_timeout (int, optional): Timeout for the Spark application to start.
Defaults to 0 (no timeout).
**kwargs: Other keyword arguments for BaseOperator.
"""

Expand Down Expand Up @@ -151,6 +153,7 @@ def __init__(
poll_interval: int = 10,
deferrable: bool = False,
on_kill_action: Literal["keep", "delete", "kill"] = OnKillAction.DELETE,
startup_timeout: int = 0,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -188,6 +191,7 @@ def __init__(
self.poll_interval = poll_interval
self.deferrable = deferrable
self.on_kill_action = on_kill_action
self.startup_timeout = startup_timeout

def _render_nested_template_fields(
self,
Expand Down Expand Up @@ -366,23 +370,33 @@ def execute(self, context: Context):
),
method_name="execute_complete",
)
if self.app_waiter == "wait":
spark_app_manager.wait_for_app(
namespace=self.namespace,
pod_name=self._driver_pod_name,
poll_interval=self.poll_interval,
)
elif self.app_waiter == "log":
spark_app_manager.stream_logs(
namespace=self.namespace,
pod_name=self._driver_pod_name,
)
# wait for termination status
spark_app_manager.wait_for_app(
try:
if self.app_waiter == "wait":
spark_app_manager.wait_for_app(
namespace=self.namespace,
pod_name=self._driver_pod_name,
poll_interval=self.poll_interval,
startup_timeout=self.startup_timeout,
)
elif self.app_waiter == "log":
spark_app_manager.stream_logs(
namespace=self.namespace,
pod_name=self._driver_pod_name,
startup_timeout=self.startup_timeout,
)
# wait for termination status
spark_app_manager.wait_for_app(
namespace=self.namespace,
pod_name=self._driver_pod_name,
poll_interval=1,
)
except TimeoutError:
self.log.info("Deleting Spark application due to startup timeout...")
spark_app_manager.delete_app(
namespace=self.namespace,
pod_name=self._driver_pod_name,
poll_interval=1,
)
raise AirflowException("Spark application startup timeout exceeded") from None
app_status = spark_app_manager.app_status(
namespace=self.namespace,
pod_name=self._driver_pod_name,
Expand Down
11 changes: 10 additions & 1 deletion spark_on_k8s/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def submit_app(
executor_labels: dict[str, str] | ArgNotSet = NOTSET,
driver_tolerations: list[k8s.V1Toleration] | ArgNotSet = NOTSET,
executor_pod_template_path: str | ArgNotSet = NOTSET,
startup_timeout: int | ArgNotSet = NOTSET,
) -> str:
"""Submit a Spark app to Kubernetes
Expand Down Expand Up @@ -178,6 +179,7 @@ def submit_app(
executor_node_selector: Node selector for the executors
driver_tolerations: List of tolerations for the driver
executor_pod_template_path: Path to the executor pod template file
startup_timeout: Timeout in seconds to wait for the application to start
Returns:
Name of the Spark application pod
Expand Down Expand Up @@ -280,6 +282,8 @@ def submit_app(
driver_tolerations = []
if executor_pod_template_path is NOTSET or executor_pod_template_path is None:
executor_pod_template_path = Configuration.SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH
if startup_timeout is NOTSET:
startup_timeout = Configuration.SPARK_ON_K8S_STARTUP_TIMEOUT

spark_conf = spark_conf or {}
main_class_parameters = app_arguments or []
Expand Down Expand Up @@ -449,11 +453,16 @@ def submit_app(
self.app_manager.stream_logs(
namespace=namespace,
pod_name=pod.metadata.name,
startup_timeout=startup_timeout,
should_print=should_print,
)
elif app_waiter == SparkAppWait.WAIT:
self.app_manager.wait_for_app(
namespace=namespace, pod_name=pod.metadata.name, should_print=should_print
namespace=namespace,
pod_name=pod.metadata.name,
poll_interval=5,
startup_timeout=startup_timeout,
should_print=should_print,
)
return pod.metadata.name

Expand Down
15 changes: 15 additions & 0 deletions spark_on_k8s/utils/app_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def wait_for_app(
app_id: str | None = None,
poll_interval: float = 10,
should_print: bool = False,
startup_timeout: float = 0,
):
"""Wait for a Spark app to finish.
Expand All @@ -99,8 +100,11 @@ def wait_for_app(
pod_name (str): Pod name.
app_id (str): App ID.
poll_interval (float, optional): Poll interval in seconds. Defaults to 10.
startup_timeout (float, optional): Timeout in seconds to wait for the app to start.
Defaults to 0 (no timeout).
should_print (bool, optional): Whether to print logs instead of logging them.
"""
start_time = time.time()
termination_statuses = {SparkAppStatus.Succeeded, SparkAppStatus.Failed, SparkAppStatus.Unknown}
with self.k8s_client_manager.client() as client:
api = k8s.CoreV1Api(client)
Expand All @@ -111,6 +115,10 @@ def wait_for_app(
)
if status in termination_statuses:
break
if status == SparkAppStatus.Pending:
if startup_timeout and start_time + startup_timeout < time.time():
raise TimeoutError("App startup timeout")

except ApiException as e:
if e.status == 404:
self.log(
Expand All @@ -135,6 +143,7 @@ def stream_logs(
namespace: str,
pod_name: str | None = None,
app_id: str | None = None,
startup_timeout: float = 0,
should_print: bool = False,
):
"""Stream logs from a Spark app.
Expand All @@ -143,8 +152,11 @@ def stream_logs(
namespace (str): Namespace.
pod_name (str): Pod name.
app_id (str): App ID.
startup_timeout (float, optional): Timeout in seconds to wait for the app to start.
Defaults to 0 (no timeout).
should_print (bool, optional): Whether to print logs instead of logging them.
"""
start_time = time.time()
if pod_name is None and app_id is None:
raise ValueError("Either pod_name or app_id must be specified")
if pod_name is None:
Expand All @@ -166,6 +178,9 @@ def stream_logs(
)
if pod.status.phase != "Pending":
break
if startup_timeout and start_time + startup_timeout < time.time():
raise TimeoutError("App startup timeout")
time.sleep(5)
watcher = watch.Watch()
for line in watcher.stream(
api.read_namespaced_pod_log,
Expand Down
1 change: 1 addition & 0 deletions spark_on_k8s/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Configuration:
if getenv("SPARK_ON_K8S_DRIVER_ENV_VARS_FROM_SECRET")
else []
)
SPARK_ON_K8S_STARTUP_TIMEOUT = int(getenv("SPARK_ON_K8S_STARTUP_TIMEOUT", 0))

# Kubernetes client configuration
# K8S client configuration
Expand Down
52 changes: 52 additions & 0 deletions tests/airflow/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,55 @@ def test_job_adoption(
mock_submit_app.assert_called_once()
else:
mock_submit_app.assert_not_called()

@pytest.mark.parametrize(
"app_waiter",
[
pytest.param("log", id="log"),
pytest.param("wait", id="wait"),
],
)
@mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.stream_logs")
@mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.wait_for_app")
@mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.app_status")
@mock.patch("spark_on_k8s.client.SparkOnK8S.submit_app")
def test_startup_timeout(
self,
mock_submit_app,
mock_app_status,
mock_wait_for_app,
mock_stream_logs,
app_waiter,
):
from spark_on_k8s.airflow.operators import SparkOnK8SOperator

mock_app_status.return_value = SparkAppStatus.Succeeded
mock_submit_app.return_value = "test-pod-name"
spark_app_task = SparkOnK8SOperator(
task_id="spark_application",
namespace="test-namespace",
image="pyspark-job",
app_path="local:///opt/spark/work-dir/job.py",
startup_timeout=10,
app_waiter=app_waiter,
)
spark_app_task.execute(
{
"ti": mock.MagicMock(
xcom_pull=mock.MagicMock(side_effect=["test-namespace", "existing-pod"]),
)
}
)
if app_waiter == "log":
mock_stream_logs.assert_called_once_with(
namespace="test-namespace",
pod_name="test-pod-name",
startup_timeout=10,
)
else:
mock_wait_for_app.assert_called_once_with(
namespace="test-namespace",
pod_name="test-pod-name",
poll_interval=10,
startup_timeout=10,
)
Loading

0 comments on commit 1a1ac96

Please sign in to comment.