From f2f22698d442ed77f75caf1d564aa88570a57091 Mon Sep 17 00:00:00 2001 From: Anton Kukushkin Date: Wed, 14 Feb 2024 14:49:27 +0000 Subject: [PATCH] use provided task subnets, scope down s3 permissions Signed-off-by: Anton Kukushkin --- modules/mlflow/mlflow-fargate/app.py | 8 ++++++-- modules/mlflow/mlflow-fargate/stack.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/modules/mlflow/mlflow-fargate/app.py b/modules/mlflow/mlflow-fargate/app.py index b9b7eeda..2ba6f371 100644 --- a/modules/mlflow/mlflow-fargate/app.py +++ b/modules/mlflow/mlflow-fargate/app.py @@ -38,11 +38,15 @@ def _param(name: str) -> str: artifacts_bucket_name = os.getenv(_param("ARTIFACTS_BUCKET_NAME")) # TODO: add persistent backend store +if not vpc_id: + raise ValueError("Missing input parameter vpc-id") + if not ecr_repo_name: raise ValueError("Missing input parameter ecr-repository-name") -if not vpc_id: - raise ValueError("Missing input parameter vpc-id") +if not artifacts_bucket_name: + raise ValueError("Missing input parameter artifacts-bucket-name") + app = aws_cdk.App() stack = MlflowFargateStack( diff --git a/modules/mlflow/mlflow-fargate/stack.py b/modules/mlflow/mlflow-fargate/stack.py index d43b29d0..e442ab1e 100644 --- a/modules/mlflow/mlflow-fargate/stack.py +++ b/modules/mlflow/mlflow-fargate/stack.py @@ -9,6 +9,7 @@ from aws_cdk import aws_ecs as ecs from aws_cdk import aws_ecs_patterns as ecs_patterns from aws_cdk import aws_iam as iam +from aws_cdk import aws_s3 as s3 from constructs import Construct, IConstruct @@ -37,13 +38,16 @@ def __init__( "TaskRole", assumed_by=iam.ServicePrincipal(service="ecs-tasks.amazonaws.com"), managed_policies=[ - iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"), iam.ManagedPolicy.from_aws_managed_policy_name("AmazonECS_FullAccess"), ], ) - vpc = ec2.Vpc.from_lookup(self, "vpc", vpc_id=vpc_id) - subnets = [ec2.Subnet.from_subnet_id(self, f"sub-{subnet_id}", subnet_id) for subnet_id in subnet_ids] + # Grant artifacts bucket read-write permissions + model_bucket = s3.Bucket.from_bucket_name(self, "ArtifactsBucket", bucket_name=artifacts_bucket_name) + model_bucket.grant_read_write(role) + + vpc = ec2.Vpc.from_lookup(self, "Vpc", vpc_id=vpc_id) + subnets = [ec2.Subnet.from_subnet_id(self, f"Sub{subnet_id}", subnet_id) for subnet_id in subnet_ids] cluster = ecs.Cluster( self, @@ -91,6 +95,7 @@ def __init__( service_name=service_name, cluster=cluster, task_definition=task_definition, + task_subnets=ec2.SubnetSelection(subnets=subnets), ) # Setup security group