Skip to content

Commit

Permalink
use provided task subnets, scope down s3 permissions
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Kukushkin <kukushkin.anton@gmail.com>
  • Loading branch information
kukushking committed Feb 14, 2024
1 parent 9a34e40 commit f2f2269
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
8 changes: 6 additions & 2 deletions modules/mlflow/mlflow-fargate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def _param(name: str) -> str:
artifacts_bucket_name = os.getenv(_param("ARTIFACTS_BUCKET_NAME"))
# TODO: add persistent backend store

if not vpc_id:
raise ValueError("Missing input parameter vpc-id")

if not ecr_repo_name:
raise ValueError("Missing input parameter ecr-repository-name")

if not vpc_id:
raise ValueError("Missing input parameter vpc-id")
if not artifacts_bucket_name:
raise ValueError("Missing input parameter artifacts-bucket-name")


app = aws_cdk.App()
stack = MlflowFargateStack(
Expand Down
11 changes: 8 additions & 3 deletions modules/mlflow/mlflow-fargate/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aws_cdk import aws_ecs as ecs
from aws_cdk import aws_ecs_patterns as ecs_patterns
from aws_cdk import aws_iam as iam
from aws_cdk import aws_s3 as s3
from constructs import Construct, IConstruct


Expand Down Expand Up @@ -37,13 +38,16 @@ def __init__(
"TaskRole",
assumed_by=iam.ServicePrincipal(service="ecs-tasks.amazonaws.com"),
managed_policies=[
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3FullAccess"),
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonECS_FullAccess"),
],
)

vpc = ec2.Vpc.from_lookup(self, "vpc", vpc_id=vpc_id)
subnets = [ec2.Subnet.from_subnet_id(self, f"sub-{subnet_id}", subnet_id) for subnet_id in subnet_ids]
# Grant artifacts bucket read-write permissions
model_bucket = s3.Bucket.from_bucket_name(self, "ArtifactsBucket", bucket_name=artifacts_bucket_name)
model_bucket.grant_read_write(role)

vpc = ec2.Vpc.from_lookup(self, "Vpc", vpc_id=vpc_id)
subnets = [ec2.Subnet.from_subnet_id(self, f"Sub{subnet_id}", subnet_id) for subnet_id in subnet_ids]

cluster = ecs.Cluster(
self,
Expand Down Expand Up @@ -91,6 +95,7 @@ def __init__(
service_name=service_name,
cluster=cluster,
task_definition=task_definition,
task_subnets=ec2.SubnetSelection(subnets=subnets),
)

# Setup security group
Expand Down

0 comments on commit f2f2269

Please sign in to comment.