diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index ef57372b..39e50fea 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -1847,6 +1847,36 @@ async def test_user_defined_tags_in_task_run_request_template( {"key": "OVERRIDE", "value": "NEW"}, ] +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_capacity_provider_strategy_in_task_run_request_template( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_run_request={ + "capacityProviderStrategy": [ + {"base": 0, "weight": 1, "capacityProvider": "FOO"}, + ] + }, + ), + aws_credentials=aws_credentials, + ) + + assert "launchType" not in configuration.task_run_request + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert task.get("capacityProviderStrategy") == [ + {"base": 0, "weight": 1, "capacityProvider": "FOO"}, + ] @pytest.mark.usefixtures("ecs_mocks") @pytest.mark.parametrize(