diff --git a/CHANGELOG.md b/CHANGELOG.md index 259b594c..271c1711 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303) + ### Changed ### Deprecated diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 5b55c971..699873a7 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -70,6 +70,7 @@ ) from pydantic import Field, root_validator from slugify import slugify +from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from typing_extensions import Literal from prefect_aws import AwsCredentials @@ -122,6 +123,11 @@ taskDefinition: "{{ task_definition_arn }}" """ +# Create task run retry settings +MAX_CREATE_TASK_RUN_ATTEMPTS = 3 +CREATE_TASK_RUN_MIN_DELAY_SECONDS = 1 +CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS = 0 +CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS = 3 _TASK_DEFINITION_CACHE: Dict[UUID, str] = {} _TAG_REGEX = r"[^a-zA-Z0-9-_.=+-@: ]+" @@ -1421,6 +1427,14 @@ def _prepare_task_run_request( return task_run_request + @retry( + stop=stop_after_attempt(MAX_CREATE_TASK_RUN_ATTEMPTS), + wait=wait_fixed(CREATE_TASK_RUN_MIN_DELAY_SECONDS) + + wait_random( + CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS, + CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS, + ), + ) def _create_task_run(self, ecs_client: _ECSClient, task_run_request: dict) -> str: """ Create a run of a task definition. diff --git a/requirements.txt b/requirements.txt index 1e700c1a..40b6007d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ boto3>=1.24.53 botocore>=1.27.53 -prefect>=2.10.11 mypy_boto3_s3>=1.24.94 -mypy_boto3_secretsmanager>=1.26.49 \ No newline at end of file +mypy_boto3_secretsmanager>=1.26.49 +prefect>=2.10.11 +tenacity>=8.0.0 \ No newline at end of file diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index f222d191..4ebfb158 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -12,6 +12,7 @@ from moto.ec2.utils import generate_instance_identity_document from prefect.server.schemas.core import FlowRun from prefect.utilities.asyncutils import run_sync_in_worker_thread +from tenacity import RetryError from prefect_aws.workers.ecs_worker import ( _TASK_DEFINITION_CACHE, @@ -1900,3 +1901,26 @@ async def test_kill_infrastructure_with_grace_period(aws_credentials, caplog, fl # Logs warning assert "grace period of 60s requested, but AWS does not support" in caplog.text + + +async def test_retry_on_failed_task_start( + aws_credentials: AwsCredentials, flow_run, ecs_mocks +): + run_task_mock = MagicMock(return_value=[]) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + inject_moto_patches( + ecs_mocks, + { + "run_task": [run_task_mock], + }, + ) + + with pytest.raises(RetryError): + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + assert run_task_mock.call_count == 3