Skip to content

Commit

Permalink
42411
Browse files Browse the repository at this point in the history
Add SageMakerProcessingSensor which can be used to wait on a SageMaker processing job.
  • Loading branch information
Jasmin committed Oct 18, 2024
1 parent 0de5587 commit b2d5971
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 0 deletions.
14 changes: 14 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/sagemaker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,20 @@ you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerAut
:start-after: [START howto_operator_sagemaker_auto_ml]
:end-before: [END howto_operator_sagemaker_auto_ml]

.. _howto/sensor:SageMakerProcessingSensor:

Wait on an Amazon SageMaker processing job state
================================================

To check the state of an Amazon Sagemaker processing job until it reaches a terminal state
you can use :class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerProcessingSensor`.

.. exampleinclude:: /../../providers/tests/system/amazon/aws/example_sagemaker.py
:language: python
:dedent: 4
:start-after: [START howto_sensor_sagemaker_processing]
:end-before: [END howto_sensor_sagemaker_processing]

Reference
---------

Expand Down
2 changes: 2 additions & 0 deletions providers/src/airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ class SageMakerHook(AwsBaseHook):
non_terminal_states = {"InProgress", "Stopping"}
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
pipeline_non_terminal_states = {"Executing", "Stopping"}
processing_job_non_terminal_states = {"InProgress", "Stopping"}
failed_states = {"Failed"}
processing_job_failed_states = {*failed_states, "Stopped"}
training_failed_states = {*failed_states, "Stopped"}

def __init__(self, *args, **kwargs):
Expand Down
32 changes: 32 additions & 0 deletions providers/src/airflow/providers/amazon/aws/sensors/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,35 @@ def get_sagemaker_response(self) -> dict:

def state_from_response(self, response: dict) -> str:
return response["AutoMLJobStatus"]


class SageMakerProcessingSensor(SageMakerBaseSensor):
"""
Poll the processing job until it reaches a terminal state; raise AirflowException with the failure reason.
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:SageMakerProcessingSensor`
:param job_name: Name of the processing job to watch.
"""

template_fields: Sequence[str] = ("job_name",)
template_ext: Sequence[str] = ()

def __init__(self, *, job_name: str, **kwargs):
super().__init__(**kwargs)
self.job_name = job_name

def non_terminal_states(self) -> set[str]:
return SageMakerHook.processing_job_non_terminal_states

def failed_states(self) -> set[str]:
return SageMakerHook.processing_job_failed_states

def get_sagemaker_response(self) -> dict:
self.log.info("Poking Sagemaker ProcessingJob %s", self.job_name)
return self.hook.describe_processing_job(self.job_name)

def state_from_response(self, response: dict) -> str:
return response["ProcessingJobStatus"]
110 changes: 110 additions & 0 deletions providers/tests/amazon/aws/sensors/test_sagemaker_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerProcessingSensor

DESCRIBE_PROCESSING_INPROGRESS_RESPONSE = {
"ProcessingJobStatus": "InProgress",
"ResponseMetadata": {
"HTTPStatusCode": 200,
},
}

DESCRIBE_PROCESSING_COMPLETED_RESPONSE = {
"ProcessingJobStatus": "Completed",
"ResponseMetadata": {
"HTTPStatusCode": 200,
},
}

DESCRIBE_PROCESSING_FAILED_RESPONSE = {
"ProcessingJobStatus": "Failed",
"ResponseMetadata": {
"HTTPStatusCode": 200,
},
"FailureReason": "Unknown",
}

DESCRIBE_PROCESSING_STOPPING_RESPONSE = {
"ProcessingJobStatus": "Stopping",
"ResponseMetadata": {
"HTTPStatusCode": 200,
},
}

DESCRIBE_PROCESSING_STOPPED_RESPONSE = {
"ProcessingJobStatus": "Stopped",
"ResponseMetadata": {
"HTTPStatusCode": 200,
},
}


class TestSageMakerProcessingSensor:
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "describe_processing_job")
def test_sensor_with_failure(self, mock_describe_job, mock_client):
mock_describe_job.side_effect = [DESCRIBE_PROCESSING_FAILED_RESPONSE]
sensor = SageMakerProcessingSensor(
task_id="test_task", poke_interval=2, aws_conn_id="aws_test", job_name="test_job_name"
)
with pytest.raises(AirflowException):
sensor.execute(None)
mock_describe_job.assert_called_once_with("test_job_name")

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "describe_processing_job")
def test_sensor_with_stopped(self, mock_describe_job, mock_client):
mock_describe_job.side_effect = [DESCRIBE_PROCESSING_STOPPED_RESPONSE]
sensor = SageMakerProcessingSensor(
task_id="test_task", poke_interval=2, aws_conn_id="aws_test", job_name="test_job_name"
)
with pytest.raises(AirflowException):
sensor.execute(None)
mock_describe_job.assert_called_once_with("test_job_name")

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "__init__")
@mock.patch.object(SageMakerHook, "describe_processing_job")
def test_sensor(self, mock_describe_job, hook_init, mock_client):
hook_init.return_value = None

mock_describe_job.side_effect = [
DESCRIBE_PROCESSING_INPROGRESS_RESPONSE,
DESCRIBE_PROCESSING_STOPPING_RESPONSE,
DESCRIBE_PROCESSING_COMPLETED_RESPONSE,
]
sensor = SageMakerProcessingSensor(
task_id="test_task", poke_interval=0, aws_conn_id="aws_test", job_name="test_job_name"
)

sensor.execute(None)

# make sure we called 3 times(terminated when its completed)
assert mock_describe_job.call_count == 3

# make sure the hook was initialized with the specific params
calls = [mock.call(aws_conn_id="aws_test")]
hook_init.assert_has_calls(calls)
13 changes: 13 additions & 0 deletions providers/tests/system/amazon/aws/example_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from airflow.providers.amazon.aws.sensors.sagemaker import (
SageMakerAutoMLSensor,
SageMakerProcessingSensor,
SageMakerTrainingSensor,
SageMakerTransformSensor,
SageMakerTuningSensor,
Expand Down Expand Up @@ -390,6 +391,7 @@ def set_up(env_id, role_arn):
ti.xcom_push(key="raw_data_s3_key", value=raw_data_s3_key)
ti.xcom_push(key="ecr_repository_name", value=ecr_repository_name)
ti.xcom_push(key="processing_config", value=processing_config)
ti.xcom_push(key="processing_job_name", value=processing_job_name)
ti.xcom_push(key="input_data_uri", value=input_data_uri)
ti.xcom_push(key="output_data_uri", value=f"s3://{bucket_name}/{training_output_s3_key}")
ti.xcom_push(key="training_config", value=training_config)
Expand Down Expand Up @@ -518,8 +520,18 @@ def delete_docker_image(image_name):
task_id="preprocess_raw_data",
config=test_setup["processing_config"],
)

# SageMakerProcessingOperator waits by default, setting as False to test the Sensor below.
preprocess_raw_data.wait_for_completion = False

# [END howto_operator_sagemaker_processing]

# [START howto_sensor_sagemaker_processing]
await_preprocess = SageMakerProcessingSensor(
task_id="await_preprocess", job_name=test_setup["processing_job_name"]
)
# [END howto_sensor_sagemaker_processing]

# [START howto_operator_sagemaker_training]
train_model = SageMakerTrainingOperator(
task_id="train_model",
Expand Down Expand Up @@ -622,6 +634,7 @@ def delete_docker_image(image_name):
await_automl,
create_experiment,
preprocess_raw_data,
await_preprocess,
train_model,
await_training,
create_model,
Expand Down

0 comments on commit b2d5971

Please sign in to comment.