diff --git a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst index c2f433267c306..03e5a7a921c27 100644 --- a/docs/apache-airflow-providers-amazon/operators/sagemaker.rst +++ b/docs/apache-airflow-providers-amazon/operators/sagemaker.rst @@ -366,6 +366,20 @@ you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerAut :start-after: [START howto_operator_sagemaker_auto_ml] :end-before: [END howto_operator_sagemaker_auto_ml] +.. _howto/sensor:SageMakerProcessingSensor: + +Wait on an Amazon SageMaker processing job state +================================================ + +To check the state of an Amazon Sagemaker processing job until it reaches a terminal state +you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerProcessingSensor`. + +.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_sagemaker.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_sagemaker_processing] + :end-before: [END howto_sensor_sagemaker_processing] + Reference --------- diff --git a/providers/src/airflow/providers/amazon/aws/hooks/sagemaker.py b/providers/src/airflow/providers/amazon/aws/hooks/sagemaker.py index e16ab11b0c95c..10d7b8436c0d0 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/sagemaker.py @@ -153,7 +153,9 @@ class SageMakerHook(AwsBaseHook): non_terminal_states = {"InProgress", "Stopping"} endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"} pipeline_non_terminal_states = {"Executing", "Stopping"} + processing_job_non_terminal_states = {"InProgress", "Stopping"} failed_states = {"Failed"} + processing_job_failed_states = {*failed_states, "Stopped"} training_failed_states = {*failed_states, "Stopped"} def __init__(self, *args, **kwargs): diff --git a/providers/src/airflow/providers/amazon/aws/sensors/sagemaker.py b/providers/src/airflow/providers/amazon/aws/sensors/sagemaker.py index e77628cf8d596..5863de92d9339 100644 --- a/providers/src/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/providers/src/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -330,3 +330,35 @@ def get_sagemaker_response(self) -> dict: def state_from_response(self, response: dict) -> str: return response["AutoMLJobStatus"] + + +class SageMakerProcessingSensor(SageMakerBaseSensor): + """ + Poll the processing job until it reaches a terminal state; raise AirflowException with the failure reason. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:SageMakerProcessingSensor` + + :param job_name: Name of the processing job to watch. + """ + + template_fields: Sequence[str] = ("job_name",) + template_ext: Sequence[str] = () + + def __init__(self, *, job_name: str, **kwargs): + super().__init__(**kwargs) + self.job_name = job_name + + def non_terminal_states(self) -> set[str]: + return SageMakerHook.processing_job_non_terminal_states + + def failed_states(self) -> set[str]: + return SageMakerHook.processing_job_failed_states + + def get_sagemaker_response(self) -> dict: + self.log.info("Poking Sagemaker ProcessingJob %s", self.job_name) + return self.hook.describe_processing_job(self.job_name) + + def state_from_response(self, response: dict) -> str: + return response["ProcessingJobStatus"] diff --git a/providers/tests/amazon/aws/sensors/test_sagemaker_processing.py b/providers/tests/amazon/aws/sensors/test_sagemaker_processing.py new file mode 100644 index 0000000000000..0529a0b9b19d5 --- /dev/null +++ b/providers/tests/amazon/aws/sensors/test_sagemaker_processing.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook +from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerProcessingSensor + +DESCRIBE_PROCESSING_INPROGRESS_RESPONSE = { + "ProcessingJobStatus": "InProgress", + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, +} + +DESCRIBE_PROCESSING_COMPLETED_RESPONSE = { + "ProcessingJobStatus": "Completed", + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, +} + +DESCRIBE_PROCESSING_FAILED_RESPONSE = { + "ProcessingJobStatus": "Failed", + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, + "FailureReason": "Unknown", +} + +DESCRIBE_PROCESSING_STOPPING_RESPONSE = { + "ProcessingJobStatus": "Stopping", + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, +} + +DESCRIBE_PROCESSING_STOPPED_RESPONSE = { + "ProcessingJobStatus": "Stopped", + "ResponseMetadata": { + "HTTPStatusCode": 200, + }, +} + + +class TestSageMakerProcessingSensor: + @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") + def test_sensor_with_failure(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_PROCESSING_FAILED_RESPONSE] + sensor = SageMakerProcessingSensor( + task_id="test_task", poke_interval=2, aws_conn_id="aws_test", job_name="test_job_name" + ) + with pytest.raises(AirflowException): + sensor.execute(None) + mock_describe_job.assert_called_once_with("test_job_name") + + @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") + def test_sensor_with_stopped(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_PROCESSING_STOPPED_RESPONSE] + sensor = SageMakerProcessingSensor( + task_id="test_task", poke_interval=2, aws_conn_id="aws_test", job_name="test_job_name" + ) + with pytest.raises(AirflowException): + sensor.execute(None) + mock_describe_job.assert_called_once_with("test_job_name") + + @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "__init__") + @mock.patch.object(SageMakerHook, "describe_processing_job") + def test_sensor(self, mock_describe_job, hook_init, mock_client): + hook_init.return_value = None + + mock_describe_job.side_effect = [ + DESCRIBE_PROCESSING_INPROGRESS_RESPONSE, + DESCRIBE_PROCESSING_STOPPING_RESPONSE, + DESCRIBE_PROCESSING_COMPLETED_RESPONSE, + ] + sensor = SageMakerProcessingSensor( + task_id="test_task", poke_interval=0, aws_conn_id="aws_test", job_name="test_job_name" + ) + + sensor.execute(None) + + # make sure we called 3 times(terminated when its completed) + assert mock_describe_job.call_count == 3 + + # make sure the hook was initialized with the specific params + calls = [mock.call(aws_conn_id="aws_test")] + hook_init.assert_has_calls(calls) diff --git a/providers/tests/system/amazon/aws/example_sagemaker.py b/providers/tests/system/amazon/aws/example_sagemaker.py index 96e9756659975..3098ac647ea75 100644 --- a/providers/tests/system/amazon/aws/example_sagemaker.py +++ b/providers/tests/system/amazon/aws/example_sagemaker.py @@ -47,6 +47,7 @@ ) from airflow.providers.amazon.aws.sensors.sagemaker import ( SageMakerAutoMLSensor, + SageMakerProcessingSensor, SageMakerTrainingSensor, SageMakerTransformSensor, SageMakerTuningSensor, @@ -390,6 +391,7 @@ def set_up(env_id, role_arn): ti.xcom_push(key="raw_data_s3_key", value=raw_data_s3_key) ti.xcom_push(key="ecr_repository_name", value=ecr_repository_name) ti.xcom_push(key="processing_config", value=processing_config) + ti.xcom_push(key="processing_job_name", value=processing_job_name) ti.xcom_push(key="input_data_uri", value=input_data_uri) ti.xcom_push(key="output_data_uri", value=f"s3://{bucket_name}/{training_output_s3_key}") ti.xcom_push(key="training_config", value=training_config) @@ -518,8 +520,18 @@ def delete_docker_image(image_name): task_id="preprocess_raw_data", config=test_setup["processing_config"], ) + + # SageMakerProcessingOperator waits by default, setting as False to test the Sensor below. + preprocess_raw_data.wait_for_completion = False + # [END howto_operator_sagemaker_processing] + # [START howto_sensor_sagemaker_processing] + await_preprocess = SageMakerProcessingSensor( + task_id="await_preprocess", job_name=test_setup["processing_job_name"] + ) + # [END howto_sensor_sagemaker_processing] + # [START howto_operator_sagemaker_training] train_model = SageMakerTrainingOperator( task_id="train_model", @@ -622,6 +634,7 @@ def delete_docker_image(image_name): await_automl, create_experiment, preprocess_raw_data, + await_preprocess, train_model, await_training, create_model,