diff --git a/CHANGELOG.md b/CHANGELOG.md index 7967e69f..f3cc4f89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314) +### Fixed + +- Resolved an issue where defining a custom network configuration with a subnet would erroneously report it as missing from the VPC when more than one subnet exists in the VPC. [#321](https://github.com/PrefectHQ/prefect-aws/pull/321) + ### Deprecated ### Removed diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 1b1c8073..a3822a3f 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -1349,10 +1349,10 @@ def _custom_network_configuration( + "Network configuration cannot be inferred." ) + subnet_ids = [subnet["SubnetId"] for subnet in subnets] + config_subnets = network_configuration.get("subnets", []) - if not all( - [conf_sn in sn.values() for conf_sn in config_subnets for sn in subnets] - ): + if not all(conf_sn in subnet_ids for conf_sn in config_subnets): raise ValueError( f"Subnets {config_subnets} not found within {vpc_message}." + "Please check that VPC is associated with supplied subnets." diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 1be6eb70..2c177aae 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -892,7 +892,7 @@ async def test_network_config_from_vpc_id( @pytest.mark.usefixtures("ecs_mocks") -async def test_network_config_from_custom_settings( +async def test_network_config_1_subnet_in_custom_settings_1_in_vpc( aws_credentials: AwsCredentials, flow_run: FlowRun ): session = aws_credentials.get_boto3_session() @@ -937,6 +937,107 @@ async def test_network_config_from_custom_settings( } +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_1_sn_in_custom_settings_many_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.3.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.4.0/24", VpcId=vpc.id) + + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_many_subnet_in_custom_settings_many_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnets = [ + ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.33.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.44.0/24", VpcId=vpc.id), + ] + subnet_ids = [subnet.id for subnet in subnets] + + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + @pytest.mark.usefixtures("ecs_mocks") async def test_network_config_from_custom_settings_invalid_subnet( aws_credentials: AwsCredentials, flow_run: FlowRun @@ -978,6 +1079,48 @@ async def test_network_config_from_custom_settings_invalid_subnet( await run_then_stop_task(worker, configuration, flow_run) +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_subnets( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + invalid_subnet_id = "subnet-3bf19de7" + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [invalid_subnet_id, subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + with pytest.raises( + ValueError, + match=( + rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC" + f" with ID {vpc.id}.Please check that VPC is associated with supplied" + " subnets." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + await run_then_stop_task(worker, configuration, flow_run) + + @pytest.mark.usefixtures("ecs_mocks") async def test_network_config_configure_network_requires_vpc_id( aws_credentials: AwsCredentials, flow_run: FlowRun