diff --git a/modules/mlflow/mlflow-fargate/stack.py b/modules/mlflow/mlflow-fargate/stack.py index 31e5f07c..23eaf1b9 100644 --- a/modules/mlflow/mlflow-fargate/stack.py +++ b/modules/mlflow/mlflow-fargate/stack.py @@ -8,6 +8,7 @@ from aws_cdk import aws_ecr as ecr from aws_cdk import aws_ecs as ecs from aws_cdk import aws_ecs_patterns as ecs_patterns +from aws_cdk import aws_efs as efs from aws_cdk import aws_iam as iam from aws_cdk import aws_s3 as s3 from cdk_nag import AwsSolutionsChecks, NagPackSuppression, NagSuppressions @@ -70,7 +71,6 @@ def __init__( container = task_definition.add_container( "ContainerDef", - # TODO: add ability to pull specific tag image=ecs.ContainerImage.from_ecr_repository( repository=ecr.Repository.from_repository_name( self, @@ -80,18 +80,54 @@ def __init__( ), environment={ "BUCKET": f"s3://{artifacts_bucket_name}", - # TODO: Add persistence - # "HOST": database.db_instance_endpoint_address, - # "PORT": str(port), - # "DATABASE": db_name, - # "USERNAME": username, }, - # secrets={"PASSWORD": ecs.Secret.from_secrets_manager(db_password_secret)}, logging=ecs.LogDriver.aws_logs(stream_prefix="mlflow"), ) port_mapping = ecs.PortMapping(container_port=5000, host_port=5000, protocol=ecs.Protocol.TCP) container.add_port_mappings(port_mapping) + # Add EFS + fs = efs.FileSystem( + self, + "EfsFileSystem", + vpc=vpc, + encrypted=True, + throughput_mode=efs.ThroughputMode.ELASTIC, + performance_mode=efs.PerformanceMode.GENERAL_PURPOSE, + file_system_policy=iam.PolicyDocument( + statements=[ + iam.PolicyStatement( + actions=[ + "elasticfilesystem:ClientMount", + "elasticfilesystem:ClientWrite", + "elasticfilesystem:ClientRootAccess", + ], + principals=[iam.AnyPrincipal()], + resources=["*"], + conditions={"Bool": {"elasticfilesystem:AccessedViaMountTarget": "true"}}, + ), + ] + ), + ) + self.fs = fs + + # Attach and mount volume + task_definition.add_volume( + name="efs-volume", + efs_volume_configuration=ecs.EfsVolumeConfiguration( + file_system_id=fs.file_system_id, + transit_encryption="ENABLED", + ), + ) + container.add_mount_points( + ecs.MountPoint( + container_path="./mlruns", + source_volume="efs-volume", + read_only=False, + ) + ) + + # Create ECS Service service = ecs_patterns.NetworkLoadBalancedFargateService( self, "MlflowLBService", @@ -99,7 +135,11 @@ def __init__( cluster=cluster, task_definition=task_definition, task_subnets=ec2.SubnetSelection(subnets=subnets), + circuit_breaker=ecs.DeploymentCircuitBreaker(rollback=True), ) + self.service = service + + # Enable access logs lb_access_logs_bucket = s3.Bucket( self, "LBAccessLogsBucket", @@ -110,6 +150,10 @@ def __init__( service.load_balancer.log_access_logs(bucket=lb_access_logs_bucket) self.lb_access_logs_bucket = lb_access_logs_bucket + # Allow access to EFS from Fargate service + fs.grant_root_access(service.task_definition.task_role.grant_principal) + fs.connections.allow_default_port_from(service.service.connections) + # Setup security group service.service.connections.security_groups[0].add_ingress_rule( peer=ec2.Peer.ipv4(vpc.vpc_cidr_block), @@ -125,7 +169,6 @@ def __init__( scale_in_cooldown=Duration.seconds(60), scale_out_cooldown=Duration.seconds(60), ) - self.service = service # Add CDK nag solutions checks Aspects.of(self).add(AwsSolutionsChecks())