Skip to content

Commit

Permalink
Add network isolation for training and processing jobs & Inter contai…
Browse files Browse the repository at this point in the history
…ner encryption in sagemaker training and hyper parameter tuning job (#133)

* add KMS encryption for storage

* adding encryption to ECR

* Encrypting each bucket with KMS MANAGED

* Network isolation for training and processing jobs &  Inter container encryption in sagemaker training job and hyper parameter tuning job

* Network isolation for training and processing jobs &  Inter container encryption in sagemaker training job and hyper parameter tuning job

* checkpoint

* Initial changes based on PR comments

* Re-arrange imports for validation script

* Convert lists to env vars correctly

* Create SG in module

* Resolve VPC in catalog stack instead of product stack

* Formatting

* Add permissions

* Do not enable network isolation on preprocess

* Updated changelog

---------

Co-authored-by: Srinivas Reddy <srinivasreddych@outlook.com>
Co-authored-by: Leon Luttenberger <LeonLuttenberger@users.noreply.github.com>
Co-authored-by: Anton Kukushkin <kukushkin.anton@gmail.com>
Co-authored-by: Ethan Bunce <ebunce@amazon.co.uk>
  • Loading branch information
5 people authored Aug 7, 2024
1 parent dc22221 commit 2818c2e
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### **Added**

- added documentation for Ray on EKS manifests
- Added network isolation and inter container encryption for xgboost template

### **Changed**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cdk_nag
from aws_cdk import BundlingOptions, BundlingOutput, DockerImage, Stack, Tags
from aws_cdk import aws_ec2 as ec2
from aws_cdk import aws_iam as iam
from aws_cdk import aws_s3_assets as s3_assets
from aws_cdk import aws_servicecatalog as servicecatalog
Expand Down Expand Up @@ -74,6 +75,10 @@ def __init__(
managed_policies=[iam.ManagedPolicy.from_aws_managed_policy_name("AdministratorAccess")],
)

dev_vpc = None
if dev_vpc_id:
dev_vpc = ec2.Vpc.from_lookup(self, "dev-vpc", vpc_id=dev_vpc_id)

templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
for template_name in next(os.walk(templates_dir))[1]:
if template_name == "__pycache__":
Expand All @@ -90,6 +95,7 @@ def __init__(
build_app_asset=build_app_asset,
deploy_app_asset=deploy_app_asset,
dev_vpc_id=dev_vpc_id,
dev_vpc=dev_vpc,
dev_subnet_ids=dev_subnet_ids,
dev_security_group_ids=dev_security_group_ids,
pre_prod_vpc_id=pre_prod_vpc_id,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Any
import json
from typing import Any, List

import aws_cdk
from aws_cdk import Aws
Expand Down Expand Up @@ -29,6 +29,10 @@ def __init__(
model_bucket: s3.IBucket,
pipeline_artifact_bucket: s3.IBucket,
repo_asset: s3_assets.Asset,
enable_network_isolation: str,
encrypt_inter_container_traffic: str,
subnet_ids: List[str],
security_group_ids: List[str],
**kwargs: Any,
) -> None:
super().__init__(scope, construct_id, **kwargs)
Expand Down Expand Up @@ -171,6 +175,37 @@ def __init__(
],
),
)
sagemaker_execution_role.add_to_policy(
iam.PolicyStatement(
actions=[
"ec2:CreateNetworkInterface",
"ec2:CreateNetworkInterfacePermission",
"ec2:DeleteNetworkInterface",
],
resources=[
f"arn:{Aws.PARTITION}:ec2:{Aws.REGION}:{Aws.ACCOUNT_ID}:network-interface/*",
*[
f"arn:{Aws.PARTITION}:ec2:{Aws.REGION}:{Aws.ACCOUNT_ID}:subnet/{subnet_id}"
for subnet_id in subnet_ids
],
*[
f"arn:{Aws.PARTITION}:ec2:{Aws.REGION}:{Aws.ACCOUNT_ID}:security-group/{security_group_id}"
for security_group_id in security_group_ids
],
],
),
)
sagemaker_execution_role.add_to_policy(
iam.PolicyStatement(
actions=[
"ec2:DescribeNetworkInterfaces",
"ec2:DescribeVpcs",
"ec2:DescribeSubnets",
"ec2:DescribeDhcpOptions",
],
resources=["*"],
),
)

# Grant extra permissions for the CodeBuild role
codebuild_role.add_to_policy(
Expand Down Expand Up @@ -232,6 +267,12 @@ def __init__(
environment=codebuild.BuildEnvironment(
build_image=codebuild.LinuxBuildImage.STANDARD_5_0,
environment_variables={
"ENABLE_NETWORK_ISOLATION": codebuild.BuildEnvironmentVariable(value=enable_network_isolation),
"ENCRYPT_INTER_CONTAINER_TRAFFIC": codebuild.BuildEnvironmentVariable(
value=encrypt_inter_container_traffic,
),
"SUBNET_IDS": codebuild.BuildEnvironmentVariable(value=json.dumps(subnet_ids)),
"SECURITY_GROUP_IDS": codebuild.BuildEnvironmentVariable(value=json.dumps(security_group_ids)),
"SAGEMAKER_PROJECT_NAME": codebuild.BuildEnvironmentVariable(value=project_name),
"SAGEMAKER_PROJECT_ID": codebuild.BuildEnvironmentVariable(value=project_id),
"SAGEMAKER_DOMAIN_ID": codebuild.BuildEnvironmentVariable(value=domain_id),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Any
from typing import Any, List

import aws_cdk.aws_ec2 as ec2
import aws_cdk.aws_iam as iam
import aws_cdk.aws_kms as kms
import aws_cdk.aws_s3 as s3
Expand Down Expand Up @@ -30,6 +31,8 @@ def __init__(
prod_account_id: str,
sagemaker_domain_id: str,
sagemaker_domain_arn: str,
dev_vpc: ec2.IVpc,
dev_subnet_ids: List[str],
**kwargs: Any,
) -> None:
super().__init__(scope, id)
Expand Down Expand Up @@ -68,6 +71,27 @@ def __init__(
default=prod_account_id,
).value_as_string

enable_network_isolation = CfnParameter(
self,
"EnableNetworkIsolation",
type="String",
description=(
"Enable network isolation. Will NOT enable network isolation on preprocess step as it requires access "
"to S3 for training data."
),
allowed_values=["true", "false"],
default="false",
).value_as_string

encrypt_inter_container_traffic = CfnParameter(
self,
"EncryptInterContainerTraffic",
type="String",
description="Encrypt inter container traffic",
allowed_values=["true", "false"],
default="false",
).value_as_string

Tags.of(self).add("sagemaker:project-id", sagemaker_project_id)
Tags.of(self).add("sagemaker:project-name", sagemaker_project_name)
if sagemaker_domain_id:
Expand Down Expand Up @@ -229,6 +253,12 @@ def __init__(
removal_policy=RemovalPolicy.DESTROY,
)

security_group_ids = []
if dev_vpc and dev_subnet_ids:
security_group_ids = [ec2.SecurityGroup(self, "Security Group", vpc=dev_vpc).security_group_id]
else:
dev_subnet_ids = []

BuildPipelineConstruct(
self,
"build",
Expand All @@ -240,6 +270,10 @@ def __init__(
model_bucket=model_bucket,
pipeline_artifact_bucket=pipeline_artifact_bucket,
repo_asset=build_app_asset,
enable_network_isolation=enable_network_isolation,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
subnet_ids=dev_subnet_ids,
security_group_ids=security_group_ids,
)

CfnOutput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
Implements a get_pipeline(**kwargs) method.
"""

import json
import logging
import os
from typing import Any, Optional

import boto3
Expand All @@ -21,6 +23,7 @@
from sagemaker.estimator import Estimator
from sagemaker.inputs import TrainingInput
from sagemaker.model_metrics import MetricsSource, ModelMetrics
from sagemaker.network import NetworkConfig
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
Expand All @@ -33,6 +36,12 @@

# BASE_DIR = os.path.dirname(os.path.realpath(__file__))

ENABLE_NETWORK_ISOLATION = os.getenv("ENABLE_NETWORK_ISOLATION", "true").lower() == "true"
ENCRYPT_INTER_CONTAINER_TRAFFIC = os.getenv("ENCRYPT_INTER_CONTAINER_TRAFFIC", "true").lower() == "true"
SUBNET_IDS = json.loads(os.getenv("SUBNET_IDS", "[]"))
SECURITY_GROUP_IDS = json.loads(os.getenv("SECURITY_GROUP_IDS", "[]"))


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -86,6 +95,21 @@ def get_pipeline(
if role is None:
role = sagemaker.session.get_execution_role(sagemaker_session)

# define network config
network_config = NetworkConfig(
subnets=SUBNET_IDS if SUBNET_IDS else None,
security_group_ids=SECURITY_GROUP_IDS if SECURITY_GROUP_IDS else None,
enable_network_isolation=ENABLE_NETWORK_ISOLATION,
encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC,
)
# define network config without network isolation to allow S3 access for preprocessor
network_config_without_isolation = NetworkConfig(
subnets=SUBNET_IDS if SUBNET_IDS else None,
security_group_ids=SECURITY_GROUP_IDS if SECURITY_GROUP_IDS else None,
enable_network_isolation=False,
encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC,
)

# parameters for pipeline execution
processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
processing_instance_type = ParameterString(name="ProcessingInstanceType", default_value="ml.m5.xlarge")
Expand Down Expand Up @@ -122,6 +146,7 @@ def get_pipeline(
sagemaker_session=sagemaker_session,
role=role,
output_kms_key=bucket_kms_id,
network_config=network_config_without_isolation,
)
step_process = ProcessingStep(
name="PreprocessAbaloneData",
Expand Down Expand Up @@ -160,6 +185,10 @@ def get_pipeline(
sagemaker_session=sagemaker_session,
role=role,
output_kms_key=bucket_kms_id,
subnets=SUBNET_IDS if SUBNET_IDS else None,
security_group_ids=SECURITY_GROUP_IDS if SECURITY_GROUP_IDS else None,
enable_network_isolation=ENABLE_NETWORK_ISOLATION,
encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC,
)
xgb_train.set_hyperparameters(
objective="reg:linear",
Expand Down Expand Up @@ -196,6 +225,7 @@ def get_pipeline(
sagemaker_session=sagemaker_session,
role=role,
output_kms_key=bucket_kms_id,
network_config=network_config,
)
evaluation_report = PropertyFile(
name="AbaloneEvaluationReport",
Expand Down

0 comments on commit 2818c2e

Please sign in to comment.