diff --git a/README.md b/README.md index 92654cb..b7d7381 100644 --- a/README.md +++ b/README.md @@ -157,48 +157,48 @@ The Python client and the CLI can be configured with environment variables to av time if you have a common configuration for all your apps. The environment variables are the same for both the client and the CLI. Here is a list of the available environment variables: -| Environment Variable | Description | Default | -|-------------------------------------------|-------------------------------------------------------|--------------| -| SPARK_ON_K8S_DOCKER_IMAGE | The docker image to use for the spark pods | | -| SPARK_ON_K8S_APP_PATH | The path to the app file | | -| SPARK_ON_K8S_NAMESPACE | The namespace to use | default | -| SPARK_ON_K8S_SERVICE_ACCOUNT | The service account to use | spark | -| SPARK_ON_K8S_SPARK_CONF | The spark configuration to use | {} | -| SPARK_ON_K8S_CLASS_NAME | The class name to use | | +| Environment Variable | Description | Default | +|-------------------------------------------|-------------------------------------------------------|----------------| +| SPARK_ON_K8S_DOCKER_IMAGE | The docker image to use for the spark pods | | +| SPARK_ON_K8S_APP_PATH | The path to the app file | | +| SPARK_ON_K8S_NAMESPACE | The namespace to use | default | +| SPARK_ON_K8S_SERVICE_ACCOUNT | The service account to use | spark | +| SPARK_ON_K8S_SPARK_CONF | The spark configuration to use | {} | +| SPARK_ON_K8S_CLASS_NAME | The class name to use | | | SPARK_ON_K8S_PACKAGES | The maven packages list to add to the classpath | -| SPARK_ON_K8S_APP_ARGUMENTS | The arguments to pass to the app | [] | -| SPARK_ON_K8S_APP_WAITER | The waiter to use to wait for the app to finish | no_wait | -| SPARK_ON_K8S_IMAGE_PULL_POLICY | The image pull policy to use | IfNotPresent | -| SPARK_ON_K8S_UI_REVERSE_PROXY | Whether to use a reverse proxy to access the spark UI | false | -| SPARK_ON_K8S_DRIVER_CPU | The driver CPU | 1 | -| SPARK_ON_K8S_DRIVER_MEMORY | The driver memory | 1024 | -| SPARK_ON_K8S_DRIVER_MEMORY_OVERHEAD | The driver memory overhead | 512 | -| SPARK_ON_K8S_EXECUTOR_CPU | The executor CPU | 1 | -| SPARK_ON_K8S_EXECUTOR_MEMORY | The executor memory | 1024 | -| SPARK_ON_K8S_EXECUTOR_MEMORY_OVERHEAD | The executor memory overhead | 512 | -| SPARK_ON_K8S_EXECUTOR_MIN_INSTANCES | The minimum number of executor instances | | -| SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES | The maximum number of executor instances | | -| SPARK_ON_K8S_EXECUTOR_INITIAL_INSTANCES | The initial number of executor instances | | -| SPARK_ON_K8S_CONFIG_FILE | The path to the config file | | -| SPARK_ON_K8S_CONTEXT | The context to use | | -| SPARK_ON_K8S_CLIENT_CONFIG | The sync Kubernetes client configuration to use | | -| SPARK_ON_K8S_ASYNC_CLIENT_CONFIG | The async Kubernetes client configuration to use | | -| SPARK_ON_K8S_IN_CLUSTER | Whether to use the in cluster Kubernetes config | false | -| SPARK_ON_K8S_API_DEFAULT_NAMESPACE | The default namespace to use for the API | default | -| SPARK_ON_K8S_API_HOST | The host to use for the API | 127.0.0.1 | -| SPARK_ON_K8S_API_PORT | The port to use for the API | 8000 | -| SPARK_ON_K8S_API_WORKERS | The number of workers to use for the API | 4 | -| SPARK_ON_K8S_API_LOG_LEVEL | The log level to use for the API | info | -| SPARK_ON_K8S_API_LIMIT_CONCURRENCY | The limit concurrency to use for the API | 1000 | -| SPARK_ON_K8S_API_SPARK_HISTORY_HOST | The host to use for the spark history server | | -| SPARK_ON_K8S_SPARK_DRIVER_NODE_SELECTOR | The node selector to use for the driver pod | {} | -| SPARK_ON_K8S_SPARK_EXECUTOR_NODE_SELECTOR | The node selector to use for the executor pods | {} | -| SPARK_ON_K8S_SPARK_DRIVER_LABELS | The labels to use for the driver pod | {} | -| SPARK_ON_K8S_SPARK_EXECUTOR_LABELS | The labels to use for the executor pods | {} | -| SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS | The annotations to use for the driver pod | {} | -| SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS | The annotations to use for the executor pods | {} | -| SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH | The path to the executor pod template | | - +| SPARK_ON_K8S_APP_ARGUMENTS | The arguments to pass to the app | [] | +| SPARK_ON_K8S_APP_WAITER | The waiter to use to wait for the app to finish | no_wait | +| SPARK_ON_K8S_IMAGE_PULL_POLICY | The image pull policy to use | IfNotPresent | +| SPARK_ON_K8S_UI_REVERSE_PROXY | Whether to use a reverse proxy to access the spark UI | false | +| SPARK_ON_K8S_DRIVER_CPU | The driver CPU | 1 | +| SPARK_ON_K8S_DRIVER_MEMORY | The driver memory | 1024 | +| SPARK_ON_K8S_DRIVER_MEMORY_OVERHEAD | The driver memory overhead | 512 | +| SPARK_ON_K8S_EXECUTOR_CPU | The executor CPU | 1 | +| SPARK_ON_K8S_EXECUTOR_MEMORY | The executor memory | 1024 | +| SPARK_ON_K8S_EXECUTOR_MEMORY_OVERHEAD | The executor memory overhead | 512 | +| SPARK_ON_K8S_EXECUTOR_MIN_INSTANCES | The minimum number of executor instances | | +| SPARK_ON_K8S_EXECUTOR_MAX_INSTANCES | The maximum number of executor instances | | +| SPARK_ON_K8S_EXECUTOR_INITIAL_INSTANCES | The initial number of executor instances | | +| SPARK_ON_K8S_CONFIG_FILE | The path to the config file | | +| SPARK_ON_K8S_CONTEXT | The context to use | | +| SPARK_ON_K8S_CLIENT_CONFIG | The sync Kubernetes client configuration to use | | +| SPARK_ON_K8S_ASYNC_CLIENT_CONFIG | The async Kubernetes client configuration to use | | +| SPARK_ON_K8S_IN_CLUSTER | Whether to use the in cluster Kubernetes config | false | +| SPARK_ON_K8S_API_DEFAULT_NAMESPACE | The default namespace to use for the API | default | +| SPARK_ON_K8S_API_HOST | The host to use for the API | 127.0.0.1 | +| SPARK_ON_K8S_API_PORT | The port to use for the API | 8000 | +| SPARK_ON_K8S_API_WORKERS | The number of workers to use for the API | 4 | +| SPARK_ON_K8S_API_LOG_LEVEL | The log level to use for the API | info | +| SPARK_ON_K8S_API_LIMIT_CONCURRENCY | The limit concurrency to use for the API | 1000 | +| SPARK_ON_K8S_API_SPARK_HISTORY_HOST | The host to use for the spark history server | | +| SPARK_ON_K8S_SPARK_DRIVER_NODE_SELECTOR | The node selector to use for the driver pod | {} | +| SPARK_ON_K8S_SPARK_EXECUTOR_NODE_SELECTOR | The node selector to use for the executor pods | {} | +| SPARK_ON_K8S_SPARK_DRIVER_LABELS | The labels to use for the driver pod | {} | +| SPARK_ON_K8S_SPARK_EXECUTOR_LABELS | The labels to use for the executor pods | {} | +| SPARK_ON_K8S_SPARK_DRIVER_ANNOTATIONS | The annotations to use for the driver pod | {} | +| SPARK_ON_K8S_SPARK_EXECUTOR_ANNOTATIONS | The annotations to use for the executor pods | {} | +| SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH | The path to the executor pod template | | +| SPARK_ON_K8S_STARTUP_TIMEOUT | The timeout to wait for the app to start in seconds | 0 (no timeout) | ## Examples diff --git a/spark_on_k8s/airflow/operators.py b/spark_on_k8s/airflow/operators.py index cd1fbc6..c4a9f24 100644 --- a/spark_on_k8s/airflow/operators.py +++ b/spark_on_k8s/airflow/operators.py @@ -84,6 +84,8 @@ class SparkOnK8SOperator(BaseOperator): deferrable (bool, optional): Whether the operator is deferrable. Defaults to False. on_kill_action (Literal["keep", "delete", "kill"], optional): Action to take when the operator is killed. Defaults to "delete". + startup_timeout (int, optional): Timeout for the Spark application to start. + Defaults to 0 (no timeout). **kwargs: Other keyword arguments for BaseOperator. """ @@ -151,6 +153,7 @@ def __init__( poll_interval: int = 10, deferrable: bool = False, on_kill_action: Literal["keep", "delete", "kill"] = OnKillAction.DELETE, + startup_timeout: int = 0, **kwargs, ): super().__init__(**kwargs) @@ -188,6 +191,7 @@ def __init__( self.poll_interval = poll_interval self.deferrable = deferrable self.on_kill_action = on_kill_action + self.startup_timeout = startup_timeout def _render_nested_template_fields( self, @@ -366,23 +370,33 @@ def execute(self, context: Context): ), method_name="execute_complete", ) - if self.app_waiter == "wait": - spark_app_manager.wait_for_app( - namespace=self.namespace, - pod_name=self._driver_pod_name, - poll_interval=self.poll_interval, - ) - elif self.app_waiter == "log": - spark_app_manager.stream_logs( - namespace=self.namespace, - pod_name=self._driver_pod_name, - ) - # wait for termination status - spark_app_manager.wait_for_app( + try: + if self.app_waiter == "wait": + spark_app_manager.wait_for_app( + namespace=self.namespace, + pod_name=self._driver_pod_name, + poll_interval=self.poll_interval, + startup_timeout=self.startup_timeout, + ) + elif self.app_waiter == "log": + spark_app_manager.stream_logs( + namespace=self.namespace, + pod_name=self._driver_pod_name, + startup_timeout=self.startup_timeout, + ) + # wait for termination status + spark_app_manager.wait_for_app( + namespace=self.namespace, + pod_name=self._driver_pod_name, + poll_interval=1, + ) + except TimeoutError: + self.log.info("Deleting Spark application due to startup timeout...") + spark_app_manager.delete_app( namespace=self.namespace, pod_name=self._driver_pod_name, - poll_interval=1, ) + raise AirflowException("Spark application startup timeout exceeded") from None app_status = spark_app_manager.app_status( namespace=self.namespace, pod_name=self._driver_pod_name, diff --git a/spark_on_k8s/client.py b/spark_on_k8s/client.py index 8dd4848..324653f 100644 --- a/spark_on_k8s/client.py +++ b/spark_on_k8s/client.py @@ -136,6 +136,7 @@ def submit_app( executor_labels: dict[str, str] | ArgNotSet = NOTSET, driver_tolerations: list[k8s.V1Toleration] | ArgNotSet = NOTSET, executor_pod_template_path: str | ArgNotSet = NOTSET, + startup_timeout: int | ArgNotSet = NOTSET, ) -> str: """Submit a Spark app to Kubernetes @@ -178,6 +179,7 @@ def submit_app( executor_node_selector: Node selector for the executors driver_tolerations: List of tolerations for the driver executor_pod_template_path: Path to the executor pod template file + startup_timeout: Timeout in seconds to wait for the application to start Returns: Name of the Spark application pod @@ -280,6 +282,8 @@ def submit_app( driver_tolerations = [] if executor_pod_template_path is NOTSET or executor_pod_template_path is None: executor_pod_template_path = Configuration.SPARK_ON_K8S_EXECUTOR_POD_TEMPLATE_PATH + if startup_timeout is NOTSET: + startup_timeout = Configuration.SPARK_ON_K8S_STARTUP_TIMEOUT spark_conf = spark_conf or {} main_class_parameters = app_arguments or [] @@ -449,11 +453,16 @@ def submit_app( self.app_manager.stream_logs( namespace=namespace, pod_name=pod.metadata.name, + startup_timeout=startup_timeout, should_print=should_print, ) elif app_waiter == SparkAppWait.WAIT: self.app_manager.wait_for_app( - namespace=namespace, pod_name=pod.metadata.name, should_print=should_print + namespace=namespace, + pod_name=pod.metadata.name, + poll_interval=5, + startup_timeout=startup_timeout, + should_print=should_print, ) return pod.metadata.name diff --git a/spark_on_k8s/utils/app_manager.py b/spark_on_k8s/utils/app_manager.py index ed9693e..4516b3c 100644 --- a/spark_on_k8s/utils/app_manager.py +++ b/spark_on_k8s/utils/app_manager.py @@ -91,6 +91,7 @@ def wait_for_app( app_id: str | None = None, poll_interval: float = 10, should_print: bool = False, + startup_timeout: float = 0, ): """Wait for a Spark app to finish. @@ -99,8 +100,11 @@ def wait_for_app( pod_name (str): Pod name. app_id (str): App ID. poll_interval (float, optional): Poll interval in seconds. Defaults to 10. + startup_timeout (float, optional): Timeout in seconds to wait for the app to start. + Defaults to 0 (no timeout). should_print (bool, optional): Whether to print logs instead of logging them. """ + start_time = time.time() termination_statuses = {SparkAppStatus.Succeeded, SparkAppStatus.Failed, SparkAppStatus.Unknown} with self.k8s_client_manager.client() as client: api = k8s.CoreV1Api(client) @@ -111,6 +115,10 @@ def wait_for_app( ) if status in termination_statuses: break + if status == SparkAppStatus.Pending: + if startup_timeout and start_time + startup_timeout < time.time(): + raise TimeoutError("App startup timeout") + except ApiException as e: if e.status == 404: self.log( @@ -135,6 +143,7 @@ def stream_logs( namespace: str, pod_name: str | None = None, app_id: str | None = None, + startup_timeout: float = 0, should_print: bool = False, ): """Stream logs from a Spark app. @@ -143,8 +152,11 @@ def stream_logs( namespace (str): Namespace. pod_name (str): Pod name. app_id (str): App ID. + startup_timeout (float, optional): Timeout in seconds to wait for the app to start. + Defaults to 0 (no timeout). should_print (bool, optional): Whether to print logs instead of logging them. """ + start_time = time.time() if pod_name is None and app_id is None: raise ValueError("Either pod_name or app_id must be specified") if pod_name is None: @@ -166,6 +178,9 @@ def stream_logs( ) if pod.status.phase != "Pending": break + if startup_timeout and start_time + startup_timeout < time.time(): + raise TimeoutError("App startup timeout") + time.sleep(5) watcher = watch.Watch() for line in watcher.stream( api.read_namespaced_pod_log, diff --git a/spark_on_k8s/utils/configuration.py b/spark_on_k8s/utils/configuration.py index 55a7d04..8f4fd56 100644 --- a/spark_on_k8s/utils/configuration.py +++ b/spark_on_k8s/utils/configuration.py @@ -48,6 +48,7 @@ class Configuration: if getenv("SPARK_ON_K8S_DRIVER_ENV_VARS_FROM_SECRET") else [] ) + SPARK_ON_K8S_STARTUP_TIMEOUT = int(getenv("SPARK_ON_K8S_STARTUP_TIMEOUT", 0)) # Kubernetes client configuration # K8S client configuration diff --git a/tests/airflow/test_operators.py b/tests/airflow/test_operators.py index acb8f03..c1c1c1f 100644 --- a/tests/airflow/test_operators.py +++ b/tests/airflow/test_operators.py @@ -254,3 +254,55 @@ def test_job_adoption( mock_submit_app.assert_called_once() else: mock_submit_app.assert_not_called() + + @pytest.mark.parametrize( + "app_waiter", + [ + pytest.param("log", id="log"), + pytest.param("wait", id="wait"), + ], + ) + @mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.stream_logs") + @mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.wait_for_app") + @mock.patch("spark_on_k8s.utils.app_manager.SparkAppManager.app_status") + @mock.patch("spark_on_k8s.client.SparkOnK8S.submit_app") + def test_startup_timeout( + self, + mock_submit_app, + mock_app_status, + mock_wait_for_app, + mock_stream_logs, + app_waiter, + ): + from spark_on_k8s.airflow.operators import SparkOnK8SOperator + + mock_app_status.return_value = SparkAppStatus.Succeeded + mock_submit_app.return_value = "test-pod-name" + spark_app_task = SparkOnK8SOperator( + task_id="spark_application", + namespace="test-namespace", + image="pyspark-job", + app_path="local:///opt/spark/work-dir/job.py", + startup_timeout=10, + app_waiter=app_waiter, + ) + spark_app_task.execute( + { + "ti": mock.MagicMock( + xcom_pull=mock.MagicMock(side_effect=["test-namespace", "existing-pod"]), + ) + } + ) + if app_waiter == "log": + mock_stream_logs.assert_called_once_with( + namespace="test-namespace", + pod_name="test-pod-name", + startup_timeout=10, + ) + else: + mock_wait_for_app.assert_called_once_with( + namespace="test-namespace", + pod_name="test-pod-name", + poll_interval=10, + startup_timeout=10, + ) diff --git a/tests/test_spark_client.py b/tests/test_spark_client.py index 75f75c2..bdc5c49 100644 --- a/tests/test_spark_client.py +++ b/tests/test_spark_client.py @@ -10,6 +10,7 @@ import pytest from freezegun import freeze_time from kubernetes import client as k8s +from mock.mock import MagicMock from spark_on_k8s import client as client_module from spark_on_k8s.client import ExecutorInstances, PodResources, SparkOnK8S, default_app_id_suffix from spark_on_k8s.utils import configuration as configuration_module @@ -839,3 +840,42 @@ def test_submit_app_with_executor_pod_template_path( if conf.startswith("spark.kubernetes.executor") } assert executor_config.get("spark.kubernetes.executor.podTemplateFile") == "s3a://bucket/executor.yml" + + @pytest.mark.parametrize( + "app_waiter", + [ + pytest.param("log", id="log"), + pytest.param("wait", id="wait"), + ], + ) + @mock.patch("spark_on_k8s.k8s.sync_client.KubernetesClientManager.create_client") + @mock.patch("kubernetes.client.api.core_v1_api.CoreV1Api.read_namespaced_pod") + @mock.patch("kubernetes.client.api.core_v1_api.CoreV1Api.create_namespaced_pod") + @mock.patch("kubernetes.client.api.core_v1_api.CoreV1Api.create_namespaced_service") + def test_submit_app_with_startup_timeout( + self, + mock_create_namespaced_service, + mock_create_namespaced_pod, + mock_read_namespaced_pod, + mock_create_client, + app_waiter, + ): + """Test the method submit_app""" + mock_read_namespaced_pod.return_value = MagicMock(**{"status.phase": "Pending"}) + spark_client = SparkOnK8S() + with pytest.raises(TimeoutError, match="App startup timeout"): + spark_client.submit_app( + image="pyspark-job", + app_path="local:///opt/spark/work-dir/job.py", + namespace="spark", + service_account="spark", + app_name="pyspark-job-example", + app_arguments=["100000"], + app_waiter=app_waiter, + image_pull_policy="Never", + ui_reverse_proxy=True, + driver_resources=PodResources(cpu=1, memory=2048, memory_overhead=1024), + executor_instances=ExecutorInstances(min=2, max=5, initial=5), + executor_pod_template_path="s3a://bucket/executor.yml", + startup_timeout=5, + )