diff --git a/docs/integrations/catalog/prefect-aws.yaml b/docs/integrations/catalog/prefect-aws.yaml index 1d9d18c95e7a..113e169d330c 100644 --- a/docs/integrations/catalog/prefect-aws.yaml +++ b/docs/integrations/catalog/prefect-aws.yaml @@ -1,6 +1,6 @@ collectionName: prefect-aws author: Prefect authorUrl: https://prefect.io -documentation: https://prefecthq.github.io/prefect-aws/ +documentation: prefect-aws iconUrl: /img/collections/aws.png tag: AWS \ No newline at end of file diff --git a/docs/integrations/prefect-aws/batch.md b/docs/integrations/prefect-aws/batch.md new file mode 100644 index 000000000000..6262d9aa37b9 --- /dev/null +++ b/docs/integrations/prefect-aws/batch.md @@ -0,0 +1,6 @@ +--- +description: Tasks for interacting with AWS Batch +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.batch diff --git a/docs/integrations/prefect-aws/client_waiter.md b/docs/integrations/prefect-aws/client_waiter.md new file mode 100644 index 000000000000..b7c5b8ebd1b9 --- /dev/null +++ b/docs/integrations/prefect-aws/client_waiter.md @@ -0,0 +1,6 @@ +--- +description: Task for waiting on a long-running AWS job +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.client_waiter diff --git a/docs/integrations/prefect-aws/credentials.md b/docs/integrations/prefect-aws/credentials.md new file mode 100644 index 000000000000..f6447ace782f --- /dev/null +++ b/docs/integrations/prefect-aws/credentials.md @@ -0,0 +1,6 @@ +--- +description: Module handling AWS credentials +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.credentials \ No newline at end of file diff --git a/docs/integrations/prefect-aws/deployments/steps.md b/docs/integrations/prefect-aws/deployments/steps.md new file mode 100644 index 000000000000..b9386b8903f3 --- /dev/null +++ b/docs/integrations/prefect-aws/deployments/steps.md @@ -0,0 +1,6 @@ +--- +description: Prefect deployment steps for managing deployment code storage via AWS S3. +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.deployments.steps \ No newline at end of file diff --git a/docs/integrations/prefect-aws/ecs.md b/docs/integrations/prefect-aws/ecs.md new file mode 100644 index 000000000000..98ca6ece8767 --- /dev/null +++ b/docs/integrations/prefect-aws/ecs.md @@ -0,0 +1,6 @@ +--- +description: Integrations with the Amazon Elastic Container Service. +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.ecs \ No newline at end of file diff --git a/docs/integrations/prefect-aws/ecs_guide.md b/docs/integrations/prefect-aws/ecs_guide.md new file mode 100644 index 000000000000..7e5380ac6217 --- /dev/null +++ b/docs/integrations/prefect-aws/ecs_guide.md @@ -0,0 +1,356 @@ +# ECS Worker Guide + +## Why use ECS for flow run execution? + +ECS (Elastic Container Service) tasks are a good option for executing Prefect flow runs for several reasons: + +1. **Scalability**: ECS scales your infrastructure in response to demand, effectively managing Prefect flow runs. ECS automatically administers container distribution across multiple instances based on demand. +2. **Flexibility**: ECS lets you choose between AWS Fargate and Amazon EC2 for container operation. Fargate abstracts the underlying infrastructure, while EC2 has faster job start times and offers additional control over instance management and configuration. +3. **AWS Integration**: Easily connect with other AWS services, such as AWS IAM and CloudWatch. +4. **Containerization**: ECS supports Docker containers and offers managed execution. Containerization encourages reproducible deployments. + +## ECS flow run execution + +Prefect enables remote flow execution via [workers](https://docs.prefect.io/concepts/work-pools/#worker-overview) and [work pools](https://docs.prefect.io/concepts/work-pools/#work-pool-overview). To learn more about these concepts please see our [deployment tutorial](https://docs.prefect.io/tutorial/deployments/). + +For details on how workers and work pools are implemented for ECS, see the diagram below. + +```mermaid +%%{ + init: { + 'theme': 'base', + 'themeVariables': { + 'primaryColor': '#2D6DF6', + 'primaryTextColor': '#fff', + 'lineColor': '#FE5A14', + 'secondaryColor': '#E04BF0', + 'tertiaryColor': '#fff' + } + } +}%% +graph TB + + subgraph ecs_cluster[ECS cluster] + subgraph ecs_service[ECS service] + td_worker[Worker task definition] --> |defines| prefect_worker((Prefect worker)) + end + prefect_worker -->|kicks off| ecs_task + fr_task_definition[Flow run task definition] + + + subgraph ecs_task["ECS task execution"] + style ecs_task text-align:center,display:flex + + + flow_run((Flow run)) + + end + fr_task_definition -->|defines| ecs_task + end + + subgraph prefect_cloud[Prefect Cloud] + subgraph prefect_workpool[ECS work pool] + workqueue[Work queue] + end + end + + subgraph github["ECR"] + flow_code{{"Flow code"}} + end + flow_code --> |pulls| ecs_task + prefect_worker -->|polls| workqueue + prefect_workpool -->|configures| fr_task_definition +``` + +## ECS and Prefect + +!!! tip "ECS tasks != Prefect tasks" + An ECS task is **not** the same thing as a [Prefect task](https://docs.prefect.io/latest/concepts/tasks/#tasks-overview). + + ECS tasks are groupings of containers that run within an ECS Cluster. An ECS task's behavior is determined by its task definition. + + An [*ECS task definition*](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_definitions.html) is the blueprint for the ECS task. It describes which Docker containers to run and what you want to have happen inside these containers. + +ECS tasks are instances of a task definition. A Task Execution launches container(s) as defined in the task definition **until they are stopped or exit on their own**. This setup is ideal for ephemeral processes such as a Prefect flow run. + +The ECS task running the Prefect [worker](https://docs.prefect.io/latest/concepts/work-pools/#worker-overview) should be an [**ECS Service**](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs_services.html), given its long-running nature and need for **auto-recovery in case of failure**. An ECS service automatically replaces any task that fails, which is ideal for managing a long-running process such as a Prefect worker. + +When a Prefect [flow](https://docs.prefect.io/latest/concepts/flows/) is scheduled to run it goes into the work pool specified in the flow's [deployment](https://docs.prefect.io/latest/concepts/deployments). [Work pools](https://docs.prefect.io/latest/concepts/work-pools/?h=work#work-pool-overview) are typed according to the infrastructure the flow will run on. Flow runs scheduled in an `ecs` typed work pool are executed as ECS tasks. Only Prefect ECS [workers](https://docs.prefect.io/latest/concepts/work-pools/#worker-types) can poll an `ecs` typed work pool. + +When the ECS worker receives a scheduled flow run from the ECS work pool it is polling, it spins up the specified infrastructure on AWS ECS. The worker knows to build an ECS task definition for each flow run based on the configuration specified in the work pool. + +Once the flow run completes, the ECS containers of the cluster are spun down to a single container that continues to run the Prefect worker. This worker continues polling for work from the Prefect work pool. + +If you specify a task definition [ARN (Amazon Resource Name)](https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) in the work pool, the worker will use that ARN when spinning up the ECS Task, rather than creating a task definition from the fields supplied in the work pool configuration. + +You can use either EC2 or Fargate as the capacity provider. Fargate simplifies initiation, but lengthens infrastructure setup time for each flow run. Using EC2 for the ECS cluster can reduce setup time. In this example, we will show how to use Fargate. + +
+ + +!!! tip + If you prefer infrastructure as code check out this [Terraform module](https://github.com/PrefectHQ/prefect-recipes/tree/main/devops/infrastructure-as-code/aws/tf-prefect2-ecs-worker) to provision an ECS cluster with a worker. + +## Prerequisites + +- An AWS account with permissions to create ECS services and IAM roles. +- The AWS CLI installed on your local machine. You can [download it from the AWS website](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html). +- An [ECS Cluster](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/clusters.html) to host both the worker and the flow runs it submits. This guide uses the default cluster. To create your own follow [this guide](https://docs.aws.amazon.com/AmazonECS/latest/userguide/create_cluster.html). +- A [VPC](https://docs.aws.amazon.com/vpc/latest/userguide/what-is-amazon-vpc.html) configured for your ECS tasks. This guide uses the default VPC. +- Prefect Cloud account or Prefect self-managed instance. + +## Step 1: Set up an ECS work pool + +Before setting up the worker, create a [work pool](https://docs.prefect.io/latest/concepts/work-pools/#work-pool-configuration) of type ECS for the worker to pull work from. If doing so from the CLI, be sure to [authenticate with Prefect Cloud](https://docs.prefect.io/latest/cloud/cloud-quickstart/#log-into-prefect-cloud-from-a-terminal). + +Create a work pool from the CLI: + +```bash +prefect work-pool create --type ecs my-ecs-pool +``` + +Or from the Prefect UI: +![WorkPool](img/Workpool_UI.png) + +!!! + Because this guide uses Fargate as the capacity provider and the default VPC and ECS cluster, no further configuration is needed. + +Next, set up a Prefect ECS worker that will discover and pull work from this work pool. + +## Step 2: Start a Prefect worker in your ECS cluster + +First start by creating the [IAM role](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-custom.html#roles-creatingrole-custom-trust-policy-console) required in order for your worker and flows to run. The sample flow in this guide doesn't interact with many other AWS services, so you will only be creating one role, `taskExecutionRole`. To create an IAM role for the ECS task using the AWS CLI, follow these steps: + +### 1. Create a trust policy + +The trust policy will specify that the ECS service containing the Prefect worker will be able to assume the role required for calling other AWS services. + +Save this policy to a file, such as `ecs-trust-policy.json`: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "ecs-tasks.amazonaws.com" + }, + "Action": "sts:AssumeRole" + } + ] +} +``` + +### 2. Create the IAM roles + +Use the `aws iam create-role` command to create the roles that you will be using. For this guide, the `ecsTaskExecutionRole` will be used by the worker to start ECS tasks, and will also be the role assigned to the ECS tasks running your Prefect flows. + +```bash + aws iam create-role \ + --role-name ecsTaskExecutionRole \ + --assume-role-policy-document file://ecs-trust-policy.json +``` + +!!! tip + Depending on the requirements of your flows, it is advised to create a [second role for your ECS tasks](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html). This role will contain the permissions required by the ECS tasks in which your flows will run. For example, if your workflow loads data into an S3 bucket, you would need a role with additional permissions to access S3. + +### 3. Attach the policy to the role + +For this guide the ECS worker will require permissions to pull images from ECR and publish logs to CloudWatch. Amazon has a managed policy named `AmazonECSTaskExecutionRolePolicy` that grants the permissions necessary for starting ECS tasks. [See here](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_execution_IAM_role.html) for other common execution role permissions. Attach this policy to your task execution role: + +```bash + aws iam attach-role-policy \ + --role-name ecsTaskExecutionRole \ + --policy-arn arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy +``` + +Remember to replace the `--role-name` and `--policy-arn` with the actual role name and policy Amazon Resource Name (ARN) you want to use. + +## Step 3: Creating an ECS worker service + +### 1. Launch an ECS Service to host the worker + +Next, create an ECS task definition that specifies the Docker image for the Prefect worker, the resources it requires, and the command it should run. In this example, the command to start the worker is `prefect worker start --pool my-ecs-pool`. + +**Create a JSON file with the following contents:** + +```json +{ + "family": "prefect-worker-task", + "networkMode": "awsvpc", + "requiresCompatibilities": [ + "FARGATE" + ], + "cpu": "512", + "memory": "1024", + "executionRoleArn": "", + "taskRoleArn": "", + "containerDefinitions": [ + { + "name": "prefect-worker", + "image": "prefecthq/prefect:2-latest", + "cpu": 512, + "memory": 1024, + "essential": true, + "command": [ + "/bin/sh", + "-c", + "pip install prefect-aws && prefect worker start --pool my-ecs-pool --type ecs" + ], + "environment": [ + { + "name": "PREFECT_API_URL", + "value": "prefect-api-url>" + }, + { + "name": "PREFECT_API_KEY", + "value": "" + } + ] + } + ] +} +``` + +- Use `prefect config view` to view the `PREFECT_API_URL` for your current Prefect profile. Use this to replace ``. + +- For the `PREFECT_API_KEY`, if you are on a paid plan you can create a [service account](https://docs.prefect.io/latest/cloud/users/service-accounts/) for the worker. If your are on a free plan, you can pass a user’s API key. + +- Replace both instances of `` with the ARN of the IAM role you created in Step 2. You can grab this by running: +``` +aws iam get-role --role-name taskExecutionRole --query 'Role.[RoleName, Arn]' --output text +``` + +- Notice that the CPU and Memory allocations are relatively small. The worker's main responsibility is to submit work through API calls to AWS, _not_ to execute your Prefect flow code. + +!!! tip + To avoid hardcoding your API key into the task definition JSON see [how to add sensitive data using AWS secrets manager to the container definition](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/specifying-sensitive-data-tutorial.html#specifying-sensitive-data-tutorial-create-taskdef). + +### 2. Register the task definition + +Before creating a service, you first need to register a task definition. You can do that using the `register-task-definition` command in the AWS CLI. Here is an example: + +```bash +aws ecs register-task-definition --cli-input-json file://task-definition.json +``` + +Replace `task-definition.json` with the name of your JSON file. + +### 3. Create an ECS service to host your worker + +Finally, create a service that will manage your Prefect worker: + +Open a terminal window and run the following command to create an ECS Fargate service: + +```bash +aws ecs create-service \ + --service-name prefect-worker-service \ + --cluster \ + --task-definition \ + --launch-type FARGATE \ + --desired-count 1 \ + --network-configuration "awsvpcConfiguration={subnets=[],securityGroups=[],assignPublicIp='ENABLED'}" +``` + +- Replace `` with the name of your ECS cluster. +- Replace `` with the ARN of the task definition you just registered. +- Replace `` with a comma-separated list of your VPC subnet IDs. Ensure that these subnets are aligned with the vpc specified on the work pool in step 1. You can view subnet ids with the following command: + `aws ec2 describe-subnets --filter Name=` +- Replace `` with a comma-separated list of your VPC security group IDs. + +!!! tip "Sanity check" + The work pool page in the Prefect UI allows you to check the health of your workers - make sure your new worker is live! Note that it can take a few minutes for an ECS service to come online. + If your worker does not come online and you are using the command from this guide, you may not be using the default VPC. For connectivity issues, check your VPC's configuration and refer to the [ECS outbound networking guide](https://docs.aws.amazon.com/AmazonECS/latest/bestpracticesguide/networking-outbound.html). + +## Step 4: Pick up a flow run with your new worker + +This guide uses ECR to store a Docker image containing your flow code. To do this, we will write a flow, then deploy it using build and push steps that copy flow code into a Docker image and push that image to an ECR repository. + +### 1. Write a simple test flow + +`my_flow.py` + +```python +from prefect import flow, get_run_logger + +@flow +def my_flow(): + logger = get_run_logger() + logger.info("Hello from ECS!!") + +if __name__ == "__main__": + my_flow() +``` + +### 2. Create an ECR repository + +Use the following AWS CLI command to create an ECR repository. The name you choose for your repository will be reused in the next step when defining your Prefect deployment. + +```bash +aws ecr create-repository \ +--repository-name \ +--region +``` + +### 3. Create a `prefect.yaml` file + +To have Prefect build your image when deploying your flow create a `prefect.yaml` file with the following specification: + +```yaml +name: ecs-worker-guide +# this is pre-populated by running prefect init +prefect-version: 2.14.20 + +# build section allows you to manage and build docker images +build: +- prefect_docker.deployments.steps.build_docker_image: + id: build_image + requires: prefect-docker>=0.3.1 + image_name: + tag: latest + dockerfile: auto + +# push section allows you to manage if and how this project is uploaded to remote locations +push: +- prefect_docker.deployments.steps.push_docker_image: + requires: prefect-docker>=0.3.1 + image_name: '{{ build_image.image_name }}' + tag: '{{ build_image.tag }}' + + # the deployments section allows you to provide configuration for deploying flows +deployments: +- name: my_ecs_deployment + version: + tags: [] + description: + entrypoint: flow.py:my_flow + parameters: {} + work_pool: + name: ecs-dev-pool + work_queue_name: + job_variables: + image: '{{ build_image.image }}' + schedules: [] +pull: + - prefect.deployments.steps.set_working_directory: + directory: /opt/prefect/ecs-worker-guide + +``` + +### 4. [Deploy](https://docs.prefect.io/tutorial/deployments/#create-a-deployment) the flow to the Prefect Cloud or your self-managed server instance, specifying the ECS work pool when prompted + +```bash +prefect deploy my_flow.py:my_ecs_deployment +``` + +### 5. Find the deployment in the UI and click the **Quick Run** button! + +## Optional next steps + +1. Now that you are confident your ECS worker is healthy, you can experiment with different work pool configurations. + + - Do your flow runs require higher `CPU`? + - Would an EC2 `Launch Type` speed up your flow run execution? + + These infrastructure configuration values can be set on your ECS work pool or they can be overridden on the deployment level through [job_variables](https://docs.prefect.io/concepts/infrastructure/#kubernetesjob-overrides-and-customizations) if desired. diff --git a/docs/integrations/prefect-aws/ecs_worker.md b/docs/integrations/prefect-aws/ecs_worker.md new file mode 100644 index 000000000000..4f93ec820c48 --- /dev/null +++ b/docs/integrations/prefect-aws/ecs_worker.md @@ -0,0 +1,6 @@ +--- +description: Worker integration with the Amazon Elastic Container Service. +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.workers.ecs_worker \ No newline at end of file diff --git a/docs/integrations/prefect-aws/glue_job.md b/docs/integrations/prefect-aws/glue_job.md new file mode 100644 index 000000000000..b9728a102bc0 --- /dev/null +++ b/docs/integrations/prefect-aws/glue_job.md @@ -0,0 +1,6 @@ +--- +description: Tasks for interacting with AWS Glue Job +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.glue_job \ No newline at end of file diff --git a/docs/integrations/prefect-aws/img/ECSCluster_UI.png b/docs/integrations/prefect-aws/img/ECSCluster_UI.png new file mode 100644 index 000000000000..f3119b657233 Binary files /dev/null and b/docs/integrations/prefect-aws/img/ECSCluster_UI.png differ diff --git a/docs/integrations/prefect-aws/img/LaunchType_UI.png b/docs/integrations/prefect-aws/img/LaunchType_UI.png new file mode 100644 index 000000000000..bba88714349a Binary files /dev/null and b/docs/integrations/prefect-aws/img/LaunchType_UI.png differ diff --git a/docs/integrations/prefect-aws/img/VPC_UI.png b/docs/integrations/prefect-aws/img/VPC_UI.png new file mode 100644 index 000000000000..0ab8354d48d0 Binary files /dev/null and b/docs/integrations/prefect-aws/img/VPC_UI.png differ diff --git a/docs/integrations/prefect-aws/img/Workpool_UI.png b/docs/integrations/prefect-aws/img/Workpool_UI.png new file mode 100644 index 000000000000..755dbe97048f Binary files /dev/null and b/docs/integrations/prefect-aws/img/Workpool_UI.png differ diff --git a/docs/integrations/prefect-aws/img/favicon.ico b/docs/integrations/prefect-aws/img/favicon.ico new file mode 100644 index 000000000000..159c41525dd9 Binary files /dev/null and b/docs/integrations/prefect-aws/img/favicon.ico differ diff --git a/docs/integrations/prefect-aws/img/prefect-logo-mark.png b/docs/integrations/prefect-aws/img/prefect-logo-mark.png new file mode 100644 index 000000000000..0d6968217e74 Binary files /dev/null and b/docs/integrations/prefect-aws/img/prefect-logo-mark.png differ diff --git a/docs/integrations/prefect-aws/index.md b/docs/integrations/prefect-aws/index.md new file mode 100644 index 000000000000..f95ef42501e1 --- /dev/null +++ b/docs/integrations/prefect-aws/index.md @@ -0,0 +1,135 @@ +# `prefect-aws` + +

+ + PyPI + + +

+ +## Welcome + +`prefect-aws` makes it easy to leverage the capabilities of AWS in your workflows. + +## Getting started + +### Installation + +Prefect requires Python 3.8 or newer. + +We recommend using a Python virtual environment manager such as pipenv, conda, or virtualenv. + +Install `prefect-aws` + +```bash +pip install prefect-aws +``` + +### Registering blocks + +Register [blocks](https://docs.prefect.io/ui/blocks/) in this module to make them available for use. + +```bash +prefect block register -m prefect_aws +``` + +A list of available blocks in `prefect-aws` and their setup instructions can be found [here](https://PrefectHQ.github.io/prefect-aws/#blocks-catalog). + +### Saving credentials to a block + +You will need an AWS account and credentials to use `prefect-aws`. + +1. Refer to the [AWS Configuration documentation](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html#cli-configure-quickstart-creds) on how to retrieve your access key ID and secret access key +2. Copy the access key ID and secret access key +3. Create an `AWSCredenitals` block in the Prefect UI or use a Python script like the one below and replace the placeholders with your credential information and desired block name: + +```python +from prefect_aws import AwsCredentials +AwsCredentials( + aws_access_key_id="PLACEHOLDER", + aws_secret_access_key="PLACEHOLDER", + aws_session_token=None, # replace this with token if necessary + region_name="us-east-2" +).save("BLOCK-NAME-PLACEHOLDER") +``` + +Congrats! You can now load the saved block to use your credentials in your Python code: + +```python +from prefect_aws import AwsCredentials +AwsCredentials.load("BLOCK-NAME-PLACEHOLDER") +``` + +### Using Prefect with AWS S3 + +`prefect_aws` allows you to read and write objects with AWS S3 within your Prefect flows. + +The provided code snippet shows how you can use `prefect_aws` to upload a file to a AWS S3 bucket and download the same file under a different file name. + +Note, the following code assumes that the bucket already exists. + +```python +from pathlib import Path +from prefect import flow +from prefect_aws import AwsCredentials, S3Bucket + +@flow +def s3_flow(): + # create a dummy file to upload + file_path = Path("test-example.txt") + file_path.write_text("Hello, Prefect!") + + aws_credentials = AwsCredentials.load("BLOCK-NAME-PLACEHOLDER") + s3_bucket = S3Bucket( + bucket_name="BUCKET-NAME-PLACEHOLDER", + credentials=aws_credentials + ) + + s3_bucket_path = s3_bucket.upload_from_path(file_path) + downloaded_file_path = s3_bucket.download_object_to_path( + s3_bucket_path, "downloaded-test-example.txt" + ) + return downloaded_file_path.read_text() + +s3_flow() +``` + +### Using Prefect with AWS Secrets Manager + +`prefect_aws` allows you to read and write secrets with AWS Secrets Manager within your Prefect flows. + +The provided code snippet shows how you can use `prefect_aws` to write a secret to the Secret Manager, read the secret data, delete the secret, and finally return the secret data. + +```python +from prefect import flow +from prefect_aws import AwsCredentials, AwsSecret + +@flow +def secrets_manager_flow(): + aws_credentials = AwsCredentials.load("BLOCK-NAME-PLACEHOLDER") + aws_secret = AwsSecret(secret_name="test-example", aws_credentials=aws_credentials) + aws_secret.write_secret(secret_data=b"Hello, Prefect!") + secret_data = aws_secret.read_secret() + aws_secret.delete_secret() + return secret_data + +secrets_manager_flow() +``` + +### Using Prefect with AWS ECS + +`prefect_aws` allows you to use [AWS ECS](https://aws.amazon.com/ecs/) as infrastructure for your deployments. Using ECS for scheduled flow runs enables the dynamic provisioning of infrastructure for containers and unlocks greater scalability. This setup gives you all of the observation and orchestration benefits of Prefect, while also providing you the scalability of ECS. + +See the [ECS guide](/ecs_guide/) for a full walkthrough. + +## Resources + +Refer to the API documentation on the sidebar to explore all the capabilities of Prefect AWS! + +For more tips on how to use blocks and tasks in Prefect integration libraries, check out the [docs](https://docs.prefect.io/integrations/usage/)! + +For more information about how to use Prefect, please refer to the [Prefect documentation](https://docs.prefect.io/). + +If you encounter any bugs while using `prefect-aws`, feel free to open an issue in the [`prefect`](https://github.com/PrefectHQ/prefect) repository. + +If you have any questions or issues while using `prefect-aws`, you can find help in the [Prefect Slack community](https://prefect.io/slack). \ No newline at end of file diff --git a/docs/integrations/prefect-aws/lambda_function.md b/docs/integrations/prefect-aws/lambda_function.md new file mode 100644 index 000000000000..3f5c52e88171 --- /dev/null +++ b/docs/integrations/prefect-aws/lambda_function.md @@ -0,0 +1,6 @@ +--- +description: Module handling AWS Lambda functions +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.lambda_function \ No newline at end of file diff --git a/docs/integrations/prefect-aws/s3.md b/docs/integrations/prefect-aws/s3.md new file mode 100644 index 000000000000..60e656fd42bb --- /dev/null +++ b/docs/integrations/prefect-aws/s3.md @@ -0,0 +1,6 @@ +--- +description: Tasks for interacting with AWS S3 +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.s3 \ No newline at end of file diff --git a/docs/integrations/prefect-aws/secrets_manager.md b/docs/integrations/prefect-aws/secrets_manager.md new file mode 100644 index 000000000000..83a014765fe8 --- /dev/null +++ b/docs/integrations/prefect-aws/secrets_manager.md @@ -0,0 +1,6 @@ +--- +description: Tasks for interacting with AWS Secrets Manager +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.secrets_manager \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index e5771f70ba66..b15fd21fb6bb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -238,6 +238,21 @@ nav: - Using Integrations: integrations/usage.md - Contributing Integrations: integrations/contribute.md - Libraries: + - AWS: + - integrations/prefect-aws/index.md + - Guides: + - Setting up an ECS Worker: integrations/prefect-aws/ecs_guide.md + - Batch: integrations/prefect-aws/batch.md + - Client Waiter: integrations/prefect-aws/client_waiter.md + - Credentials: integrations/prefect-aws/credentials.md + - ECS Worker: integrations/prefect-aws/ecs_worker.md + - ECS (deprecated): integrations/prefect-aws/ecs.md + - Glue Job: integrations/prefect-aws/glue_job.md + - Lambda: integrations/prefect-aws/lambda_function.md + - Deployments: + - Steps: integrations/prefect-aws/deployments/steps.md + - S3: integrations/prefect-aws/s3.md + - Secrets Manager: integrations/prefect-aws/secrets_manager.md - Azure: - integrations/prefect-azure/index.md - ACI Worker Guide: integrations/prefect-azure/aci_worker.md diff --git a/scripts/serve_docs b/scripts/serve_docs index 786edf4d1dec..d67b922de441 100755 --- a/scripts/serve_docs +++ b/scripts/serve_docs @@ -30,12 +30,9 @@ cd "$(dirname "$0")/.." uv venv "$temp_dir/.venv" source "$temp_dir/.venv/bin/activate" -# Install all integration libraries from src/integrations +# Install all integration libraries from src/integrations and core dev dependencies as editable integration_packages=$(find src/integrations -name "pyproject.toml" -exec dirname {} \;) -uv pip install $integration_packages - -# Install the prefect package in editable mode -uv pip install -e ".[dev]" +uv pip install $integration_packages -e ".[dev]" # Build and serve the docs mkdocs serve -a localhost:8000 \ No newline at end of file diff --git a/src/integrations/prefect-aws/README.md b/src/integrations/prefect-aws/README.md new file mode 100644 index 000000000000..5526f03c5dbe --- /dev/null +++ b/src/integrations/prefect-aws/README.md @@ -0,0 +1,24 @@ +# `prefect-aws` + +

+ + PyPI + + +

+ +## Welcome! + +`prefect-aws` makes it easy to leverage the capabilities of AWS in your flows, featuring support for ECS, S3, Secrets Manager, and Batch. + +### Installation + +To start using `prefect-aws`: + +```bash +pip install prefect-aws +``` + +### Contributing + +Thanks for thinking about chipping in! Check out this [step-by-step guide](https://prefecthq.github.io/prefect-aws/#installation) on how to get started. diff --git a/src/integrations/prefect-aws/prefect_aws/__init__.py b/src/integrations/prefect-aws/prefect_aws/__init__.py new file mode 100644 index 000000000000..838ed6a7ebd1 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/__init__.py @@ -0,0 +1,29 @@ +from . import _version +from .credentials import AwsCredentials, MinIOCredentials +from .client_parameters import AwsClientParameters +from .lambda_function import LambdaFunction +from .s3 import S3Bucket +from .ecs import ECSTask +from .secrets_manager import AwsSecret +from .workers import ECSWorker + +from prefect._internal.compatibility.deprecated import ( + register_renamed_module, +) + +register_renamed_module( + "prefect_aws.projects", "prefect_aws.deployments", start_date="Jun 2023" +) + +__all__ = [ + "AwsCredentials", + "AwsClientParameters", + "LambdaFunction", + "MinIOCredentials", + "S3Bucket", + "ECSTask", + "AwsSecret", + "ECSWorker", +] + +__version__ = _version.__version__ diff --git a/src/integrations/prefect-aws/prefect_aws/batch.py b/src/integrations/prefect-aws/prefect_aws/batch.py new file mode 100644 index 000000000000..4be9dbd7bfe9 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/batch.py @@ -0,0 +1,73 @@ +"""Tasks for interacting with AWS Batch""" + +from typing import Any, Dict, Optional + +from prefect import get_run_logger, task +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect_aws.credentials import AwsCredentials + + +@task +async def batch_submit( + job_name: str, + job_queue: str, + job_definition: str, + aws_credentials: AwsCredentials, + **batch_kwargs: Optional[Dict[str, Any]], +) -> str: + """ + Submit a job to the AWS Batch job service. + + Args: + job_name: The AWS batch job name. + job_queue: Name of the AWS batch job queue. + job_definition: The AWS batch job definition. + aws_credentials: Credentials to use for authentication with AWS. + **batch_kwargs: Additional keyword arguments to pass to the boto3 + `submit_job` function. See the documentation for + [submit_job](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.submit_job) + for more details. + + Returns: + The id corresponding to the job. + + Example: + Submits a job to batch. + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.batch import batch_submit + + + @flow + def example_batch_submit_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + job_id = batch_submit( + "job_name", + "job_queue", + "job_definition", + aws_credentials + ) + return job_id + + example_batch_submit_flow() + ``` + + """ # noqa + logger = get_run_logger() + logger.info("Preparing to submit %s job to %s job queue", job_name, job_queue) + + batch_client = aws_credentials.get_boto3_session().client("batch") + + response = await run_sync_in_worker_thread( + batch_client.submit_job, + jobName=job_name, + jobQueue=job_queue, + jobDefinition=job_definition, + **batch_kwargs, + ) + return response["jobId"] diff --git a/src/integrations/prefect-aws/prefect_aws/client_parameters.py b/src/integrations/prefect-aws/prefect_aws/client_parameters.py new file mode 100644 index 000000000000..6b47c422b48b --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/client_parameters.py @@ -0,0 +1,161 @@ +"""Module handling Client parameters""" + +import warnings +from typing import Any, Dict, Optional, Union + +from botocore import UNSIGNED +from botocore.client import Config +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect_aws.utilities import hash_collection + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator +else: + from pydantic import BaseModel, Field, FilePath, root_validator, validator + + +class AwsClientParameters(BaseModel): + """ + Model used to manage extra parameters that you can pass when you initialize + the Client. If you want to find more information, see + [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html) + for more info about the possible client configurations. + + Attributes: + api_version: The API version to use. By default, botocore will + use the latest API version when creating a client. You only need + to specify this parameter if you want to use a previous API version + of the client. + use_ssl: Whether or not to use SSL. By default, SSL is used. + Note that not all services support non-ssl connections. + verify: Whether or not to verify SSL certificates. By default + SSL certificates are verified. If False, SSL will still be used + (unless use_ssl is False), but SSL certificates + will not be verified. Passing a file path to this is deprecated. + verify_cert_path: A filename of the CA cert bundle to + use. You can specify this argument if you want to use a + different CA cert bundle than the one used by botocore. + endpoint_url: The complete URL to use for the constructed + client. Normally, botocore will automatically construct the + appropriate URL to use when communicating with a service. You + can specify a complete URL (including the "http/https" scheme) + to override this behavior. If this value is provided, + then ``use_ssl`` is ignored. + config: Advanced configuration for Botocore clients. See + [botocore docs](https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html) + for more details. + """ # noqa E501 + + api_version: Optional[str] = Field( + default=None, description="The API version to use.", title="API Version" + ) + use_ssl: bool = Field( + default=True, description="Whether or not to use SSL.", title="Use SSL" + ) + verify: Union[bool, FilePath] = Field( + default=True, description="Whether or not to verify SSL certificates." + ) + verify_cert_path: Optional[FilePath] = Field( + default=None, + description="Path to the CA cert bundle to use.", + title="Certificate Authority Bundle File Path", + ) + endpoint_url: Optional[str] = Field( + default=None, + description="The complete URL to use for the constructed client.", + title="Endpoint URL", + ) + config: Optional[Dict[str, Any]] = Field( + default=None, + description="Advanced configuration for Botocore clients.", + title="Botocore Config", + ) + + def __hash__(self): + return hash( + ( + self.api_version, + self.use_ssl, + self.verify, + self.verify_cert_path, + self.endpoint_url, + hash_collection(self.config), + ) + ) + + @validator("config", pre=True) + def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]: + """ + Casts lists to Config instances. + """ + if isinstance(value, Config): + return value.__dict__["_user_provided_options"] + return value + + @root_validator + def deprecated_verify_cert_path(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + If verify is not a bool, raise a warning. + """ + verify = values.get("verify") + + # deprecate using verify in favor of verify_cert_path + # so the UI looks nicer + if verify is not None and not isinstance(verify, bool): + warnings.warn( + ( + "verify should be a boolean. " + "If you want to use a CA cert bundle, use verify_cert_path instead." + ), + DeprecationWarning, + ) + return values + + @root_validator + def verify_cert_path_and_verify(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + If verify_cert_path is set but verify is False, raise a warning. + """ + verify = values.get("verify", True) + verify_cert_path = values.get("verify_cert_path") + + if not verify and verify_cert_path: + warnings.warn( + "verify_cert_path is set but verify is False. " + "verify_cert_path will be ignored." + ) + values["verify_cert_path"] = None + elif not isinstance(verify, bool) and verify_cert_path: + warnings.warn( + "verify_cert_path is set but verify is also set as a file path. " + "verify_cert_path will take precedence." + ) + values["verify"] = True + return values + + def get_params_override(self) -> Dict[str, Any]: + """ + Return the dictionary of the parameters to override. + The parameters to override are the one which are not None. + """ + params = self.dict() + if params.get("verify_cert_path"): + # to ensure that verify doesn't re-overwrite verify_cert_path + params.pop("verify") + + params_override = {} + for key, value in params.items(): + if value is None: + continue + elif key == "config": + params_override[key] = Config(**value) + # botocore UNSIGNED is an instance while actual signers can + # be fetched as strings + if params_override[key].signature_version == "unsigned": + params_override[key].signature_version = UNSIGNED + elif key == "verify_cert_path": + params_override["verify"] = value + else: + params_override[key] = value + return params_override diff --git a/src/integrations/prefect-aws/prefect_aws/client_waiter.py b/src/integrations/prefect-aws/prefect_aws/client_waiter.py new file mode 100644 index 000000000000..f289fb40a7d4 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/client_waiter.py @@ -0,0 +1,76 @@ +"""Task for waiting on a long-running AWS job""" + +from typing import Any, Dict, Optional + +from botocore.waiter import WaiterModel, create_waiter_with_client + +from prefect import get_run_logger, task +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect_aws.credentials import AwsCredentials + + +@task +async def client_waiter( + client: str, + waiter_name: str, + aws_credentials: AwsCredentials, + waiter_definition: Optional[Dict[str, Any]] = None, + **waiter_kwargs: Optional[Dict[str, Any]], +): + """ + Uses the underlying boto3 waiter functionality. + + Args: + client: The AWS client on which to wait (e.g., 'client_wait', 'ec2', etc). + waiter_name: The name of the waiter to instantiate. + You may also use a custom waiter name, if you supply + an accompanying waiter definition dict. + aws_credentials: Credentials to use for authentication with AWS. + waiter_definition: A valid custom waiter model, as a dict. Note that if + you supply a custom definition, it is assumed that the provided + 'waiter_name' is contained within the waiter definition dict. + **waiter_kwargs: Arguments to pass to the `waiter.wait(...)` method. Will + depend upon the specific waiter being called. + + Example: + Run an ec2 waiter until instance_exists. + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.client_wait import client_waiter + + @flow + def example_client_wait_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + + waiter = client_waiter( + "ec2", + "instance_exists", + aws_credentials + ) + + return waiter + example_client_wait_flow() + ``` + """ + logger = get_run_logger() + logger.info("Waiting on %s job", client) + + boto_client = aws_credentials.get_boto3_session().client(client) + + if waiter_definition is not None: + # Use user-provided waiter definition + waiter_model = WaiterModel(waiter_definition) + waiter = create_waiter_with_client(waiter_name, waiter_model, boto_client) + elif waiter_name in boto_client.waiter_names: + waiter = boto_client.get_waiter(waiter_name) + else: + raise ValueError( + f"The waiter name, {waiter_name}, is not a valid boto waiter; " + "if using a custom waiter, you must provide a waiter definition" + ) + + await run_sync_in_worker_thread(waiter.wait, **waiter_kwargs) diff --git a/src/integrations/prefect-aws/prefect_aws/credentials.py b/src/integrations/prefect-aws/prefect_aws/credentials.py new file mode 100644 index 000000000000..04e020ce692f --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/credentials.py @@ -0,0 +1,306 @@ +"""Module handling AWS credentials""" + +from enum import Enum +from functools import lru_cache +from threading import Lock +from typing import Any, Optional, Union + +import boto3 +from mypy_boto3_s3 import S3Client +from mypy_boto3_secretsmanager import SecretsManagerClient +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect.blocks.abstract import CredentialsBlock + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, SecretStr +else: + from pydantic import Field, SecretStr + +from prefect_aws.client_parameters import AwsClientParameters + +_LOCK = Lock() + + +class ClientType(Enum): + """The supported boto3 clients.""" + + S3 = "s3" + ECS = "ecs" + BATCH = "batch" + SECRETS_MANAGER = "secretsmanager" + + +@lru_cache(maxsize=8, typed=True) +def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any: + """ + Helper method to cache and dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + with _LOCK: + if isinstance(client_type, ClientType): + client_type = client_type.value + + client = ctx.get_boto3_session().client( + service_name=client_type, + **ctx.aws_client_parameters.get_params_override(), + ) + return client + + +class AwsCredentials(CredentialsBlock): + """ + Block used to manage authentication with AWS. AWS authentication is + handled via the `boto3` module. Refer to the + [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html) + for more info about the possible credential configurations. + + Example: + Load stored AWS credentials: + ```python + from prefect_aws import AwsCredentials + + aws_credentials_block = AwsCredentials.load("BLOCK_NAME") + ``` + """ # noqa E501 + + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _block_type_name = "AWS Credentials" + _documentation_url = "https://prefecthq.github.io/prefect-aws/credentials/#prefect_aws.credentials.AwsCredentials" # noqa + + aws_access_key_id: Optional[str] = Field( + default=None, + description="A specific AWS access key ID.", + title="AWS Access Key ID", + ) + aws_secret_access_key: Optional[SecretStr] = Field( + default=None, + description="A specific AWS secret access key.", + title="AWS Access Key Secret", + ) + aws_session_token: Optional[str] = Field( + default=None, + description=( + "The session key for your AWS account. " + "This is only needed when you are using temporary credentials." + ), + title="AWS Session Token", + ) + profile_name: Optional[str] = Field( + default=None, description="The profile to use when creating your session." + ) + region_name: Optional[str] = Field( + default=None, + description="The AWS Region where you want to create new connections.", + ) + aws_client_parameters: AwsClientParameters = Field( + default_factory=AwsClientParameters, + description="Extra parameters to initialize the Client.", + title="AWS Client Parameters", + ) + + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + field_hashes = ( + hash(self.aws_access_key_id), + hash(self.aws_secret_access_key), + hash(self.aws_session_token), + hash(self.profile_name), + hash(self.region_name), + hash(self.aws_client_parameters), + ) + return hash(field_hashes) + + def get_boto3_session(self) -> boto3.Session: + """ + Returns an authenticated boto3 session that can be used to create clients + for AWS services + + Example: + Create an S3 client from an authorized boto3 session: + ```python + aws_credentials = AwsCredentials( + aws_access_key_id = "access_key_id", + aws_secret_access_key = "secret_access_key" + ) + s3_client = aws_credentials.get_boto3_session().client("s3") + ``` + """ + + if self.aws_secret_access_key: + aws_secret_access_key = self.aws_secret_access_key.get_secret_value() + else: + aws_secret_access_key = None + + return boto3.Session( + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=self.aws_session_token, + profile_name=self.profile_name, + region_name=self.region_name, + ) + + def get_client(self, client_type: Union[str, ClientType]): + """ + Helper method to dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + if isinstance(client_type, ClientType): + client_type = client_type.value + + return _get_client_cached(ctx=self, client_type=client_type) + + def get_s3_client(self) -> S3Client: + """ + Gets an authenticated S3 client. + + Returns: + An authenticated S3 client. + """ + return self.get_client(client_type=ClientType.S3) + + def get_secrets_manager_client(self) -> SecretsManagerClient: + """ + Gets an authenticated Secrets Manager client. + + Returns: + An authenticated Secrets Manager client. + """ + return self.get_client(client_type=ClientType.SECRETS_MANAGER) + + +class MinIOCredentials(CredentialsBlock): + """ + Block used to manage authentication with MinIO. Refer to the + [MinIO docs](https://docs.min.io/docs/minio-server-configuration-guide.html) + for more info about the possible credential configurations. + + Attributes: + minio_root_user: Admin or root user. + minio_root_password: Admin or root password. + region_name: Location of server, e.g. "us-east-1". + + Example: + Load stored MinIO credentials: + ```python + from prefect_aws import MinIOCredentials + + minio_credentials_block = MinIOCredentials.load("BLOCK_NAME") + ``` + """ # noqa E501 + + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/676cb17bcbdff601f97e0a02ff8bcb480e91ff40-250x250.png" # noqa + _block_type_name = "MinIO Credentials" + _description = ( + "Block used to manage authentication with MinIO. Refer to the MinIO " + "docs: https://docs.min.io/docs/minio-server-configuration-guide.html " + "for more info about the possible credential configurations." + ) + _documentation_url = "https://prefecthq.github.io/prefect-aws/credentials/#prefect_aws.credentials.MinIOCredentials" # noqa + + minio_root_user: str = Field(default=..., description="Admin or root user.") + minio_root_password: SecretStr = Field( + default=..., description="Admin or root password." + ) + region_name: Optional[str] = Field( + default=None, + description="The AWS Region where you want to create new connections.", + ) + aws_client_parameters: AwsClientParameters = Field( + default_factory=AwsClientParameters, + description="Extra parameters to initialize the Client.", + ) + + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + return hash( + ( + hash(self.minio_root_user), + hash(self.minio_root_password), + hash(self.region_name), + hash(frozenset(self.aws_client_parameters.dict().items())), + ) + ) + + def get_boto3_session(self) -> boto3.Session: + """ + Returns an authenticated boto3 session that can be used to create clients + and perform object operations on MinIO server. + + Example: + Create an S3 client from an authorized boto3 session + + ```python + minio_credentials = MinIOCredentials( + minio_root_user = "minio_root_user", + minio_root_password = "minio_root_password" + ) + s3_client = minio_credentials.get_boto3_session().client( + service="s3", + endpoint_url="http://localhost:9000" + ) + ``` + """ + + minio_root_password = ( + self.minio_root_password.get_secret_value() + if self.minio_root_password + else None + ) + + return boto3.Session( + aws_access_key_id=self.minio_root_user, + aws_secret_access_key=minio_root_password, + region_name=self.region_name, + ) + + def get_client(self, client_type: Union[str, ClientType]): + """ + Helper method to dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + if isinstance(client_type, ClientType): + client_type = client_type.value + + return _get_client_cached(ctx=self, client_type=client_type) + + def get_s3_client(self) -> S3Client: + """ + Gets an authenticated S3 client. + + Returns: + An authenticated S3 client. + """ + return self.get_client(client_type=ClientType.S3) diff --git a/src/integrations/prefect-aws/prefect_aws/deployments/__init__.py b/src/integrations/prefect-aws/prefect_aws/deployments/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/integrations/prefect-aws/prefect_aws/deployments/steps.py b/src/integrations/prefect-aws/prefect_aws/deployments/steps.py new file mode 100644 index 000000000000..54f21bff7bd5 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/deployments/steps.py @@ -0,0 +1,249 @@ +""" +Prefect deployment steps for code storage and retrieval in S3 and S3 +compatible services. +""" +from pathlib import Path, PurePosixPath +from typing import Dict, Optional + +import boto3 +from botocore.client import Config +from typing_extensions import TypedDict + +from prefect._internal.compatibility.deprecated import deprecated_callable +from prefect.utilities.filesystem import filter_files, relative_path_to_current_platform + + +class PushToS3Output(TypedDict): + """ + The output of the `push_to_s3` step. + """ + + bucket: str + folder: str + + +@deprecated_callable(start_date="Jun 2023", help="Use `PushToS3Output` instead.") +class PushProjectToS3Output(PushToS3Output): + """Deprecated. Use `PushToS3Output` instead.""" + + +class PullFromS3Output(TypedDict): + """ + The output of the `pull_from_s3` step. + """ + + bucket: str + folder: str + directory: str + + +@deprecated_callable(start_date="Jun 2023", help="Use `PullFromS3Output` instead.") +class PullProjectFromS3Output(PullFromS3Output): + """Deprecated. Use `PullFromS3Output` instead..""" + + +@deprecated_callable(start_date="Jun 2023", help="Use `push_to_s3` instead.") +def push_project_to_s3(*args, **kwargs): + """Deprecated. Use `push_to_s3` instead.""" + push_to_s3(*args, **kwargs) + + +def push_to_s3( + bucket: str, + folder: str, + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, + ignore_file: Optional[str] = ".prefectignore", +) -> PushToS3Output: + """ + Pushes the contents of the current working directory to an S3 bucket, + excluding files and folders specified in the ignore_file. + + Args: + bucket: The name of the S3 bucket where files will be uploaded. + folder: The folder in the S3 bucket where files will be uploaded. + credentials: A dictionary of AWS credentials (aws_access_key_id, + aws_secret_access_key, aws_session_token) or MinIO credentials + (minio_root_user, minio_root_password). + client_parameters: A dictionary of additional parameters to pass to the boto3 + client. + ignore_file: The name of the file containing ignore patterns. + + Returns: + A dictionary containing the bucket and folder where files were uploaded. + + Examples: + Push files to an S3 bucket: + ```yaml + push: + - prefect_aws.deployments.steps.push_to_s3: + requires: prefect-aws + bucket: my-bucket + folder: my-project + ``` + + Push files to an S3 bucket using credentials stored in a block: + ```yaml + push: + - prefect_aws.deployments.steps.push_to_s3: + requires: prefect-aws + bucket: my-bucket + folder: my-project + credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}" + ``` + + """ + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) + + local_path = Path.cwd() + + included_files = None + if ignore_file and Path(ignore_file).exists(): + with open(ignore_file, "r") as f: + ignore_patterns = f.readlines() + + included_files = filter_files(str(local_path), ignore_patterns) + + for local_file_path in local_path.expanduser().rglob("*"): + if ( + included_files is not None + and str(local_file_path.relative_to(local_path)) not in included_files + ): + continue + elif not local_file_path.is_dir(): + remote_file_path = Path(folder) / local_file_path.relative_to(local_path) + s3.upload_file( + str(local_file_path), bucket, str(remote_file_path.as_posix()) + ) + + return { + "bucket": bucket, + "folder": folder, + } + + +@deprecated_callable(start_date="Jun 2023", help="Use `pull_from_s3` instead.") +def pull_project_from_s3(*args, **kwargs): + """Deprecated. Use `pull_from_s3` instead.""" + pull_from_s3(*args, **kwargs) + + +def pull_from_s3( + bucket: str, + folder: str, + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, +) -> PullFromS3Output: + """ + Pulls the contents of an S3 bucket folder to the current working directory. + + Args: + bucket: The name of the S3 bucket where files are stored. + folder: The folder in the S3 bucket where files are stored. + credentials: A dictionary of AWS credentials (aws_access_key_id, + aws_secret_access_key, aws_session_token) or MinIO credentials + (minio_root_user, minio_root_password). + client_parameters: A dictionary of additional parameters to pass to the + boto3 client. + + Returns: + A dictionary containing the bucket, folder, and local directory where + files were downloaded. + + Examples: + Pull files from S3 using the default credentials and client parameters: + ```yaml + pull: + - prefect_aws.deployments.steps.pull_from_s3: + requires: prefect-aws + bucket: my-bucket + folder: my-project + ``` + + Pull files from S3 using credentials stored in a block: + ```yaml + pull: + - prefect_aws.deployments.steps.pull_from_s3: + requires: prefect-aws + bucket: my-bucket + folder: my-project + credentials: "{{ prefect.blocks.aws-credentials.dev-credentials }}" + ``` + """ + s3 = get_s3_client(credentials=credentials, client_parameters=client_parameters) + + local_path = Path.cwd() + + paginator = s3.get_paginator("list_objects_v2") + for result in paginator.paginate(Bucket=bucket, Prefix=folder): + for obj in result.get("Contents", []): + remote_key = obj["Key"] + + if remote_key[-1] == "/": + # object is a folder and will be created if it contains any objects + continue + + target = PurePosixPath( + local_path + / relative_path_to_current_platform(remote_key).relative_to(folder) + ) + Path.mkdir(Path(target.parent), parents=True, exist_ok=True) + s3.download_file(bucket, remote_key, str(target)) + + return { + "bucket": bucket, + "folder": folder, + "directory": str(local_path), + } + + +def get_s3_client( + credentials: Optional[Dict] = None, + client_parameters: Optional[Dict] = None, +) -> dict: + if credentials is None: + credentials = {} + if client_parameters is None: + client_parameters = {} + + # Get credentials from credentials (regardless if block or not) + aws_access_key_id = credentials.get( + "aws_access_key_id", credentials.get("minio_root_user", None) + ) + aws_secret_access_key = credentials.get( + "aws_secret_access_key", credentials.get("minio_root_password", None) + ) + aws_session_token = credentials.get("aws_session_token", None) + + # Get remaining session info from credentials, or client_parameters + profile_name = credentials.get( + "profile_name", client_parameters.get("profile_name", None) + ) + region_name = credentials.get( + "region_name", client_parameters.get("region_name", None) + ) + + # Get additional info from client_parameters, otherwise credentials input (if block) + aws_client_parameters = credentials.get("aws_client_parameters", client_parameters) + api_version = aws_client_parameters.get("api_version", None) + endpoint_url = aws_client_parameters.get("endpoint_url", None) + use_ssl = aws_client_parameters.get("use_ssl", True) + verify = aws_client_parameters.get("verify", None) + config_params = aws_client_parameters.get("config", {}) + config = Config(**config_params) + + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + profile_name=profile_name, + region_name=region_name, + ) + return session.client( + "s3", + api_version=api_version, + endpoint_url=endpoint_url, + use_ssl=use_ssl, + verify=verify, + config=config, + ) diff --git a/src/integrations/prefect-aws/prefect_aws/ecs.py b/src/integrations/prefect-aws/prefect_aws/ecs.py new file mode 100644 index 000000000000..dbc168b73345 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/ecs.py @@ -0,0 +1,1600 @@ +""" +DEPRECATION WARNING: + +This module is deprecated as of March 2024 and will not be available after September 2024. +It has been replaced by the ECS worker, which offers enhanced functionality and better performance. + +For upgrade instructions, see https://docs.prefect.io/latest/guides/upgrade-guide-agents-to-workers/. + +Integrations with the Amazon Elastic Container Service. + +Examples: + + Run a task using ECS Fargate + ```python + ECSTask(command=["echo", "hello world"]).run() + ``` + + Run a task using ECS Fargate with a spot container instance + ```python + ECSTask(command=["echo", "hello world"], launch_type="FARGATE_SPOT").run() + ``` + + Run a task using ECS with an EC2 container instance + ```python + ECSTask(command=["echo", "hello world"], launch_type="EC2").run() + ``` + + Run a task on a specific VPC using ECS Fargate + ```python + ECSTask(command=["echo", "hello world"], vpc_id="vpc-01abcdf123456789a").run() + ``` + + Run a task and stream the container's output to the local terminal. Note an + execution role must be provided with permissions: logs:CreateLogStream, + logs:CreateLogGroup, and logs:PutLogEvents. + ```python + ECSTask( + command=["echo", "hello world"], + stream_output=True, + execution_role_arn="..." + ) + ``` + + Run a task using an existing task definition as a base + ```python + ECSTask(command=["echo", "hello world"], task_definition_arn="arn:aws:ecs:...") + ``` + + Run a task with a specific image + ```python + ECSTask(command=["echo", "hello world"], image="alpine:latest") + ``` + + Run a task with custom memory and CPU requirements + ```python + ECSTask(command=["echo", "hello world"], memory=4096, cpu=2048) + ``` + + Run a task with custom environment variables + ```python + ECSTask(command=["echo", "hello $PLANET"], env={"PLANET": "earth"}) + ``` + + Run a task in a specific ECS cluster + ```python + ECSTask(command=["echo", "hello world"], cluster="my-cluster-name") + ``` + + Run a task with custom VPC subnets + ```python + ECSTask( + command=["echo", "hello world"], + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/subnets", + "value": ["subnet-80b6fbcd", "subnet-42a6fdgd"], + }, + ] + ) + ``` + + Run a task without a public IP assigned + ```python + ECSTask( + command=["echo", "hello world"], + vpc_id="vpc-01abcdf123456789a", + task_customizations=[ + { + "op": "replace", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + ] + ) + ``` + + Run a task with custom VPC security groups + ```python + ECSTask( + command=["echo", "hello world"], + vpc_id="vpc-01abcdf123456789a", + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", + "value": ["sg-d72e9599956a084f5"], + }, + ], + ) + ``` +""" # noqa + +import copy +import difflib +import json +import logging +import pprint +import shlex +import sys +import time +import warnings +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union + +import boto3 +import yaml +from anyio.abc import TaskStatus +from jsonpointer import JsonPointerException +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect._internal.compatibility.deprecated import deprecated_class +from prefect.blocks.core import BlockNotSavedError +from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound +from prefect.infrastructure.base import Infrastructure, InfrastructureResult +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from prefect.utilities.dockerutils import get_prefect_image_name +from prefect.utilities.pydantic import JsonPatch +from prefect_aws.utilities import assemble_document_for_patches + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, root_validator, validator +else: + from pydantic import Field, root_validator, validator + +from slugify import slugify +from typing_extensions import Literal, Self + +from prefect_aws import AwsCredentials +from prefect_aws.workers.ecs_worker import _TAG_REGEX, ECSWorker + +# Internal type alias for ECS clients which are generated dynamically in botocore +_ECSClient = Any + + +if TYPE_CHECKING: + from prefect.client.schemas import FlowRun + from prefect.server.schemas.core import Deployment, Flow + + +class ECSTaskResult(InfrastructureResult): + """The result of a run of an ECS task""" + + +PREFECT_ECS_CONTAINER_NAME = "prefect" +ECS_DEFAULT_CPU = 1024 +ECS_DEFAULT_MEMORY = 2048 +ECS_DEFAULT_FAMILY = "prefect" +POST_REGISTRATION_FIELDS = [ + "compatibilities", + "taskDefinitionArn", + "revision", + "status", + "requiresAttributes", + "registeredAt", + "registeredBy", + "deregisteredAt", +] + + +def get_prefect_container(containers: List[dict]) -> Optional[dict]: + """ + Extract the Prefect container from a list of containers or container definitions. + If not found, `None` is returned. + """ + return get_container(containers, PREFECT_ECS_CONTAINER_NAME) + + +def get_container(containers: List[dict], name: str) -> Optional[dict]: + """ + Extract a container from a list of containers or container definitions. + If not found, `None` is returned. + """ + for container in containers: + if container.get("name") == name: + return container + return None + + +def parse_task_identifier(identifier: str) -> Tuple[str, str]: + """ + Splits identifier into its cluster and task components, e.g. + input "cluster_name::task_arn" outputs ("cluster_name", "task_arn"). + """ + cluster, task = identifier.split("::", maxsplit=1) + return cluster, task + + +def _pretty_diff(d1: dict, d2: dict) -> str: + """ + Return a string with a pretty printed difference between two dictionaries. + """ + return "\n" + "\n".join( + difflib.ndiff(pprint.pformat(d1).splitlines(), pprint.pformat(d2).splitlines()) + ) + + +@deprecated_class( + start_date="Mar 2024", + help=( + "Use the ECS worker instead." + " Refer to the upgrade guide for more information:" + " https://docs.prefect.io/latest/guides/upgrade-guide-agents-to-workers/." + ), +) +class ECSTask(Infrastructure): + """ + Run a command as an ECS task. + + Attributes: + type: The slug for this task type with a default value of "ecs-task". + aws_credentials: The AWS credentials to use to connect to ECS with a + default factory of AwsCredentials. + task_definition_arn: An optional identifier for an existing task definition + to use. If fields are set on the ECSTask that conflict with the task + definition, a new copy will be registered with the required values. + Cannot be used with task_definition. If not provided, Prefect will + generate and register a minimal task definition. + task_definition: An optional ECS task definition to use. Prefect may set + defaults or override fields on this task definition to match other + ECSTask fields. Cannot be used with task_definition_arn. + If not provided, Prefect will generate and register + a minimal task definition. + family: An optional family for the task definition. If not provided, + it will be inferred from the task definition. If the task definition + does not have a family, the name will be generated. When flow and + deployment metadata is available, the generated name will include + their names. Values for this field will be slugified to match + AWS character requirements. + image: An optional image to use for the Prefect container in the task. + If this value is not null, it will override the value in the task + definition. This value defaults to a Prefect base image matching + your local versions. + auto_deregister_task_definition: A boolean that controls if any task + definitions that are created by this block will be deregistered + or not. Existing task definitions linked by ARN will never be + deregistered. Deregistering a task definition does not remove + it from your AWS account, instead it will be marked as INACTIVE. + cpu: The amount of CPU to provide to the ECS task. Valid amounts are + specified in the AWS documentation. If not provided, a default + value of ECS_DEFAULT_CPU will be used unless present on + the task definition. + memory: The amount of memory to provide to the ECS task. + Valid amounts are specified in the AWS documentation. + If not provided, a default value of ECS_DEFAULT_MEMORY + will be used unless present on the task definition. + execution_role_arn: An execution role to use for the task. + This controls the permissions of the task when it is launching. + If this value is not null, it will override the value in the task + definition. An execution role must be provided to capture logs + from the container. + configure_cloudwatch_logs: A boolean that controls if the Prefect + container will be configured to send its output to the + AWS CloudWatch logs service or not. This functionality requires + an execution role with permissions to create log streams and groups. + cloudwatch_logs_options: A dictionary of options to pass to + the CloudWatch logs configuration. + stream_output: A boolean indicating whether logs will be + streamed from the Prefect container to the local console. + launch_type: An optional launch type for the ECS task run infrastructure. + vpc_id: An optional VPC ID to link the task run to. + This is only applicable when using the 'awsvpc' network mode for your task. + cluster: An optional ECS cluster to run the task in. + The ARN or name may be provided. If not provided, + the default cluster will be used. + env: A dictionary of environment variables to provide to + the task run. These variables are set on the + Prefect container at task runtime. + task_role_arn: An optional role to attach to the task run. + This controls the permissions of the task while it is running. + task_customizations: A list of JSON 6902 patches to apply to the task + run request. If a string is given, it will parsed as a JSON expression. + task_start_timeout_seconds: The amount of time to watch for the + start of the ECS task before marking it as failed. The task must + enter a RUNNING state to be considered started. + task_watch_poll_interval: The amount of time to wait between AWS API + calls while monitoring the state of an ECS task. + """ + + _block_type_slug = "ecs-task" + _block_type_name = "ECS Task" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _description = "Run a command as an ECS task." # noqa + _documentation_url = ( + "https://prefecthq.github.io/prefect-aws/ecs/#prefect_aws.ecs.ECSTask" # noqa + ) + + type: Literal["ecs-task"] = Field( + "ecs-task", description="The slug for this task type." + ) + + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description="The AWS credentials to use to connect to ECS.", + ) + + # Task definition settings + task_definition_arn: Optional[str] = Field( + default=None, + description=( + "An identifier for an existing task definition to use. If fields are set " + "on the `ECSTask` that conflict with the task definition, a new copy " + "will be registered with the required values. " + "Cannot be used with `task_definition`. If not provided, Prefect will " + "generate and register a minimal task definition." + ), + ) + task_definition: Optional[dict] = Field( + default=None, + description=( + "An ECS task definition to use. Prefect may set defaults or override " + "fields on this task definition to match other `ECSTask` fields. " + "Cannot be used with `task_definition_arn`. If not provided, Prefect will " + "generate and register a minimal task definition." + ), + ) + family: Optional[str] = Field( + default=None, + description=( + "A family for the task definition. If not provided, it will be inferred " + "from the task definition. If the task definition does not have a family, " + "the name will be generated. When flow and deployment metadata is " + "available, the generated name will include their names. Values for this " + "field will be slugified to match AWS character requirements." + ), + ) + image: Optional[str] = Field( + default=None, + description=( + "The image to use for the Prefect container in the task. If this value is " + "not null, it will override the value in the task definition. This value " + "defaults to a Prefect base image matching your local versions." + ), + ) + auto_deregister_task_definition: bool = Field( + default=True, + description=( + "If set, any task definitions that are created by this block will be " + "deregistered. Existing task definitions linked by ARN will never be " + "deregistered. Deregistering a task definition does not remove it from " + "your AWS account, instead it will be marked as INACTIVE." + ), + ) + + # Mixed task definition / run settings + cpu: int = Field( + title="CPU", + default=None, + description=( + "The amount of CPU to provide to the ECS task. Valid amounts are " + "specified in the AWS documentation. If not provided, a default value of " + f"{ECS_DEFAULT_CPU} will be used unless present on the task definition." + ), + ) + memory: int = Field( + default=None, + description=( + "The amount of memory to provide to the ECS task. Valid amounts are " + "specified in the AWS documentation. If not provided, a default value of " + f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition." + ), + ) + execution_role_arn: str = Field( + title="Execution Role ARN", + default=None, + description=( + "An execution role to use for the task. This controls the permissions of " + "the task when it is launching. If this value is not null, it will " + "override the value in the task definition. An execution role must be " + "provided to capture logs from the container." + ), + ) + configure_cloudwatch_logs: bool = Field( + default=None, + description=( + "If `True`, the Prefect container will be configured to send its output " + "to the AWS CloudWatch logs service. This functionality requires an " + "execution role with logs:CreateLogStream, logs:CreateLogGroup, and " + "logs:PutLogEvents permissions. The default for this field is `False` " + "unless `stream_output` is set." + ), + ) + cloudwatch_logs_options: Dict[str, str] = Field( + default_factory=dict, + description=( + "When `configure_cloudwatch_logs` is enabled, this setting may be used to " + "pass additional options to the CloudWatch logs configuration or override " + "the default options. See the AWS documentation for available options. " + "https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options." # noqa + ), + ) + stream_output: bool = Field( + default=None, + description=( + "If `True`, logs will be streamed from the Prefect container to the local " + "console. Unless you have configured AWS CloudWatch logs manually on your " + "task definition, this requires the same prerequisites outlined in " + "`configure_cloudwatch_logs`." + ), + ) + + # Task run settings + launch_type: Optional[ + Literal["FARGATE", "EC2", "EXTERNAL", "FARGATE_SPOT"] + ] = Field( + default="FARGATE", + description=( + "The type of ECS task run infrastructure that should be used. Note that" + " 'FARGATE_SPOT' is not a formal ECS launch type, but we will configure" + " the proper capacity provider strategy if set here." + ), + ) + vpc_id: Optional[str] = Field( + title="VPC ID", + default=None, + description=( + "The AWS VPC to link the task run to. This is only applicable when using " + "the 'awsvpc' network mode for your task. FARGATE tasks require this " + "network mode, but for EC2 tasks the default network mode is 'bridge'. " + "If using the 'awsvpc' network mode and this field is null, your default " + "VPC will be used. If no default VPC can be found, the task run will fail." + ), + ) + cluster: Optional[str] = Field( + default=None, + description=( + "The ECS cluster to run the task in. The ARN or name may be provided. If " + "not provided, the default cluster will be used." + ), + ) + env: Dict[str, Optional[str]] = Field( + title="Environment Variables", + default_factory=dict, + description=( + "Environment variables to provide to the task run. These variables are set " + "on the Prefect container at task runtime. These will not be set on the " + "task definition." + ), + ) + task_role_arn: str = Field( + title="Task Role ARN", + default=None, + description=( + "A role to attach to the task run. This controls the permissions of the " + "task while it is running." + ), + ) + task_customizations: JsonPatch = Field( + default_factory=lambda: JsonPatch([]), + description=( + "A list of JSON 6902 patches to apply to the task run request. " + "If a string is given, it will parsed as a JSON expression." + ), + ) + + # Execution settings + task_start_timeout_seconds: int = Field( + default=120, + description=( + "The amount of time to watch for the start of the ECS task " + "before marking it as failed. The task must enter a RUNNING state to be " + "considered started." + ), + ) + task_watch_poll_interval: float = Field( + default=5.0, + description=( + "The amount of time to wait between AWS API calls while monitoring the " + "state of an ECS task." + ), + ) + + @root_validator(pre=True) + def set_default_configure_cloudwatch_logs(cls, values: dict) -> dict: + """ + Streaming output generally requires CloudWatch logs to be configured. + + To avoid entangled arguments in the simple case, `configure_cloudwatch_logs` + defaults to matching the value of `stream_output`. + """ + configure_cloudwatch_logs = values.get("configure_cloudwatch_logs") + if configure_cloudwatch_logs is None: + values["configure_cloudwatch_logs"] = values.get("stream_output") + return values + + @root_validator + def configure_cloudwatch_logs_requires_execution_role_arn( + cls, values: dict + ) -> dict: + """ + Enforces that an execution role arn is provided (or could be provided by a + runtime task definition) when configuring logging. + """ + if ( + values.get("configure_cloudwatch_logs") + and not values.get("execution_role_arn") + # Do not raise if they've linked to another task definition or provided + # it without using our shortcuts + and not values.get("task_definition_arn") + and not (values.get("task_definition") or {}).get("executionRoleArn") + ): + raise ValueError( + "An `execution_role_arn` must be provided to use " + "`configure_cloudwatch_logs` or `stream_logs`." + ) + return values + + @root_validator + def cloudwatch_logs_options_requires_configure_cloudwatch_logs( + cls, values: dict + ) -> dict: + """ + Enforces that an execution role arn is provided (or could be provided by a + runtime task definition) when configuring logging. + """ + if values.get("cloudwatch_logs_options") and not values.get( + "configure_cloudwatch_logs" + ): + raise ValueError( + "`configure_cloudwatch_log` must be enabled to use " + "`cloudwatch_logs_options`." + ) + return values + + @root_validator(pre=True) + def image_is_required(cls, values: dict) -> dict: + """ + Enforces that an image is available if image is `None`. + """ + has_image = bool(values.get("image")) + has_task_definition_arn = bool(values.get("task_definition_arn")) + + # The image can only be null when the task_definition_arn is set + if has_image or has_task_definition_arn: + return values + + prefect_container = ( + get_prefect_container( + (values.get("task_definition") or {}).get("containerDefinitions", []) + ) + or {} + ) + image_in_task_definition = prefect_container.get("image") + + # If a task_definition is given with a prefect container image, use that value + if image_in_task_definition: + values["image"] = image_in_task_definition + # Otherwise, it should default to the Prefect base image + else: + values["image"] = get_prefect_image_name() + return values + + @validator("task_customizations", pre=True) + def cast_customizations_to_a_json_patch( + cls, value: Union[List[Dict], JsonPatch, str] + ) -> JsonPatch: + """ + Casts lists to JsonPatch instances. + """ + if isinstance(value, str): + value = json.loads(value) + if isinstance(value, list): + return JsonPatch(value) + return value # type: ignore + + class Config: + """Configuration of pydantic.""" + + # Support serialization of the 'JsonPatch' type + arbitrary_types_allowed = True + json_encoders = {JsonPatch: lambda p: p.patch} + + def dict(self, *args, **kwargs) -> Dict: + """ + Convert to a dictionary. + """ + # Support serialization of the 'JsonPatch' type + d = super().dict(*args, **kwargs) + d["task_customizations"] = self.task_customizations.patch + return d + + def prepare_for_flow_run( + self: Self, + flow_run: "FlowRun", + deployment: Optional["Deployment"] = None, + flow: Optional["Flow"] = None, + ) -> Self: + """ + Return an copy of the block that is prepared to execute a flow run. + """ + new_family = None + + # Update the family if not specified elsewhere + if ( + not self.family + and not self.task_definition_arn + and not (self.task_definition and self.task_definition.get("family")) + ): + if flow and deployment: + new_family = f"{ECS_DEFAULT_FAMILY}__{flow.name}__{deployment.name}" + elif flow and not deployment: + new_family = f"{ECS_DEFAULT_FAMILY}__{flow.name}" + elif deployment and not flow: + # This is a weird case and should not be see in the wild + new_family = f"{ECS_DEFAULT_FAMILY}__unknown-flow__{deployment.name}" + + new = super().prepare_for_flow_run(flow_run, deployment=deployment, flow=flow) + + if new_family: + return new.copy(update={"family": new_family}) + else: + # Avoid an extra copy if not needed + return new + + @sync_compatible + async def run(self, task_status: Optional[TaskStatus] = None) -> ECSTaskResult: + """ + Run the configured task on ECS. + """ + boto_session, ecs_client = await run_sync_in_worker_thread( + self._get_session_and_client + ) + + ( + task_arn, + cluster_arn, + task_definition, + is_new_task_definition, + ) = await run_sync_in_worker_thread( + self._create_task_and_wait_for_start, boto_session, ecs_client + ) + + # Display a nice message indicating the command and image + command = self.command or get_prefect_container( + task_definition["containerDefinitions"] + ).get("command", []) + self.logger.info( + f"{self._log_prefix}: Running command {' '.join(command)!r} " + f"in container {PREFECT_ECS_CONTAINER_NAME!r} ({self.image})..." + ) + + # The task identifier is "{cluster}::{task}" where we use the configured cluster + # if set to preserve matching by name rather than arn + # Note "::" is used despite the Prefect standard being ":" because ARNs contain + # single colons. + identifier = (self.cluster if self.cluster else cluster_arn) + "::" + task_arn + + if task_status: + task_status.started(identifier) + + status_code = await run_sync_in_worker_thread( + self._watch_task_and_get_exit_code, + task_arn, + cluster_arn, + task_definition, + is_new_task_definition and self.auto_deregister_task_definition, + boto_session, + ecs_client, + ) + + return ECSTaskResult( + identifier=identifier, + # If the container does not start the exit code can be null but we must + # still report a status code. We use a -1 to indicate a special code. + status_code=status_code if status_code is not None else -1, + ) + + @sync_compatible + async def kill(self, identifier: str, grace_seconds: int = 30) -> None: + """ + Kill a task running on ECS. + + Args: + identifier: A cluster and task arn combination. This should match a value + yielded by `ECSTask.run`. + """ + if grace_seconds != 30: + self.logger.warning( + f"Kill grace period of {grace_seconds}s requested, but AWS does not " + "support dynamic grace period configuration so 30s will be used. " + "See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-agent-config.html for configuration of grace periods." # noqa + ) + cluster, task = parse_task_identifier(identifier) + await run_sync_in_worker_thread(self._stop_task, cluster, task) + + @staticmethod + def get_corresponding_worker_type() -> str: + """Return the corresponding worker type for this infrastructure block.""" + return ECSWorker.type + + async def generate_work_pool_base_job_template(self) -> dict: + """ + Generate a base job template for a cloud-run work pool with the same + configuration as this block. + + Returns: + - dict: a base job template for a cloud-run work pool + """ + base_job_template = copy.deepcopy(ECSWorker.get_default_base_job_template()) + for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items(): + if key == "command": + base_job_template["variables"]["properties"]["command"][ + "default" + ] = shlex.join(value) + elif key in [ + "type", + "block_type_slug", + "_block_document_id", + "_block_document_name", + "_is_anonymous", + "task_customizations", + ]: + continue + elif key == "aws_credentials": + if not self.aws_credentials._block_document_id: + raise BlockNotSavedError( + "It looks like you are trying to use a block that" + " has not been saved. Please call `.save` on your block" + " before publishing it as a work pool." + ) + base_job_template["variables"]["properties"]["aws_credentials"][ + "default" + ] = { + "$ref": { + "block_document_id": str( + self.aws_credentials._block_document_id + ) + } + } + elif key == "task_definition": + base_job_template["job_configuration"]["task_definition"] = value + elif key in base_job_template["variables"]["properties"]: + base_job_template["variables"]["properties"][key]["default"] = value + else: + self.logger.warning( + f"Variable {key!r} is not supported by Cloud Run work pools." + " Skipping." + ) + + if self.task_customizations: + network_config_patches = JsonPatch( + [ + patch + for patch in self.task_customizations + if "networkConfiguration" in patch["path"] + ] + ) + minimal_network_config = assemble_document_for_patches( + network_config_patches + ) + if minimal_network_config: + minimal_network_config_with_patches = network_config_patches.apply( + minimal_network_config + ) + base_job_template["variables"]["properties"]["network_configuration"][ + "default" + ] = minimal_network_config_with_patches["networkConfiguration"] + try: + base_job_template["job_configuration"][ + "task_run_request" + ] = self.task_customizations.apply( + base_job_template["job_configuration"]["task_run_request"] + ) + except JsonPointerException: + self.logger.warning( + "Unable to apply task customizations to the base job template." + "You may need to update the template manually." + ) + + return base_job_template + + def _stop_task(self, cluster: str, task: str) -> None: + """ + Stop a running ECS task. + """ + if self.cluster is not None and cluster != self.cluster: + raise InfrastructureNotAvailable( + "Cannot stop ECS task: this infrastructure block has access to " + f"cluster {self.cluster!r} but the task is running in cluster " + f"{cluster!r}." + ) + + _, ecs_client = self._get_session_and_client() + try: + ecs_client.stop_task(cluster=cluster, task=task) + except Exception as exc: + # Raise a special exception if the task does not exist + if "ClusterNotFound" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the cluster {cluster!r} could not be found." + ) from exc + if "not find task" in str(exc) or "referenced task was not found" in str( + exc + ): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the task {task!r} could not be found in " + f"cluster {cluster!r}." + ) from exc + if "no registered tasks" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the cluster {cluster!r} has no tasks." + ) from exc + + # Reraise unknown exceptions + raise + + @property + def _log_prefix(self) -> str: + """ + Internal property for generating a prefix for logs where `name` may be null + """ + if self.name is not None: + return f"ECSTask {self.name!r}" + else: + return "ECSTask" + + def _get_session_and_client(self) -> Tuple[boto3.Session, _ECSClient]: + """ + Retrieve a boto3 session and ECS client + """ + boto_session = self.aws_credentials.get_boto3_session() + ecs_client = boto_session.client("ecs") + return boto_session, ecs_client + + def _create_task_and_wait_for_start( + self, boto_session: boto3.Session, ecs_client: _ECSClient + ) -> Tuple[str, str, dict, bool]: + """ + Register the task definition, create the task run, and wait for it to start. + + Returns a tuple of + - The task ARN + - The task's cluster ARN + - The task definition + - A bool indicating if the task definition is newly registered + """ + new_task_definition_registered = False + requested_task_definition = ( + self._retrieve_task_definition(ecs_client, self.task_definition_arn) + if self.task_definition_arn + else self.task_definition + ) or {} + task_definition_arn = requested_task_definition.get("taskDefinitionArn", None) + + task_definition = self._prepare_task_definition( + requested_task_definition, region=ecs_client.meta.region_name + ) + + # We must register the task definition if the arn is null or changes were made + if task_definition != requested_task_definition or not task_definition_arn: + # Before registering, check if the latest task definition in the family + # can be used + latest_task_definition = self._retrieve_latest_task_definition( + ecs_client, task_definition["family"] + ) + if self._task_definitions_equal(latest_task_definition, task_definition): + self.logger.debug( + f"{self._log_prefix}: The latest task definition matches the " + "required task definition; using that instead of registering a new " + " one." + ) + task_definition_arn = latest_task_definition["taskDefinitionArn"] + else: + if task_definition_arn: + self.logger.warning( + f"{self._log_prefix}: Settings require changes to the linked " + "task definition. A new task definition will be registered. " + + ( + "Enable DEBUG level logs to see the difference." + if self.logger.level > logging.DEBUG + else "" + ) + ) + self.logger.debug( + f"{self._log_prefix}: Diff for requested task definition" + + _pretty_diff(requested_task_definition, task_definition) + ) + else: + self.logger.info( + f"{self._log_prefix}: Registering task definition..." + ) + self.logger.debug( + "Task definition payload\n" + yaml.dump(task_definition) + ) + + task_definition_arn = self._register_task_definition( + ecs_client, task_definition + ) + new_task_definition_registered = True + + if task_definition.get("networkMode") == "awsvpc": + network_config = self._load_vpc_network_config(self.vpc_id, boto_session) + else: + network_config = None + + task_run = self._prepare_task_run( + network_config=network_config, + task_definition_arn=task_definition_arn, + ) + self.logger.info(f"{self._log_prefix}: Creating task run...") + self.logger.debug("Task run payload\n" + yaml.dump(task_run)) + + try: + task = self._run_task(ecs_client, task_run) + task_arn = task["taskArn"] + cluster_arn = task["clusterArn"] + except Exception as exc: + self._report_task_run_creation_failure(task_run, exc) + + # Raises an exception if the task does not start + self.logger.info(f"{self._log_prefix}: Waiting for task run to start...") + self._wait_for_task_start( + task_arn, cluster_arn, ecs_client, timeout=self.task_start_timeout_seconds + ) + + return task_arn, cluster_arn, task_definition, new_task_definition_registered + + def _watch_task_and_get_exit_code( + self, + task_arn: str, + cluster_arn: str, + task_definition: dict, + deregister_task_definition: bool, + boto_session: boto3.Session, + ecs_client: _ECSClient, + ) -> Optional[int]: + """ + Wait for the task run to complete and retrieve the exit code of the Prefect + container. + """ + + # Wait for completion and stream logs + task = self._wait_for_task_finish( + task_arn, cluster_arn, task_definition, ecs_client, boto_session + ) + + if deregister_task_definition: + ecs_client.deregister_task_definition( + taskDefinition=task["taskDefinitionArn"] + ) + + # Check the status code of the Prefect container + prefect_container = get_prefect_container(task["containers"]) + assert ( + prefect_container is not None + ), f"'prefect' container missing from task: {task}" + status_code = prefect_container.get("exitCode") + self._report_container_status_code(PREFECT_ECS_CONTAINER_NAME, status_code) + + return status_code + + def _task_definitions_equal(self, taskdef_1, taskdef_2) -> bool: + """ + Compare two task definitions. + + Since one may come from the AWS API and have populated defaults, we do our best + to homogenize the definitions without changing their meaning. + """ + if taskdef_1 == taskdef_2: + return True + + if taskdef_1 is None or taskdef_2 is None: + return False + + taskdef_1 = copy.deepcopy(taskdef_1) + taskdef_2 = copy.deepcopy(taskdef_2) + + def _set_aws_defaults(taskdef): + """Set defaults that AWS would set after registration""" + container_definitions = taskdef.get("containerDefinitions", []) + essential = any( + container.get("essential") for container in container_definitions + ) + if not essential: + container_definitions[0].setdefault("essential", True) + + taskdef.setdefault("networkMode", "bridge") + + _set_aws_defaults(taskdef_1) + _set_aws_defaults(taskdef_2) + + def _drop_empty_keys(dict_): + """Recursively drop keys with 'empty' values""" + for key, value in tuple(dict_.items()): + if not value: + dict_.pop(key) + if isinstance(value, dict): + _drop_empty_keys(value) + if isinstance(value, list): + for v in value: + if isinstance(v, dict): + _drop_empty_keys(v) + + _drop_empty_keys(taskdef_1) + _drop_empty_keys(taskdef_2) + + # Clear fields that change on registration for comparison + for field in POST_REGISTRATION_FIELDS: + taskdef_1.pop(field, None) + taskdef_2.pop(field, None) + + return taskdef_1 == taskdef_2 + + def preview(self) -> str: + """ + Generate a preview of the task definition and task run that will be sent to AWS. + """ + preview = "" + + task_definition_arn = self.task_definition_arn or "" + + if self.task_definition or not self.task_definition_arn: + task_definition = self._prepare_task_definition( + self.task_definition or {}, + region=self.aws_credentials.region_name + or "", + ) + preview += "---\n# Task definition\n" + preview += yaml.dump(task_definition) + preview += "\n" + else: + task_definition = None + + if task_definition and task_definition.get("networkMode") == "awsvpc": + vpc = "the default VPC" if not self.vpc_id else self.vpc_id + network_config = { + "awsvpcConfiguration": { + "subnets": f"", + "assignPublicIp": "ENABLED", + } + } + else: + network_config = None + + task_run = self._prepare_task_run(network_config, task_definition_arn) + preview += "---\n# Task run request\n" + preview += yaml.dump(task_run) + + return preview + + def _report_container_status_code( + self, name: str, status_code: Optional[int] + ) -> None: + """ + Display a log for the given container status code. + """ + if status_code is None: + self.logger.error( + f"{self._log_prefix}: Task exited without reporting an exit status " + f"for container {name!r}." + ) + elif status_code == 0: + self.logger.info( + f"{self._log_prefix}: Container {name!r} exited successfully." + ) + else: + self.logger.warning( + f"{self._log_prefix}: Container {name!r} exited with non-zero exit " + f"code {status_code}." + ) + + def _report_task_run_creation_failure(self, task_run: dict, exc: Exception) -> None: + """ + Wrap common AWS task run creation failures with nicer user-facing messages. + """ + # AWS generates exception types at runtime so they must be captured a bit + # differently than normal. + if "ClusterNotFoundException" in str(exc): + cluster = task_run.get("cluster", "default") + raise RuntimeError( + f"Failed to run ECS task, cluster {cluster!r} not found. " + "Confirm that the cluster is configured in your region." + ) from exc + elif "No Container Instances" in str(exc) and self.launch_type == "EC2": + cluster = task_run.get("cluster", "default") + raise RuntimeError( + f"Failed to run ECS task, cluster {cluster!r} does not appear to " + "have any container instances associated with it. Confirm that you " + "have EC2 container instances available." + ) from exc + elif ( + "failed to validate logger args" in str(exc) + and "AccessDeniedException" in str(exc) + and self.configure_cloudwatch_logs + ): + raise RuntimeError( + "Failed to run ECS task, the attached execution role does not appear " + "to have sufficient permissions. Ensure that the execution role " + f"{self.execution_role!r} has permissions logs:CreateLogStream, " + "logs:CreateLogGroup, and logs:PutLogEvents." + ) + else: + raise + + def _watch_task_run( + self, + task_arn: str, + cluster_arn: str, + ecs_client: _ECSClient, + current_status: str = "UNKNOWN", + until_status: str = None, + timeout: int = None, + ) -> Generator[None, None, dict]: + """ + Watches an ECS task run by querying every `poll_interval` seconds. After each + query, the retrieved task is yielded. This function returns when the task run + reaches a STOPPED status or the provided `until_status`. + + Emits a log each time the status changes. + """ + last_status = status = current_status + t0 = time.time() + while status != until_status: + tasks = ecs_client.describe_tasks( + tasks=[task_arn], cluster=cluster_arn, include=["TAGS"] + )["tasks"] + + if tasks: + task = tasks[0] + + status = task["lastStatus"] + if status != last_status: + self.logger.info(f"{self._log_prefix}: Status is {status}.") + + yield task + + # No point in continuing if the status is final + if status == "STOPPED": + break + + last_status = status + + else: + # Intermittently, the task will not be described. We wat to respect the + # watch timeout though. + self.logger.debug(f"{self._log_prefix}: Task not found.") + + elapsed_time = time.time() - t0 + if timeout is not None and elapsed_time > timeout: + raise RuntimeError( + f"Timed out after {elapsed_time}s while watching task for status " + f"{until_status or 'STOPPED'}" + ) + time.sleep(self.task_watch_poll_interval) + + def _wait_for_task_start( + self, task_arn: str, cluster_arn: str, ecs_client: _ECSClient, timeout: int + ) -> dict: + """ + Waits for an ECS task run to reach a RUNNING status. + + If a STOPPED status is reached instead, an exception is raised indicating the + reason that the task run did not start. + """ + for task in self._watch_task_run( + task_arn, cluster_arn, ecs_client, until_status="RUNNING", timeout=timeout + ): + # TODO: It is possible that the task has passed _through_ a RUNNING + # status during the polling interval. In this case, there is not an + # exception to raise. + if task["lastStatus"] == "STOPPED": + code = task.get("stopCode") + reason = task.get("stoppedReason") + # Generate a dynamic exception type from the AWS name + raise type(code, (RuntimeError,), {})(reason) + + return task + + def _wait_for_task_finish( + self, + task_arn: str, + cluster_arn: str, + task_definition: dict, + ecs_client: _ECSClient, + boto_session: boto3.Session, + ): + """ + Watch an ECS task until it reaches a STOPPED status. + + If configured, logs from the Prefect container are streamed to stderr. + + Returns a description of the task on completion. + """ + can_stream_output = False + + if self.stream_output: + container_def = get_prefect_container( + task_definition["containerDefinitions"] + ) + if not container_def: + self.logger.warning( + f"{self._log_prefix}: Prefect container definition not found in " + "task definition. Output cannot be streamed." + ) + elif not container_def.get("logConfiguration"): + self.logger.warning( + f"{self._log_prefix}: Logging configuration not found on task. " + "Output cannot be streamed." + ) + elif not container_def["logConfiguration"].get("logDriver") == "awslogs": + self.logger.warning( + f"{self._log_prefix}: Logging configuration uses unsupported " + " driver {container_def['logConfiguration'].get('logDriver')!r}. " + "Output cannot be streamed." + ) + else: + # Prepare to stream the output + log_config = container_def["logConfiguration"]["options"] + logs_client = boto_session.client("logs") + can_stream_output = True + # Track the last log timestamp to prevent double display + last_log_timestamp: Optional[int] = None + # Determine the name of the stream as "prefix/container/run-id" + stream_name = "/".join( + [ + log_config["awslogs-stream-prefix"], + PREFECT_ECS_CONTAINER_NAME, + task_arn.rsplit("/")[-1], + ] + ) + self.logger.info( + f"{self._log_prefix}: Streaming output from container " + f"{PREFECT_ECS_CONTAINER_NAME!r}..." + ) + + for task in self._watch_task_run( + task_arn, cluster_arn, ecs_client, current_status="RUNNING" + ): + if self.stream_output and can_stream_output: + # On each poll for task run status, also retrieve available logs + last_log_timestamp = self._stream_available_logs( + logs_client, + log_group=log_config["awslogs-group"], + log_stream=stream_name, + last_log_timestamp=last_log_timestamp, + ) + + return task + + def _stream_available_logs( + self, + logs_client: Any, + log_group: str, + log_stream: str, + last_log_timestamp: Optional[int] = None, + ) -> Optional[int]: + """ + Stream logs from the given log group and stream since the last log timestamp. + + Will continue on paginated responses until all logs are returned. + + Returns the last log timestamp which can be used to call this method in the + future. + """ + last_log_stream_token = "NO-TOKEN" + next_log_stream_token = None + + # AWS will return the same token that we send once the end of the paginated + # response is reached + while last_log_stream_token != next_log_stream_token: + last_log_stream_token = next_log_stream_token + + request = { + "logGroupName": log_group, + "logStreamName": log_stream, + } + + if last_log_stream_token is not None: + request["nextToken"] = last_log_stream_token + + if last_log_timestamp is not None: + # Bump the timestamp by one ms to avoid retrieving the last log again + request["startTime"] = last_log_timestamp + 1 + + try: + response = logs_client.get_log_events(**request) + except Exception: + self.logger.error( + ( + f"{self._log_prefix}: Failed to read log events with request " + f"{request}" + ), + exc_info=True, + ) + return last_log_timestamp + + log_events = response["events"] + for log_event in log_events: + # TODO: This doesn't forward to the local logger, which can be + # bad for customizing handling and understanding where the + # log is coming from, but it avoid nesting logger information + # when the content is output from a Prefect logger on the + # running infrastructure + print(log_event["message"], file=sys.stderr) + + if ( + last_log_timestamp is None + or log_event["timestamp"] > last_log_timestamp + ): + last_log_timestamp = log_event["timestamp"] + + next_log_stream_token = response.get("nextForwardToken") + if not log_events: + # Stop reading pages if there was no data + break + + return last_log_timestamp + + def _retrieve_latest_task_definition( + self, ecs_client: _ECSClient, task_definition_family: str + ) -> Optional[dict]: + try: + latest_task_definition = self._retrieve_task_definition( + ecs_client, task_definition_family + ) + except Exception: + # The family does not exist... + return None + + return latest_task_definition + + def _retrieve_task_definition( + self, ecs_client: _ECSClient, task_definition_arn: str + ): + """ + Retrieve an existing task definition from AWS. + """ + self.logger.info( + f"{self._log_prefix}: Retrieving task definition {task_definition_arn!r}..." + ) + response = ecs_client.describe_task_definition( + taskDefinition=task_definition_arn + ) + return response["taskDefinition"] + + def _register_task_definition( + self, ecs_client: _ECSClient, task_definition: dict + ) -> str: + """ + Register a new task definition with AWS. + """ + # TODO: Consider including a global cache for this task definition since + # registration of task definitions is frequently rate limited + task_definition_request = copy.deepcopy(task_definition) + + # We need to remove some fields here if copying an existing task definition + for field in POST_REGISTRATION_FIELDS: + task_definition_request.pop(field, None) + + response = ecs_client.register_task_definition(**task_definition_request) + return response["taskDefinition"]["taskDefinitionArn"] + + def _prepare_task_definition(self, task_definition: dict, region: str) -> dict: + """ + Prepare a task definition by inferring any defaults and merging overrides. + """ + task_definition = copy.deepcopy(task_definition) + + # Configure the Prefect runtime container + task_definition.setdefault("containerDefinitions", []) + container = get_prefect_container(task_definition["containerDefinitions"]) + if container is None: + container = {"name": PREFECT_ECS_CONTAINER_NAME} + task_definition["containerDefinitions"].append(container) + + if self.image: + container["image"] = self.image + + # Remove any keys that have been explicitly "unset" + unset_keys = {key for key, value in self.env.items() if value is None} + for item in tuple(container.get("environment", [])): + if item["name"] in unset_keys: + container["environment"].remove(item) + + if self.configure_cloudwatch_logs: + container["logConfiguration"] = { + "logDriver": "awslogs", + "options": { + "awslogs-create-group": "true", + "awslogs-group": "prefect", + "awslogs-region": region, + "awslogs-stream-prefix": self.name or "prefect", + **self.cloudwatch_logs_options, + }, + } + + family = self.family or task_definition.get("family") or ECS_DEFAULT_FAMILY + task_definition["family"] = slugify( + family, + max_length=255, + regex_pattern=r"[^a-zA-Z0-9-_]+", + ) + + # CPU and memory are required in some cases, retrieve the value to use + cpu = self.cpu or task_definition.get("cpu") or ECS_DEFAULT_CPU + memory = self.memory or task_definition.get("memory") or ECS_DEFAULT_MEMORY + + if self.launch_type == "FARGATE" or self.launch_type == "FARGATE_SPOT": + # Task level memory and cpu are required when using fargate + task_definition["cpu"] = str(cpu) + task_definition["memory"] = str(memory) + + # The FARGATE compatibility is required if it will be used as as launch type + requires_compatibilities = task_definition.setdefault( + "requiresCompatibilities", [] + ) + if "FARGATE" not in requires_compatibilities: + task_definition["requiresCompatibilities"].append("FARGATE") + + # Only the 'awsvpc' network mode is supported when using FARGATE + # However, we will not enforce that here if the user has set it + network_mode = task_definition.setdefault("networkMode", "awsvpc") + + if network_mode != "awsvpc": + warnings.warn( + f"Found network mode {network_mode!r} which is not compatible with " + f"launch type {self.launch_type!r}. Use either the 'EC2' launch " + "type or the 'awsvpc' network mode." + ) + + elif self.launch_type == "EC2": + # Container level memory and cpu are required when using ec2 + container.setdefault("cpu", int(cpu)) + container.setdefault("memory", int(memory)) + + if self.execution_role_arn and not self.task_definition_arn: + task_definition["executionRoleArn"] = self.execution_role_arn + + if self.configure_cloudwatch_logs and not task_definition.get( + "executionRoleArn" + ): + raise ValueError( + "An execution role arn must be set on the task definition to use " + "`configure_cloudwatch_logs` or `stream_logs` but no execution role " + "was found on the task definition." + ) + + return task_definition + + def _prepare_task_run_overrides(self) -> dict: + """ + Prepare the 'overrides' payload for a task run request. + """ + overrides = { + "containerOverrides": [ + { + "name": PREFECT_ECS_CONTAINER_NAME, + "environment": [ + {"name": key, "value": value} + for key, value in { + **self._base_environment(), + **self.env, + }.items() + if value is not None + ], + } + ], + } + + prefect_container_overrides = overrides["containerOverrides"][0] + + if self.command: + prefect_container_overrides["command"] = self.command + + if self.execution_role_arn: + overrides["executionRoleArn"] = self.execution_role_arn + + if self.task_role_arn: + overrides["taskRoleArn"] = self.task_role_arn + + if self.memory: + overrides["memory"] = str(self.memory) + prefect_container_overrides.setdefault("memory", self.memory) + + if self.cpu: + overrides["cpu"] = str(self.cpu) + prefect_container_overrides.setdefault("cpu", self.cpu) + + return overrides + + def _load_vpc_network_config( + self, vpc_id: Optional[str], boto_session: boto3.Session + ) -> dict: + """ + Load settings from a specific VPC or the default VPC and generate a task + run request's network configuration. + """ + ec2_client = boto_session.client("ec2") + vpc_message = "the default VPC" if not vpc_id else f"VPC with ID {vpc_id}" + + if not vpc_id: + # Retrieve the default VPC + describe = {"Filters": [{"Name": "isDefault", "Values": ["true"]}]} + else: + describe = {"VpcIds": [vpc_id]} + + vpcs = ec2_client.describe_vpcs(**describe)["Vpcs"] + if not vpcs: + help_message = ( + "Pass an explicit `vpc_id` or configure a default VPC." + if not vpc_id + else "Check that the VPC exists in the current region." + ) + raise ValueError( + f"Failed to find {vpc_message}. " + "Network configuration cannot be inferred. " + help_message + ) + + vpc_id = vpcs[0]["VpcId"] + subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["Subnets"] + if not subnets: + raise ValueError( + f"Failed to find subnets for {vpc_message}. " + "Network configuration cannot be inferred." + ) + + return { + "awsvpcConfiguration": { + "subnets": [s["SubnetId"] for s in subnets], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + def _prepare_task_run( + self, + network_config: Optional[dict], + task_definition_arn: str, + ) -> dict: + """ + Prepare a task run request payload. + """ + task_run = { + "overrides": self._prepare_task_run_overrides(), + "tags": [ + { + "key": slugify( + key, + regex_pattern=_TAG_REGEX, + allow_unicode=True, + lowercase=False, + ), + "value": slugify( + value, + regex_pattern=_TAG_REGEX, + allow_unicode=True, + lowercase=False, + ), + } + for key, value in self.labels.items() + ], + "taskDefinition": task_definition_arn, + } + + if self.cluster: + task_run["cluster"] = self.cluster + + if self.launch_type: + if self.launch_type == "FARGATE_SPOT": + task_run["capacityProviderStrategy"] = [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] + else: + task_run["launchType"] = self.launch_type + + if network_config: + task_run["networkConfiguration"] = network_config + + task_run = self.task_customizations.apply(task_run) + return task_run + + def _run_task(self, ecs_client: _ECSClient, task_run: dict): + """ + Run the task using the ECS client. + + This is isolated as a separate method for testing purposes. + """ + return ecs_client.run_task(**task_run)["tasks"][0] diff --git a/src/integrations/prefect-aws/prefect_aws/glue_job.py b/src/integrations/prefect-aws/prefect_aws/glue_job.py new file mode 100644 index 000000000000..c131265027c6 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/glue_job.py @@ -0,0 +1,190 @@ +""" +Integrations with the AWS Glue Job. + +""" + +import time +from typing import Any, Optional + +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect.blocks.abstract import JobBlock, JobRun + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import BaseModel, Field +else: + from pydantic import BaseModel, Field + +from prefect_aws import AwsCredentials + +_GlueJobClient = Any + + +class GlueJobRun(JobRun, BaseModel): + """Execute a Glue Job""" + + job_name: str = Field( + ..., + title="AWS Glue Job Name", + description="The name of the job definition to use.", + ) + + job_id: str = Field( + ..., + title="AWS Glue Job ID", + description="The ID of the job run.", + ) + + job_watch_poll_interval: float = Field( + default=60.0, + description=( + "The amount of time to wait between AWS API calls while monitoring the " + "state of an Glue Job." + ), + ) + + _error_states = ["FAILED", "STOPPED", "ERROR", "TIMEOUT"] + + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description="The AWS credentials to use to connect to Glue.", + ) + + client: _GlueJobClient = Field(default=None, description="") + + async def fetch_result(self) -> str: + """fetch glue job state""" + job = self._get_job_run() + return job["JobRun"]["JobRunState"] + + def wait_for_completion(self) -> None: + """ + Wait for the job run to complete and get exit code + """ + self.logger.info(f"watching job {self.job_name} with run id {self.job_id}") + while True: + job = self._get_job_run() + job_state = job["JobRun"]["JobRunState"] + if job_state in self._error_states: + # Generate a dynamic exception type from the AWS name + self.logger.error(f"job failed: {job['JobRun']['ErrorMessage']}") + raise RuntimeError(job["JobRun"]["ErrorMessage"]) + elif job_state == "SUCCEEDED": + self.logger.info(f"job succeeded: {self.job_id}") + break + + time.sleep(self.job_watch_poll_interval) + + def _get_job_run(self): + """get glue job""" + return self.client.get_job_run(JobName=self.job_name, RunId=self.job_id) + + +class GlueJobBlock(JobBlock): + """Execute a job to the AWS Glue Job service. + + Attributes: + job_name: The name of the job definition to use. + arguments: The job arguments associated with this run. + For this job run, they replace the default arguments set in the job + definition itself. + You can specify arguments here that your own job-execution script consumes, + as well as arguments that Glue itself consumes. + Job arguments may be logged. Do not pass plaintext secrets as arguments. + Retrieve secrets from a Glue Connection, Secrets Manager or other secret + management mechanism if you intend to keep them within the Job. + [doc](https://docs.aws.amazon.com/glue/latest/dg/aws-glue-programming-etl-glue-arguments.html) + job_watch_poll_interval: The amount of time to wait between AWS API + calls while monitoring the state of a Glue Job. + default is 60s because of jobs that use AWS Glue versions 2.0 and later + have a 1-minute minimum. + [AWS Glue Pricing](https://aws.amazon.com/glue/pricing/?nc1=h_ls) + + Example: + Start a job to AWS Glue Job. + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.glue_job import GlueJobBlock + + + @flow + def example_run_glue_job(): + aws_credentials = AwsCredentials( + aws_access_key_id="your_access_key_id", + aws_secret_access_key="your_secret_access_key" + ) + glue_job_run = GlueJobBlock( + job_name="your_glue_job_name", + arguments={"--YOUR_EXTRA_ARGUMENT": "YOUR_EXTRA_ARGUMENT_VALUE"}, + ).trigger() + + return glue_job_run.wait_for_completion() + + + example_run_glue_job() + ``` + """ + + job_name: str = Field( + ..., + title="AWS Glue Job Name", + description="The name of the job definition to use.", + ) + + arguments: Optional[dict] = Field( + default=None, + title="AWS Glue Job Arguments", + description="The job arguments associated with this run.", + ) + job_watch_poll_interval: float = Field( + default=60.0, + description=( + "The amount of time to wait between AWS API calls while monitoring the " + "state of an Glue Job." + ), + ) + + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description="The AWS credentials to use to connect to Glue.", + ) + + async def trigger(self) -> GlueJobRun: + """trigger for GlueJobRun""" + client = self._get_client() + job_run_id = self._start_job(client) + return GlueJobRun( + job_name=self.job_name, + job_id=job_run_id, + job_watch_poll_interval=self.job_watch_poll_interval, + ) + + def _start_job(self, client: _GlueJobClient) -> str: + """ + Start the AWS Glue Job + [doc](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/glue/client/start_job_run.html) + """ + self.logger.info( + f"starting job {self.job_name} with arguments {self.arguments}" + ) + try: + response = client.start_job_run( + JobName=self.job_name, + Arguments=self.arguments, + ) + job_run_id = str(response["JobRunId"]) + self.logger.info(f"job started with job run id: {job_run_id}") + return job_run_id + except Exception as e: + self.logger.error(f"failed to start job: {e}") + raise RuntimeError + + def _get_client(self) -> _GlueJobClient: + """ + Retrieve a Glue Job Client + """ + boto_session = self.aws_credentials.get_boto3_session() + return boto_session.client("glue") diff --git a/src/integrations/prefect-aws/prefect_aws/lambda_function.py b/src/integrations/prefect-aws/prefect_aws/lambda_function.py new file mode 100644 index 000000000000..fb76f956cdd3 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/lambda_function.py @@ -0,0 +1,195 @@ +"""Integrations with AWS Lambda. + +Examples: + + Run a lambda function with a payload + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(payload={"foo": "bar"}) + ``` + + Specify a version of a lambda function + + ```python + LambdaFunction( + function_name="test-function", + qualifier="1", + aws_credentials=aws_credentials, + ).invoke() + ``` + + Invoke a lambda function asynchronously + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(invocation_type="Event") + ``` + + Invoke a lambda function and return the last 4 KB of logs + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(tail=True) + ``` + + Invoke a lambda function with a client context + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(client_context={"bar": "foo"}) + ``` + +""" +import json +from typing import Literal, Optional + +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect.blocks.core import Block +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field + +from prefect_aws.credentials import AwsCredentials + + +class LambdaFunction(Block): + """Invoke a Lambda function. This block is part of the prefect-aws + collection. Install prefect-aws with `pip install prefect-aws` to use this + block. + + Attributes: + function_name: The name, ARN, or partial ARN of the Lambda function to + run. This must be the name of a function that is already deployed + to AWS Lambda. + qualifier: The version or alias of the Lambda function to use when + invoked. If not specified, the latest (unqualified) version of the + Lambda function will be used. + aws_credentials: The AWS credentials to use to connect to AWS Lambda + with a default factory of AwsCredentials. + + """ + + _block_type_name = "Lambda Function" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _documentation_url = "https://prefecthq.github.io/prefect-aws/s3/#prefect_aws.lambda_function.LambdaFunction" # noqa + + function_name: str = Field( + title="Function Name", + description=( + "The name, ARN, or partial ARN of the Lambda function to run. This" + " must be the name of a function that is already deployed to AWS" + " Lambda." + ), + ) + qualifier: Optional[str] = Field( + default=None, + title="Qualifier", + description=( + "The version or alias of the Lambda function to use when invoked. " + "If not specified, the latest (unqualified) version of the Lambda " + "function will be used." + ), + ) + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description="The AWS credentials to invoke the Lambda with.", + ) + + class Config: + """Lambda's pydantic configuration.""" + + smart_union = True + + def _get_lambda_client(self): + """ + Retrieve a boto3 session and Lambda client + """ + boto_session = self.aws_credentials.get_boto3_session() + lambda_client = boto_session.client("lambda") + return lambda_client + + @sync_compatible + async def invoke( + self, + payload: dict = None, + invocation_type: Literal[ + "RequestResponse", "Event", "DryRun" + ] = "RequestResponse", + tail: bool = False, + client_context: Optional[dict] = None, + ) -> dict: + """ + [Invoke](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html) + the Lambda function with the given payload. + + Args: + payload: The payload to send to the Lambda function. + invocation_type: The invocation type of the Lambda function. This + can be one of "RequestResponse", "Event", or "DryRun". Uses + "RequestResponse" by default. + tail: If True, the response will include the base64-encoded last 4 + KB of log data produced by the Lambda function. + client_context: The client context to send to the Lambda function. + Limited to 3583 bytes. + + Returns: + The response from the Lambda function. + + Examples: + + ```python + from prefect_aws.lambda_function import LambdaFunction + from prefect_aws.credentials import AwsCredentials + + credentials = AwsCredentials() + lambda_function = LambdaFunction( + function_name="test_lambda_function", + aws_credentials=credentials, + ) + response = lambda_function.invoke( + payload={"foo": "bar"}, + invocation_type="RequestResponse", + ) + response["Payload"].read() + ``` + ```txt + b'{"foo": "bar"}' + ``` + + """ + # Add invocation arguments + kwargs = dict(FunctionName=self.function_name) + + if payload: + kwargs["Payload"] = json.dumps(payload).encode() + + # Let boto handle invalid invocation types + kwargs["InvocationType"] = invocation_type + + if self.qualifier is not None: + kwargs["Qualifier"] = self.qualifier + + if tail: + kwargs["LogType"] = "Tail" + + if client_context is not None: + # For some reason this is string, but payload is bytes + kwargs["ClientContext"] = json.dumps(client_context) + + # Get client and invoke + lambda_client = await run_sync_in_worker_thread(self._get_lambda_client) + return await run_sync_in_worker_thread(lambda_client.invoke, **kwargs) diff --git a/src/integrations/prefect-aws/prefect_aws/s3.py b/src/integrations/prefect-aws/prefect_aws/s3.py new file mode 100644 index 000000000000..84410f148fcc --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/s3.py @@ -0,0 +1,1357 @@ +"""Tasks for interacting with AWS S3""" +import asyncio +import io +import os +import uuid +from pathlib import Path +from typing import Any, BinaryIO, Dict, List, Optional, Union + +import boto3 +from botocore.paginate import PageIterator +from botocore.response import StreamingBody +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect import get_run_logger, task +from prefect.blocks.abstract import ObjectStorageBlock +from prefect.filesystems import WritableDeploymentStorage, WritableFileSystem +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from prefect.utilities.filesystem import filter_files + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field + +from prefect_aws import AwsCredentials, MinIOCredentials +from prefect_aws.client_parameters import AwsClientParameters + + +@task +async def s3_download( + bucket: str, + key: str, + aws_credentials: AwsCredentials, + aws_client_parameters: AwsClientParameters = AwsClientParameters(), +) -> bytes: + """ + Downloads an object with a given key from a given S3 bucket. + + Args: + bucket: Name of bucket to download object from. Required if a default value was + not supplied when creating the task. + key: Key of object to download. Required if a default value was not supplied + when creating the task. + aws_credentials: Credentials to use for authentication with AWS. + aws_client_parameters: Custom parameter for the boto3 client initialization. + + + Returns: + A `bytes` representation of the downloaded object. + + Example: + Download a file from an S3 bucket: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_download + + + @flow + async def example_s3_download_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + data = await s3_download( + bucket="bucket", + key="key", + aws_credentials=aws_credentials, + ) + + example_s3_download_flow() + ``` + """ + logger = get_run_logger() + logger.info("Downloading object from bucket %s with key %s", bucket, key) + + s3_client = aws_credentials.get_boto3_session().client( + "s3", **aws_client_parameters.get_params_override() + ) + stream = io.BytesIO() + await run_sync_in_worker_thread( + s3_client.download_fileobj, Bucket=bucket, Key=key, Fileobj=stream + ) + stream.seek(0) + output = stream.read() + + return output + + +@task +async def s3_upload( + data: bytes, + bucket: str, + aws_credentials: AwsCredentials, + aws_client_parameters: AwsClientParameters = AwsClientParameters(), + key: Optional[str] = None, +) -> str: + """ + Uploads data to an S3 bucket. + + Args: + data: Bytes representation of data to upload to S3. + bucket: Name of bucket to upload data to. Required if a default value was not + supplied when creating the task. + aws_credentials: Credentials to use for authentication with AWS. + aws_client_parameters: Custom parameter for the boto3 client initialization.. + key: Key of object to download. Defaults to a UUID string. + + Returns: + The key of the uploaded object + + Example: + Read and upload a file to an S3 bucket: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_upload + + + @flow + async def example_s3_upload_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + with open("data.csv", "rb") as file: + key = await s3_upload( + bucket="bucket", + key="data.csv", + data=file.read(), + aws_credentials=aws_credentials, + ) + + example_s3_upload_flow() + ``` + """ + logger = get_run_logger() + + key = key or str(uuid.uuid4()) + + logger.info("Uploading object to bucket %s with key %s", bucket, key) + + s3_client = aws_credentials.get_boto3_session().client( + "s3", **aws_client_parameters.get_params_override() + ) + stream = io.BytesIO(data) + await run_sync_in_worker_thread( + s3_client.upload_fileobj, stream, Bucket=bucket, Key=key + ) + + return key + + +@task +async def s3_copy( + source_path: str, + target_path: str, + source_bucket_name: str, + aws_credentials: AwsCredentials, + target_bucket_name: Optional[str] = None, + **copy_kwargs, +) -> str: + """Uses S3's internal + [CopyObject](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html) + to copy objects within or between buckets. To copy objects between buckets, the + credentials must have permission to read the source object and write to the target + object. If the credentials do not have those permissions, try using + `S3Bucket.stream_from`. + + Args: + source_path: The path to the object to copy. Can be a string or `Path`. + target_path: The path to copy the object to. Can be a string or `Path`. + source_bucket_name: The bucket to copy the object from. + aws_credentials: Credentials to use for authentication with AWS. + target_bucket_name: The bucket to copy the object to. If not provided, defaults + to `source_bucket`. + **copy_kwargs: Additional keyword arguments to pass to `S3Client.copy_object`. + + Returns: + The path that the object was copied to. Excludes the bucket name. + + Examples: + + Copy notes.txt from s3://my-bucket/my_folder/notes.txt to + s3://my-bucket/my_folder/notes_copy.txt. + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_copy + + aws_credentials = AwsCredentials.load("my-creds") + + @flow + async def example_copy_flow(): + await s3_copy( + source_path="my_folder/notes.txt", + target_path="my_folder/notes_copy.txt", + source_bucket_name="my-bucket", + aws_credentials=aws_credentials, + ) + + example_copy_flow() + ``` + + Copy notes.txt from s3://my-bucket/my_folder/notes.txt to + s3://other-bucket/notes_copy.txt. + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_copy + + aws_credentials = AwsCredentials.load("shared-creds") + + @flow + async def example_copy_flow(): + await s3_copy( + source_path="my_folder/notes.txt", + target_path="notes_copy.txt", + source_bucket_name="my-bucket", + aws_credentials=aws_credentials, + target_bucket_name="other-bucket", + ) + + example_copy_flow() + ``` + + """ + logger = get_run_logger() + + s3_client = aws_credentials.get_s3_client() + + target_bucket_name = target_bucket_name or source_bucket_name + + logger.info( + "Copying object from bucket %s with key %s to bucket %s with key %s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + s3_client.copy_object( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + **copy_kwargs, + ) + + return target_path + + +@task +async def s3_move( + source_path: str, + target_path: str, + source_bucket_name: str, + aws_credentials: AwsCredentials, + target_bucket_name: Optional[str] = None, +) -> str: + """ + Move an object from one S3 location to another. To move objects between buckets, + the credentials must have permission to read and delete the source object and write + to the target object. If the credentials do not have those permissions, this method + will raise an error. If the credentials have permission to read the source object + but not delete it, the object will be copied but not deleted. + + Args: + source_path: The path of the object to move + target_path: The path to move the object to + source_bucket_name: The name of the bucket containing the source object + aws_credentials: Credentials to use for authentication with AWS. + target_bucket_name: The bucket to copy the object to. If not provided, defaults + to `source_bucket`. + + Returns: + The path that the object was moved to. Excludes the bucket name. + """ + logger = get_run_logger() + + s3_client = aws_credentials.get_s3_client() + + # If target bucket is not provided, assume it's the same as the source bucket + target_bucket_name = target_bucket_name or source_bucket_name + + logger.info( + "Moving object from s3://%s/%s s3://%s/%s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + # Copy the object to the new location + s3_client.copy_object( + Bucket=target_bucket_name, + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Key=target_path, + ) + + # Delete the original object + s3_client.delete_object(Bucket=source_bucket_name, Key=source_path) + + return target_path + + +def _list_objects_sync(page_iterator: PageIterator): + """ + Synchronous method to collect S3 objects into a list + + Args: + page_iterator: AWS Paginator for S3 objects + + Returns: + List[Dict]: List of object information + """ + return [content for page in page_iterator for content in page.get("Contents", [])] + + +@task +async def s3_list_objects( + bucket: str, + aws_credentials: AwsCredentials, + aws_client_parameters: AwsClientParameters = AwsClientParameters(), + prefix: str = "", + delimiter: str = "", + page_size: Optional[int] = None, + max_items: Optional[int] = None, + jmespath_query: Optional[str] = None, +) -> List[Dict[str, Any]]: + """ + Lists details of objects in a given S3 bucket. + + Args: + bucket: Name of bucket to list items from. Required if a default value was not + supplied when creating the task. + aws_credentials: Credentials to use for authentication with AWS. + aws_client_parameters: Custom parameter for the boto3 client initialization.. + prefix: Used to filter objects with keys starting with the specified prefix. + delimiter: Character used to group keys of listed objects. + page_size: Number of objects to return in each request to the AWS API. + max_items: Maximum number of objects that to be returned by task. + jmespath_query: Query used to filter objects based on object attributes refer to + the [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html#filtering-results-with-jmespath) + for more information on how to construct queries. + + Returns: + A list of dictionaries containing information about the objects retrieved. Refer + to the boto3 docs for an example response. + + Example: + List all objects in a bucket: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.s3 import s3_list_objects + + + @flow + async def example_s3_list_objects_flow(): + aws_credentials = AwsCredentials( + aws_access_key_id="acccess_key_id", + aws_secret_access_key="secret_access_key" + ) + objects = await s3_list_objects( + bucket="data_bucket", + aws_credentials=aws_credentials + ) + + example_s3_list_objects_flow() + ``` + """ # noqa E501 + logger = get_run_logger() + logger.info("Listing objects in bucket %s with prefix %s", bucket, prefix) + + s3_client = aws_credentials.get_boto3_session().client( + "s3", **aws_client_parameters.get_params_override() + ) + paginator = s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate( + Bucket=bucket, + Prefix=prefix, + Delimiter=delimiter, + PaginationConfig={"PageSize": page_size, "MaxItems": max_items}, + ) + if jmespath_query: + page_iterator = page_iterator.search(f"{jmespath_query} | {{Contents: @}}") + + return await run_sync_in_worker_thread(_list_objects_sync, page_iterator) + + +class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock): + + """ + Block used to store data using AWS S3 or S3-compatible object storage like MinIO. + + Attributes: + bucket_name: Name of your bucket. + credentials: A block containing your credentials to AWS or MinIO. + bucket_folder: A default path to a folder within the S3 bucket to use + for reading and writing objects. + """ + + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _block_type_name = "S3 Bucket" + _documentation_url = ( + "https://prefecthq.github.io/prefect-aws/s3/#prefect_aws.s3.S3Bucket" # noqa + ) + + bucket_name: str = Field(default=..., description="Name of your bucket.") + + credentials: Union[MinIOCredentials, AwsCredentials] = Field( + default_factory=AwsCredentials, + description="A block containing your credentials to AWS or MinIO.", + ) + + bucket_folder: str = Field( + default="", + description=( + "A default path to a folder within the S3 bucket to use " + "for reading and writing objects." + ), + ) + + # Property to maintain compatibility with storage block based deployments + @property + def basepath(self) -> str: + """ + The base path of the S3 bucket. + + Returns: + str: The base path of the S3 bucket. + """ + return self.bucket_folder + + @basepath.setter + def basepath(self, value: str) -> None: + self.bucket_folder = value + + def _resolve_path(self, path: str) -> str: + """ + A helper function used in write_path to join `self.basepath` and `path`. + + Args: + + path: Name of the key, e.g. "file1". Each object in your + bucket has a unique key (or key name). + + """ + # If bucket_folder provided, it means we won't write to the root dir of + # the bucket. So we need to add it on the front of the path. + # + # AWS object key naming guidelines require '/' for bucket folders. + # Get POSIX path to prevent `pathlib` from inferring '\' on Windows OS + path = ( + (Path(self.bucket_folder) / path).as_posix() if self.bucket_folder else path + ) + + return path + + def _get_s3_client(self) -> boto3.client: + """ + Authenticate MinIO credentials or AWS credentials and return an S3 client. + This is a helper function called by read_path() or write_path(). + """ + return self.credentials.get_client("s3") + + def _get_bucket_resource(self) -> boto3.resource: + """ + Retrieves boto3 resource object for the configured bucket + """ + params_override = self.credentials.aws_client_parameters.get_params_override() + bucket = ( + self.credentials.get_boto3_session() + .resource("s3", **params_override) + .Bucket(self.bucket_name) + ) + return bucket + + @sync_compatible + async def get_directory( + self, from_path: Optional[str] = None, local_path: Optional[str] = None + ) -> None: + """ + Copies a folder from the configured S3 bucket to a local directory. + + Defaults to copying the entire contents of the block's basepath to the current + working directory. + + Args: + from_path: Path in S3 bucket to download from. Defaults to the block's + configured basepath. + local_path: Local path to download S3 contents to. Defaults to the current + working directory. + """ + bucket_folder = self.bucket_folder + if from_path is None: + from_path = str(bucket_folder) if bucket_folder else "" + + if local_path is None: + local_path = str(Path(".").absolute()) + else: + local_path = str(Path(local_path).expanduser()) + + bucket = self._get_bucket_resource() + for obj in bucket.objects.filter(Prefix=from_path): + if obj.key[-1] == "/": + # object is a folder and will be created if it contains any objects + continue + target = os.path.join( + local_path, + os.path.relpath(obj.key, from_path), + ) + os.makedirs(os.path.dirname(target), exist_ok=True) + bucket.download_file(obj.key, target) + + @sync_compatible + async def put_directory( + self, + local_path: Optional[str] = None, + to_path: Optional[str] = None, + ignore_file: Optional[str] = None, + ) -> int: + """ + Uploads a directory from a given local path to the configured S3 bucket in a + given folder. + + Defaults to uploading the entire contents the current working directory to the + block's basepath. + + Args: + local_path: Path to local directory to upload from. + to_path: Path in S3 bucket to upload to. Defaults to block's configured + basepath. + ignore_file: Path to file containing gitignore style expressions for + filepaths to ignore. + + """ + to_path = "" if to_path is None else to_path + + if local_path is None: + local_path = "." + + included_files = None + if ignore_file: + with open(ignore_file, "r") as f: + ignore_patterns = f.readlines() + + included_files = filter_files(local_path, ignore_patterns) + + uploaded_file_count = 0 + for local_file_path in Path(local_path).expanduser().rglob("*"): + if ( + included_files is not None + and str(local_file_path.relative_to(local_path)) not in included_files + ): + continue + elif not local_file_path.is_dir(): + remote_file_path = Path(to_path) / local_file_path.relative_to( + local_path + ) + with open(local_file_path, "rb") as local_file: + local_file_content = local_file.read() + + await self.write_path( + remote_file_path.as_posix(), content=local_file_content + ) + uploaded_file_count += 1 + + return uploaded_file_count + + @sync_compatible + async def read_path(self, path: str) -> bytes: + """ + Read specified path from S3 and return contents. Provide the entire + path to the key in S3. + + Args: + path: Entire path to (and including) the key. + + Example: + Read "subfolder/file1" contents from an S3 bucket named "bucket": + ```python + from prefect_aws import AwsCredentials + from prefect_aws.s3 import S3Bucket + + aws_creds = AwsCredentials( + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY + ) + + s3_bucket_block = S3Bucket( + bucket_name="bucket", + credentials=aws_creds, + bucket_folder="subfolder" + ) + + key_contents = s3_bucket_block.read_path(path="subfolder/file1") + ``` + """ + path = self._resolve_path(path) + + return await run_sync_in_worker_thread(self._read_sync, path) + + def _read_sync(self, key: str) -> bytes: + """ + Called by read_path(). Creates an S3 client and retrieves the + contents from a specified path. + """ + + s3_client = self._get_s3_client() + + with io.BytesIO() as stream: + s3_client.download_fileobj(Bucket=self.bucket_name, Key=key, Fileobj=stream) + stream.seek(0) + output = stream.read() + return output + + @sync_compatible + async def write_path(self, path: str, content: bytes) -> str: + """ + Writes to an S3 bucket. + + Args: + + path: The key name. Each object in your bucket has a unique + key (or key name). + content: What you are uploading to S3. + + Example: + + Write data to the path `dogs/small_dogs/havanese` in an S3 Bucket: + ```python + from prefect_aws import MinioCredentials + from prefect_aws.s3 import S3Bucket + + minio_creds = MinIOCredentials( + minio_root_user = "minioadmin", + minio_root_password = "minioadmin", + ) + + s3_bucket_block = S3Bucket( + bucket_name="bucket", + minio_credentials=minio_creds, + bucket_folder="dogs/smalldogs", + endpoint_url="http://localhost:9000", + ) + s3_havanese_path = s3_bucket_block.write_path(path="havanese", content=data) + ``` + """ + + path = self._resolve_path(path) + + await run_sync_in_worker_thread(self._write_sync, path, content) + + return path + + def _write_sync(self, key: str, data: bytes) -> None: + """ + Called by write_path(). Creates an S3 client and uploads a file + object. + """ + + s3_client = self._get_s3_client() + + with io.BytesIO(data) as stream: + s3_client.upload_fileobj(Fileobj=stream, Bucket=self.bucket_name, Key=key) + + # NEW BLOCK INTERFACE METHODS BELOW + @staticmethod + def _list_objects_sync(page_iterator: PageIterator) -> List[Dict[str, Any]]: + """ + Synchronous method to collect S3 objects into a list + + Args: + page_iterator: AWS Paginator for S3 objects + + Returns: + List[Dict]: List of object information + """ + return [ + content for page in page_iterator for content in page.get("Contents", []) + ] + + def _join_bucket_folder(self, bucket_path: str = "") -> str: + """ + Joins the base bucket folder to the bucket path. + NOTE: If a method reuses another method in this class, be careful to not + call this twice because it'll join the bucket folder twice. + See https://github.com/PrefectHQ/prefect-aws/issues/141 for a past issue. + """ + if not self.bucket_folder and not bucket_path: + # there's a difference between "." and "", at least in the tests + return "" + + bucket_path = str(bucket_path) + if self.bucket_folder != "" and bucket_path.startswith(self.bucket_folder): + self.logger.info( + f"Bucket path {bucket_path!r} is already prefixed with " + f"bucket folder {self.bucket_folder!r}; is this intentional?" + ) + + return (Path(self.bucket_folder) / bucket_path).as_posix() + ( + "" if not bucket_path.endswith("/") else "/" + ) + + @sync_compatible + async def list_objects( + self, + folder: str = "", + delimiter: str = "", + page_size: Optional[int] = None, + max_items: Optional[int] = None, + jmespath_query: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Args: + folder: Folder to list objects from. + delimiter: Character used to group keys of listed objects. + page_size: Number of objects to return in each request to the AWS API. + max_items: Maximum number of objects that to be returned by task. + jmespath_query: Query used to filter objects based on object attributes refer to + the [boto3 docs](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/paginators.html#filtering-results-with-jmespath) + for more information on how to construct queries. + + Returns: + List of objects and their metadata in the bucket. + + Examples: + List objects under the `base_folder`. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.list_objects("base_folder") + ``` + """ # noqa: E501 + bucket_path = self._join_bucket_folder(folder) + client = self.credentials.get_s3_client() + paginator = client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate( + Bucket=self.bucket_name, + Prefix=bucket_path, + Delimiter=delimiter, + PaginationConfig={"PageSize": page_size, "MaxItems": max_items}, + ) + if jmespath_query: + page_iterator = page_iterator.search(f"{jmespath_query} | {{Contents: @}}") + + self.logger.info(f"Listing objects in bucket {bucket_path}.") + objects = await run_sync_in_worker_thread( + self._list_objects_sync, page_iterator + ) + return objects + + @sync_compatible + async def download_object_to_path( + self, + from_path: str, + to_path: Optional[Union[str, Path]], + **download_kwargs: Dict[str, Any], + ) -> Path: + """ + Downloads an object from the S3 bucket to a path. + + Args: + from_path: The path to the object to download; this gets prefixed + with the bucket_folder. + to_path: The path to download the object to. If not provided, the + object's name will be used. + **download_kwargs: Additional keyword arguments to pass to + `Client.download_file`. + + Returns: + The absolute path that the object was downloaded to. + + Examples: + Download my_folder/notes.txt object to notes.txt. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.download_object_to_path("my_folder/notes.txt", "notes.txt") + ``` + """ + if to_path is None: + to_path = Path(from_path).name + + # making path absolute, but converting back to str here + # since !r looks nicer that way and filename arg expects str + to_path = str(Path(to_path).absolute()) + bucket_path = self._join_bucket_folder(from_path) + client = self.credentials.get_s3_client() + + self.logger.debug( + f"Preparing to download object from bucket {self.bucket_name!r} " + f"path {bucket_path!r} to {to_path!r}." + ) + await run_sync_in_worker_thread( + client.download_file, + Bucket=self.bucket_name, + Key=from_path, + Filename=to_path, + **download_kwargs, + ) + self.logger.info( + f"Downloaded object from bucket {self.bucket_name!r} path {bucket_path!r} " + f"to {to_path!r}." + ) + return Path(to_path) + + @sync_compatible + async def download_object_to_file_object( + self, + from_path: str, + to_file_object: BinaryIO, + **download_kwargs: Dict[str, Any], + ) -> BinaryIO: + """ + Downloads an object from the object storage service to a file-like object, + which can be a BytesIO object or a BufferedWriter. + + Args: + from_path: The path to the object to download from; this gets prefixed + with the bucket_folder. + to_file_object: The file-like object to download the object to. + **download_kwargs: Additional keyword arguments to pass to + `Client.download_fileobj`. + + Returns: + The file-like object that the object was downloaded to. + + Examples: + Download my_folder/notes.txt object to a BytesIO object. + ```python + from io import BytesIO + + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + with BytesIO() as buf: + s3_bucket.download_object_to_file_object("my_folder/notes.txt", buf) + ``` + + Download my_folder/notes.txt object to a BufferedWriter. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + with open("notes.txt", "wb") as f: + s3_bucket.download_object_to_file_object("my_folder/notes.txt", f) + ``` + """ + client = self.credentials.get_s3_client() + bucket_path = self._join_bucket_folder(from_path) + + self.logger.debug( + f"Preparing to download object from bucket {self.bucket_name!r} " + f"path {bucket_path!r} to file object." + ) + await run_sync_in_worker_thread( + client.download_fileobj, + Bucket=self.bucket_name, + Key=bucket_path, + Fileobj=to_file_object, + **download_kwargs, + ) + self.logger.info( + f"Downloaded object from bucket {self.bucket_name!r} path {bucket_path!r} " + "to file object." + ) + return to_file_object + + @sync_compatible + async def download_folder_to_path( + self, + from_folder: str, + to_folder: Optional[Union[str, Path]] = None, + **download_kwargs: Dict[str, Any], + ) -> Path: + """ + Downloads objects *within* a folder (excluding the folder itself) + from the S3 bucket to a folder. + + Args: + from_folder: The path to the folder to download from. + to_folder: The path to download the folder to. + **download_kwargs: Additional keyword arguments to pass to + `Client.download_file`. + + Returns: + The absolute path that the folder was downloaded to. + + Examples: + Download my_folder to a local folder named my_folder. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.download_folder_to_path("my_folder", "my_folder") + ``` + """ + if to_folder is None: + to_folder = "" + to_folder = Path(to_folder).absolute() + + client = self.credentials.get_s3_client() + objects = await self.list_objects(folder=from_folder) + + # do not call self._join_bucket_folder for filter + # because it's built-in to that method already! + # however, we still need to do it because we're using relative_to + bucket_folder = self._join_bucket_folder(from_folder) + + async_coros = [] + for object in objects: + bucket_path = Path(object["Key"]).relative_to(bucket_folder) + # this skips the actual directory itself, e.g. + # `my_folder/` will be skipped + # `my_folder/notes.txt` will be downloaded + if bucket_path.is_dir(): + continue + to_path = to_folder / bucket_path + to_path.parent.mkdir(parents=True, exist_ok=True) + to_path = str(to_path) # must be string + self.logger.info( + f"Downloading object from bucket {self.bucket_name!r} path " + f"{bucket_path.as_posix()!r} to {to_path!r}." + ) + async_coros.append( + run_sync_in_worker_thread( + client.download_file, + Bucket=self.bucket_name, + Key=object["Key"], + Filename=to_path, + **download_kwargs, + ) + ) + await asyncio.gather(*async_coros) + + return Path(to_folder) + + @sync_compatible + async def stream_from( + self, + bucket: "S3Bucket", + from_path: str, + to_path: Optional[str] = None, + **upload_kwargs: Dict[str, Any], + ) -> str: + """Streams an object from another bucket to this bucket. Requires the + object to be downloaded and uploaded in chunks. If `self`'s credentials + allow for writes to the other bucket, try using `S3Bucket.copy_object`. + + Args: + bucket: The bucket to stream from. + from_path: The path of the object to stream. + to_path: The path to stream the object to. Defaults to the object's name. + **upload_kwargs: Additional keyword arguments to pass to + `Client.upload_fileobj`. + + Returns: + The path that the object was uploaded to. + + Examples: + Stream notes.txt from your-bucket/notes.txt to my-bucket/landed/notes.txt. + + ```python + from prefect_aws.s3 import S3Bucket + + your_s3_bucket = S3Bucket.load("your-bucket") + my_s3_bucket = S3Bucket.load("my-bucket") + + my_s3_bucket.stream_from( + your_s3_bucket, + "notes.txt", + to_path="landed/notes.txt" + ) + ``` + + """ + if to_path is None: + to_path = Path(from_path).name + + # Get the source object's StreamingBody + from_path: str = bucket._join_bucket_folder(from_path) + from_client = bucket.credentials.get_s3_client() + obj = await run_sync_in_worker_thread( + from_client.get_object, Bucket=bucket.bucket_name, Key=from_path + ) + body: StreamingBody = obj["Body"] + + # Upload the StreamingBody to this bucket + bucket_path = str(self._join_bucket_folder(to_path)) + to_client = self.credentials.get_s3_client() + await run_sync_in_worker_thread( + to_client.upload_fileobj, + Fileobj=body, + Bucket=self.bucket_name, + Key=bucket_path, + **upload_kwargs, + ) + self.logger.info( + f"Streamed s3://{bucket.bucket_name}/{from_path} to the bucket " + f"{self.bucket_name!r} path {bucket_path!r}." + ) + return bucket_path + + @sync_compatible + async def upload_from_path( + self, + from_path: Union[str, Path], + to_path: Optional[str] = None, + **upload_kwargs: Dict[str, Any], + ) -> str: + """ + Uploads an object from a path to the S3 bucket. + + Args: + from_path: The path to the file to upload from. + to_path: The path to upload the file to. + **upload_kwargs: Additional keyword arguments to pass to `Client.upload`. + + Returns: + The path that the object was uploaded to. + + Examples: + Upload notes.txt to my_folder/notes.txt. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.upload_from_path("notes.txt", "my_folder/notes.txt") + ``` + """ + from_path = str(Path(from_path).absolute()) + if to_path is None: + to_path = Path(from_path).name + + bucket_path = str(self._join_bucket_folder(to_path)) + client = self.credentials.get_s3_client() + + await run_sync_in_worker_thread( + client.upload_file, + Filename=from_path, + Bucket=self.bucket_name, + Key=bucket_path, + **upload_kwargs, + ) + self.logger.info( + f"Uploaded from {from_path!r} to the bucket " + f"{self.bucket_name!r} path {bucket_path!r}." + ) + return bucket_path + + @sync_compatible + async def upload_from_file_object( + self, from_file_object: BinaryIO, to_path: str, **upload_kwargs: Dict[str, Any] + ) -> str: + """ + Uploads an object to the S3 bucket from a file-like object, + which can be a BytesIO object or a BufferedReader. + + Args: + from_file_object: The file-like object to upload from. + to_path: The path to upload the object to. + **upload_kwargs: Additional keyword arguments to pass to + `Client.upload_fileobj`. + + Returns: + The path that the object was uploaded to. + + Examples: + Upload BytesIO object to my_folder/notes.txt. + ```python + from io import BytesIO + + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + with open("notes.txt", "rb") as f: + s3_bucket.upload_from_file_object(f, "my_folder/notes.txt") + ``` + + Upload BufferedReader object to my_folder/notes.txt. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + with open("notes.txt", "rb") as f: + s3_bucket.upload_from_file_object( + f, "my_folder/notes.txt" + ) + ``` + """ + bucket_path = str(self._join_bucket_folder(to_path)) + client = self.credentials.get_s3_client() + await run_sync_in_worker_thread( + client.upload_fileobj, + Fileobj=from_file_object, + Bucket=self.bucket_name, + Key=bucket_path, + **upload_kwargs, + ) + self.logger.info( + "Uploaded from file object to the bucket " + f"{self.bucket_name!r} path {bucket_path!r}." + ) + return bucket_path + + @sync_compatible + async def upload_from_folder( + self, + from_folder: Union[str, Path], + to_folder: Optional[str] = None, + **upload_kwargs: Dict[str, Any], + ) -> str: + """ + Uploads files *within* a folder (excluding the folder itself) + to the object storage service folder. + + Args: + from_folder: The path to the folder to upload from. + to_folder: The path to upload the folder to. + **upload_kwargs: Additional keyword arguments to pass to + `Client.upload_fileobj`. + + Returns: + The path that the folder was uploaded to. + + Examples: + Upload contents from my_folder to new_folder. + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.upload_from_folder("my_folder", "new_folder") + ``` + """ + from_folder = Path(from_folder) + bucket_folder = self._join_bucket_folder(to_folder or "") + + num_uploaded = 0 + client = self.credentials.get_s3_client() + + async_coros = [] + for from_path in from_folder.rglob("**/*"): + # this skips the actual directory itself, e.g. + # `my_folder/` will be skipped + # `my_folder/notes.txt` will be uploaded + if from_path.is_dir(): + continue + bucket_path = ( + Path(bucket_folder) / from_path.relative_to(from_folder) + ).as_posix() + self.logger.info( + f"Uploading from {str(from_path)!r} to the bucket " + f"{self.bucket_name!r} path {bucket_path!r}." + ) + async_coros.append( + run_sync_in_worker_thread( + client.upload_file, + Filename=str(from_path), + Bucket=self.bucket_name, + Key=bucket_path, + **upload_kwargs, + ) + ) + num_uploaded += 1 + await asyncio.gather(*async_coros) + + if num_uploaded == 0: + self.logger.warning(f"No files were uploaded from {str(from_folder)!r}.") + else: + self.logger.info( + f"Uploaded {num_uploaded} files from {str(from_folder)!r} to " + f"the bucket {self.bucket_name!r} path {bucket_path!r}" + ) + + return to_folder + + @sync_compatible + async def copy_object( + self, + from_path: Union[str, Path], + to_path: Union[str, Path], + to_bucket: Optional[Union["S3Bucket", str]] = None, + **copy_kwargs, + ) -> str: + """Uses S3's internal + [CopyObject](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html) + to copy objects within or between buckets. To copy objects between buckets, + `self`'s credentials must have permission to read the source object and write + to the target object. If the credentials do not have those permissions, try + using `S3Bucket.stream_from`. + + Args: + from_path: The path of the object to copy. + to_path: The path to copy the object to. + to_bucket: The bucket to copy to. Defaults to the current bucket. + **copy_kwargs: Additional keyword arguments to pass to + `S3Client.copy_object`. + + Returns: + The path that the object was copied to. Excludes the bucket name. + + Examples: + + Copy notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.copy_object("my_folder/notes.txt", "my_folder/notes_copy.txt") + ``` + + Copy notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt in + another bucket. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.copy_object( + "my_folder/notes.txt", + "my_folder/notes_copy.txt", + to_bucket="other-bucket" + ) + ``` + """ + s3_client = self.credentials.get_s3_client() + + source_bucket_name = self.bucket_name + source_path = self._resolve_path(Path(from_path).as_posix()) + + # Default to copying within the same bucket + to_bucket = to_bucket or self + + target_bucket_name: str + target_path: str + if isinstance(to_bucket, S3Bucket): + target_bucket_name = to_bucket.bucket_name + target_path = to_bucket._resolve_path(Path(to_path).as_posix()) + elif isinstance(to_bucket, str): + target_bucket_name = to_bucket + target_path = Path(to_path).as_posix() + else: + raise TypeError( + f"to_bucket must be a string or S3Bucket, not {type(to_bucket)}" + ) + + self.logger.info( + "Copying object from bucket %s with key %s to bucket %s with key %s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + s3_client.copy_object( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + **copy_kwargs, + ) + + return target_path + + @sync_compatible + async def move_object( + self, + from_path: Union[str, Path], + to_path: Union[str, Path], + to_bucket: Optional[Union["S3Bucket", str]] = None, + ) -> str: + """Uses S3's internal CopyObject and DeleteObject to move objects within or + between buckets. To move objects between buckets, `self`'s credentials must + have permission to read and delete the source object and write to the target + object. If the credentials do not have those permissions, this method will + raise an error. If the credentials have permission to read the source object + but not delete it, the object will be copied but not deleted. + + Args: + from_path: The path of the object to move. + to_path: The path to move the object to. + to_bucket: The bucket to move to. Defaults to the current bucket. + + Returns: + The path that the object was moved to. Excludes the bucket name. + + Examples: + + Move notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.move_object("my_folder/notes.txt", "my_folder/notes_copy.txt") + ``` + + Move notes.txt from my_folder/notes.txt to my_folder/notes_copy.txt in + another bucket. + + ```python + from prefect_aws.s3 import S3Bucket + + s3_bucket = S3Bucket.load("my-bucket") + s3_bucket.move_object( + "my_folder/notes.txt", + "my_folder/notes_copy.txt", + to_bucket="other-bucket" + ) + ``` + """ + s3_client = self.credentials.get_s3_client() + + source_bucket_name = self.bucket_name + source_path = self._resolve_path(Path(from_path).as_posix()) + + # Default to moving within the same bucket + to_bucket = to_bucket or self + + target_bucket_name: str + target_path: str + if isinstance(to_bucket, S3Bucket): + target_bucket_name = to_bucket.bucket_name + target_path = to_bucket._resolve_path(Path(to_path).as_posix()) + elif isinstance(to_bucket, str): + target_bucket_name = to_bucket + target_path = Path(to_path).as_posix() + else: + raise TypeError( + f"to_bucket must be a string or S3Bucket, not {type(to_bucket)}" + ) + + self.logger.info( + "Moving object from s3://%s/%s to s3://%s/%s", + source_bucket_name, + source_path, + target_bucket_name, + target_path, + ) + + # If invalid, should error and prevent next operation + s3_client.copy( + CopySource={"Bucket": source_bucket_name, "Key": source_path}, + Bucket=target_bucket_name, + Key=target_path, + ) + s3_client.delete_object(Bucket=source_bucket_name, Key=source_path) + return target_path diff --git a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py new file mode 100644 index 000000000000..a3af406b537c --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py @@ -0,0 +1,517 @@ +"""Tasks for interacting with AWS Secrets Manager""" + +from typing import Any, Dict, List, Optional, Union + +from botocore.exceptions import ClientError +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect import get_run_logger, task +from prefect.blocks.abstract import SecretBlock +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field + +from prefect_aws import AwsCredentials + + +@task +async def read_secret( + secret_name: str, + aws_credentials: AwsCredentials, + version_id: Optional[str] = None, + version_stage: Optional[str] = None, +) -> Union[str, bytes]: + """ + Reads the value of a given secret from AWS Secrets Manager. + + Args: + secret_name: Name of stored secret. + aws_credentials: Credentials to use for authentication with AWS. + version_id: Specifies version of secret to read. Defaults to the most recent + version if not given. + version_stage: Specifies the version stage of the secret to read. Defaults to + AWS_CURRENT if not given. + + Returns: + The secret values as a `str` or `bytes` depending on the format in which the + secret was stored. + + Example: + Read a secret value: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.secrets_manager import read_secret + + @flow + def example_read_secret(): + aws_credentials = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key" + ) + secret_value = read_secret( + secret_name="db_password", + aws_credentials=aws_credentials + ) + + example_read_secret() + ``` + """ + logger = get_run_logger() + logger.info("Getting value for secret %s", secret_name) + + client = aws_credentials.get_boto3_session().client("secretsmanager") + + get_secret_value_kwargs = dict(SecretId=secret_name) + if version_id is not None: + get_secret_value_kwargs["VersionId"] = version_id + if version_stage is not None: + get_secret_value_kwargs["VersionStage"] = version_stage + + try: + response = await run_sync_in_worker_thread( + client.get_secret_value, **get_secret_value_kwargs + ) + except ClientError: + logger.exception("Unable to get value for secret %s", secret_name) + raise + else: + return response.get("SecretString") or response.get("SecretBinary") + + +@task +async def update_secret( + secret_name: str, + secret_value: Union[str, bytes], + aws_credentials: AwsCredentials, + description: Optional[str] = None, +) -> Dict[str, str]: + """ + Updates the value of a given secret in AWS Secrets Manager. + + Args: + secret_name: Name of secret to update. + secret_value: Desired value of the secret. Can be either `str` or `bytes`. + aws_credentials: Credentials to use for authentication with AWS. + description: Desired description of the secret. + + Returns: + A dict containing the secret ARN (Amazon Resource Name), + name, and current version ID. + ```python + { + "ARN": str, + "Name": str, + "VersionId": str + } + ``` + + Example: + Update a secret value: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.secrets_manager import update_secret + + @flow + def example_update_secret(): + aws_credentials = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key" + ) + update_secret( + secret_name="life_the_universe_and_everything", + secret_value="42", + aws_credentials=aws_credentials + ) + + example_update_secret() + ``` + + """ + update_secret_kwargs: Dict[str, Union[str, bytes]] = dict(SecretId=secret_name) + if description is not None: + update_secret_kwargs["Description"] = description + if isinstance(secret_value, bytes): + update_secret_kwargs["SecretBinary"] = secret_value + elif isinstance(secret_value, str): + update_secret_kwargs["SecretString"] = secret_value + else: + raise ValueError("Please provide a bytes or str value for secret_value") + + logger = get_run_logger() + logger.info("Updating value for secret %s", secret_name) + + client = aws_credentials.get_boto3_session().client("secretsmanager") + + try: + response = await run_sync_in_worker_thread( + client.update_secret, **update_secret_kwargs + ) + response.pop("ResponseMetadata", None) + return response + except ClientError: + logger.exception("Unable to update secret %s", secret_name) + raise + + +@task +async def create_secret( + secret_name: str, + secret_value: Union[str, bytes], + aws_credentials: AwsCredentials, + description: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, +) -> Dict[str, str]: + """ + Creates a secret in AWS Secrets Manager. + + Args: + secret_name: The name of the secret to create. + secret_value: The value to store in the created secret. + aws_credentials: Credentials to use for authentication with AWS. + description: A description for the created secret. + tags: A list of tags to attach to the secret. Each tag should be specified as a + dictionary in the following format: + ```python + { + "Key": str, + "Value": str + } + ``` + + Returns: + A dict containing the secret ARN (Amazon Resource Name), + name, and current version ID. + ```python + { + "ARN": str, + "Name": str, + "VersionId": str + } + ``` + Example: + Create a secret: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.secrets_manager import create_secret + + @flow + def example_create_secret(): + aws_credentials = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key" + ) + create_secret( + secret_name="life_the_universe_and_everything", + secret_value="42", + aws_credentials=aws_credentials + ) + + example_create_secret() + ``` + + + """ + create_secret_kwargs: Dict[str, Union[str, bytes, List[Dict[str, str]]]] = dict( + Name=secret_name + ) + if description is not None: + create_secret_kwargs["Description"] = description + if tags is not None: + create_secret_kwargs["Tags"] = tags + if isinstance(secret_value, bytes): + create_secret_kwargs["SecretBinary"] = secret_value + elif isinstance(secret_value, str): + create_secret_kwargs["SecretString"] = secret_value + else: + raise ValueError("Please provide a bytes or str value for secret_value") + + logger = get_run_logger() + logger.info("Creating secret named %s", secret_name) + + client = aws_credentials.get_boto3_session().client("secretsmanager") + + try: + response = await run_sync_in_worker_thread( + client.create_secret, **create_secret_kwargs + ) + print(response.pop("ResponseMetadata", None)) + return response + except ClientError: + logger.exception("Unable to create secret %s", secret_name) + raise + + +@task +async def delete_secret( + secret_name: str, + aws_credentials: AwsCredentials, + recovery_window_in_days: int = 30, + force_delete_without_recovery: bool = False, +) -> Dict[str, str]: + """ + Deletes a secret from AWS Secrets Manager. + + Secrets can either be deleted immediately by setting `force_delete_without_recovery` + equal to `True`. Otherwise, secrets will be marked for deletion and available for + recovery for the number of days specified in `recovery_window_in_days` + + Args: + secret_name: Name of the secret to be deleted. + aws_credentials: Credentials to use for authentication with AWS. + recovery_window_in_days: Number of days a secret should be recoverable for + before permanent deletion. Minimum window is 7 days and maximum window + is 30 days. If `force_delete_without_recovery` is set to `True`, this + value will be ignored. + force_delete_without_recovery: If `True`, the secret will be immediately + deleted and will not be recoverable. + + Returns: + A dict containing the secret ARN (Amazon Resource Name), + name, and deletion date of the secret. DeletionDate is the date and + time of the delete request plus the number of days in + `recovery_window_in_days`. + ```python + { + "ARN": str, + "Name": str, + "DeletionDate": datetime.datetime + } + ``` + + Examples: + Delete a secret immediately: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.secrets_manager import delete_secret + + @flow + def example_delete_secret_immediately(): + aws_credentials = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key" + ) + delete_secret( + secret_name="life_the_universe_and_everything", + aws_credentials=aws_credentials, + force_delete_without_recovery: True + ) + + example_delete_secret_immediately() + ``` + + Delete a secret with a 90 day recovery window: + + ```python + from prefect import flow + from prefect_aws import AwsCredentials + from prefect_aws.secrets_manager import delete_secret + + @flow + def example_delete_secret_with_recovery_window(): + aws_credentials = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key" + ) + delete_secret( + secret_name="life_the_universe_and_everything", + aws_credentials=aws_credentials, + recovery_window_in_days=90 + ) + + example_delete_secret_with_recovery_window() + ``` + + + """ + if not force_delete_without_recovery and not (7 <= recovery_window_in_days <= 30): + raise ValueError("Recovery window must be between 7 and 30 days.") + + delete_secret_kwargs: Dict[str, Union[str, int, bool]] = dict(SecretId=secret_name) + if force_delete_without_recovery: + delete_secret_kwargs[ + "ForceDeleteWithoutRecovery" + ] = force_delete_without_recovery + else: + delete_secret_kwargs["RecoveryWindowInDays"] = recovery_window_in_days + + logger = get_run_logger() + logger.info("Deleting secret %s", secret_name) + + client = aws_credentials.get_boto3_session().client("secretsmanager") + + try: + response = await run_sync_in_worker_thread( + client.delete_secret, **delete_secret_kwargs + ) + response.pop("ResponseMetadata", None) + return response + except ClientError: + logger.exception("Unable to delete secret %s", secret_name) + raise + + +class AwsSecret(SecretBlock): + """ + Manages a secret in AWS's Secrets Manager. + + Attributes: + aws_credentials: The credentials to use for authentication with AWS. + secret_name: The name of the secret. + """ + + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _block_type_name = "AWS Secret" + _documentation_url = "https://prefecthq.github.io/prefect-aws/secrets_manager/#prefect_aws.secrets_manager.AwsSecret" # noqa + + aws_credentials: AwsCredentials + secret_name: str = Field(default=..., description="The name of the secret.") + + @sync_compatible + async def read_secret( + self, + version_id: str = None, + version_stage: str = None, + **read_kwargs: Dict[str, Any], + ) -> bytes: + """ + Reads the secret from the secret storage service. + + Args: + version_id: The version of the secret to read. If not provided, the latest + version will be read. + version_stage: The version stage of the secret to read. If not provided, + the latest version will be read. + read_kwargs: Additional keyword arguments to pass to the + `get_secret_value` method of the boto3 client. + + Returns: + The secret data. + + Examples: + Reads a secret. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.read_secret() + ``` + """ + client = self.aws_credentials.get_secrets_manager_client() + if version_id is not None: + read_kwargs["VersionId"] = version_id + if version_stage is not None: + read_kwargs["VersionStage"] = version_stage + response = await run_sync_in_worker_thread( + client.get_secret_value, SecretId=self.secret_name, **read_kwargs + ) + if "SecretBinary" in response: + secret = response["SecretBinary"] + elif "SecretString" in response: + secret = response["SecretString"] + arn = response["ARN"] + self.logger.info(f"The secret {arn!r} data was successfully read.") + return secret + + @sync_compatible + async def write_secret( + self, secret_data: bytes, **put_or_create_secret_kwargs: Dict[str, Any] + ) -> str: + """ + Writes the secret to the secret storage service as a SecretBinary; + if it doesn't exist, it will be created. + + Args: + secret_data: The secret data to write. + **put_or_create_secret_kwargs: Additional keyword arguments to pass to + put_secret_value or create_secret method of the boto3 client. + + Returns: + The path that the secret was written to. + + Examples: + Write some secret data. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.write_secret(b"my_secret_data") + ``` + """ + client = self.aws_credentials.get_secrets_manager_client() + try: + response = await run_sync_in_worker_thread( + client.put_secret_value, + SecretId=self.secret_name, + SecretBinary=secret_data, + **put_or_create_secret_kwargs, + ) + except client.exceptions.ResourceNotFoundException: + self.logger.info( + f"The secret {self.secret_name!r} does not exist yet, creating it now." + ) + response = await run_sync_in_worker_thread( + client.create_secret, + Name=self.secret_name, + SecretBinary=secret_data, + **put_or_create_secret_kwargs, + ) + arn = response["ARN"] + self.logger.info(f"The secret data was written successfully to {arn!r}.") + return arn + + @sync_compatible + async def delete_secret( + self, + recovery_window_in_days: int = 30, + force_delete_without_recovery: bool = False, + **delete_kwargs: Dict[str, Any], + ) -> str: + """ + Deletes the secret from the secret storage service. + + Args: + recovery_window_in_days: The number of days to wait before permanently + deleting the secret. Must be between 7 and 30 days. + force_delete_without_recovery: If True, the secret will be deleted + immediately without a recovery window. + **delete_kwargs: Additional keyword arguments to pass to the + delete_secret method of the boto3 client. + + Returns: + The path that the secret was deleted from. + + Examples: + Deletes the secret with a recovery window of 15 days. + ```python + secrets_manager = SecretsManager.load("MY_BLOCK") + secrets_manager.delete_secret(recovery_window_in_days=15) + ``` + """ + if force_delete_without_recovery and recovery_window_in_days: + raise ValueError( + "Cannot specify recovery window and force delete without recovery." + ) + elif not (7 <= recovery_window_in_days <= 30): + raise ValueError( + "Recovery window must be between 7 and 30 days, got " + f"{recovery_window_in_days}." + ) + + client = self.aws_credentials.get_secrets_manager_client() + response = await run_sync_in_worker_thread( + client.delete_secret, + SecretId=self.secret_name, + RecoveryWindowInDays=recovery_window_in_days, + ForceDeleteWithoutRecovery=force_delete_without_recovery, + **delete_kwargs, + ) + arn = response["ARN"] + self.logger.info(f"The secret {arn} was deleted successfully.") + return arn diff --git a/src/integrations/prefect-aws/prefect_aws/utilities.py b/src/integrations/prefect-aws/prefect_aws/utilities.py new file mode 100644 index 000000000000..b57c88ee869a --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/utilities.py @@ -0,0 +1,116 @@ +"""Utilities for working with AWS services.""" + +from typing import Dict, List, Union + +from prefect.utilities.collections import visit_collection + + +def hash_collection(collection) -> int: + """Use visit_collection to transform and hash a collection. + + Args: + collection (Any): The collection to hash. + + Returns: + int: The hash of the transformed collection. + + Example: + ```python + from prefect_aws.utilities import hash_collection + + hash_collection({"a": 1, "b": 2}) + ``` + + """ + + def make_hashable(item): + """Make an item hashable by converting it to a tuple.""" + if isinstance(item, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in item.items())) + elif isinstance(item, list): + return tuple(make_hashable(v) for v in item) + return item + + hashable_collection = visit_collection( + collection, visit_fn=make_hashable, return_data=True + ) + return hash(hashable_collection) + + +def ensure_path_exists(doc: Union[Dict, List], path: List[str]): + """ + Ensures the path exists in the document, creating empty dictionaries or lists as + needed. + + Args: + doc: The current level of the document or sub-document. + path: The remaining path parts to ensure exist. + """ + if not path: + return + current_path = path.pop(0) + # Check if the next path part exists and is a digit + next_path_is_digit = path and path[0].isdigit() + + # Determine if the current path is for an array or an object + if isinstance(doc, list): # Path is for an array index + current_path = int(current_path) + # Ensure the current level of the document is a list and long enough + + while len(doc) <= current_path: + doc.append({}) + next_level = doc[current_path] + else: # Path is for an object + if current_path not in doc or ( + next_path_is_digit and not isinstance(doc.get(current_path), list) + ): + doc[current_path] = [] if next_path_is_digit else {} + next_level = doc[current_path] + + ensure_path_exists(next_level, path) + + +def assemble_document_for_patches(patches): + """ + Assembles an initial document that can successfully accept the given JSON Patch + operations. + + Args: + patches: A list of JSON Patch operations. + + Returns: + An initial document structured to accept the patches. + + Example: + + ```python + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"} + ] + + initial_document = assemble_document_for_patches(patches) + + #output + { + "name": {}, + "contact": {}, + "age": {} + } + ``` + """ + document = {} + + for patch in patches: + operation = patch["op"] + path = patch["path"].lstrip("/").split("/") + + if operation == "add": + # Ensure all but the last element of the path exists + ensure_path_exists(document, path[:-1]) + elif operation in ["remove", "replace"]: + # For remove and replace, the entire path should exist + ensure_path_exists(document, path) + + return document diff --git a/src/integrations/prefect-aws/prefect_aws/workers/__init__.py b/src/integrations/prefect-aws/prefect_aws/workers/__init__.py new file mode 100644 index 000000000000..c1f409ea855c --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/workers/__init__.py @@ -0,0 +1,3 @@ +from .ecs_worker import ECSWorker + +__all__ = ["ECSWorker"] diff --git a/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py new file mode 100644 index 000000000000..b1216d72b3f6 --- /dev/null +++ b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py @@ -0,0 +1,1737 @@ +""" +Prefect worker for executing flow runs as ECS tasks. + +Get started by creating a work pool: + +``` +$ prefect work-pool create --type ecs my-ecs-pool +``` + +Then, you can start a worker for the pool: + +``` +$ prefect worker start --pool my-ecs-pool +``` + +It's common to deploy the worker as an ECS task as well. However, you can run the worker +locally to get started. + +The worker may work without any additional configuration, but it is dependent on your +specific AWS setup and we'd recommend opening the work pool editor in the UI to see the +available options. + +By default, the worker will register a task definition for each flow run and run a task +in your default ECS cluster using AWS Fargate. Fargate requires tasks to configure +subnets, which we will infer from your default VPC. If you do not have a default VPC, +you must provide a VPC ID or manually setup the network configuration for your tasks. + +Note, the worker caches task definitions for each deployment to avoid excessive +registration. The worker will check that the cached task definition is compatible with +your configuration before using it. + +The launch type option can be used to run your tasks in different modes. For example, +`FARGATE_SPOT` can be used to use spot instances for your Fargate tasks or `EC2` can be +used to run your tasks on a cluster backed by EC2 instances. + +Generally, it is very useful to enable CloudWatch logging for your ECS tasks; this can +help you debug task failures. To enable CloudWatch logging, you must provide an +execution role ARN with permissions to create and write to log streams. See the +`configure_cloudwatch_logs` field documentation for details. + +The worker can be configured to use an existing task definition by setting the task +definition arn variable or by providing a "taskDefinition" in the task run request. When +a task definition is provided, the worker will never create a new task definition which +may result in variables that are templated into the task definition payload being +ignored. +""" +import copy +import json +import logging +import shlex +import sys +import time +from copy import deepcopy +from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union +from uuid import UUID + +import anyio +import anyio.abc +import yaml +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound +from prefect.server.schemas.core import FlowRun +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.dockerutils import get_prefect_image_name +from prefect.workers.base import ( + BaseJobConfiguration, + BaseVariables, + BaseWorker, + BaseWorkerResult, +) + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import BaseModel, Field, root_validator +else: + from pydantic import BaseModel, Field, root_validator + +from slugify import slugify +from tenacity import retry, stop_after_attempt, wait_fixed, wait_random +from typing_extensions import Literal + +from prefect_aws.credentials import AwsCredentials, ClientType + +# Internal type alias for ECS clients which are generated dynamically in botocore +_ECSClient = Any + +ECS_DEFAULT_CONTAINER_NAME = "prefect" +ECS_DEFAULT_CPU = 1024 +ECS_DEFAULT_COMMAND = "python -m prefect.engine" +ECS_DEFAULT_MEMORY = 2048 +ECS_DEFAULT_LAUNCH_TYPE = "FARGATE" +ECS_DEFAULT_FAMILY = "prefect" +ECS_POST_REGISTRATION_FIELDS = [ + "compatibilities", + "taskDefinitionArn", + "revision", + "status", + "requiresAttributes", + "registeredAt", + "registeredBy", + "deregisteredAt", +] + + +DEFAULT_TASK_DEFINITION_TEMPLATE = """ +containerDefinitions: +- image: "{{ image }}" + name: "{{ container_name }}" +cpu: "{{ cpu }}" +family: "{{ family }}" +memory: "{{ memory }}" +executionRoleArn: "{{ execution_role_arn }}" +""" + +DEFAULT_TASK_RUN_REQUEST_TEMPLATE = """ +launchType: "{{ launch_type }}" +cluster: "{{ cluster }}" +overrides: + containerOverrides: + - name: "{{ container_name }}" + command: "{{ command }}" + environment: "{{ env }}" + cpu: "{{ cpu }}" + memory: "{{ memory }}" + cpu: "{{ cpu }}" + memory: "{{ memory }}" + taskRoleArn: "{{ task_role_arn }}" +tags: "{{ labels }}" +taskDefinition: "{{ task_definition_arn }}" +capacityProviderStrategy: "{{ capacity_provider_strategy }}" +""" + +# Create task run retry settings +MAX_CREATE_TASK_RUN_ATTEMPTS = 3 +CREATE_TASK_RUN_MIN_DELAY_SECONDS = 1 +CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS = 0 +CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS = 3 + +_TASK_DEFINITION_CACHE: Dict[UUID, str] = {} +_TAG_REGEX = r"[^a-zA-Z0-9-_.=+-@: ]+" + + +class ECSIdentifier(NamedTuple): + """ + The identifier for a running ECS task. + """ + + cluster: str + task_arn: str + + +def _default_task_definition_template() -> dict: + """ + The default task definition template for ECS jobs. + """ + return yaml.safe_load(DEFAULT_TASK_DEFINITION_TEMPLATE) + + +def _default_task_run_request_template() -> dict: + """ + The default task run request template for ECS jobs. + """ + return yaml.safe_load(DEFAULT_TASK_RUN_REQUEST_TEMPLATE) + + +def _drop_empty_keys_from_task_definition(taskdef: dict): + """ + Recursively drop keys with 'empty' values from a task definition dict. + + Mutates the task definition in place. Only supports recursion into dicts and lists. + """ + for key, value in tuple(taskdef.items()): + if not value: + taskdef.pop(key) + if isinstance(value, dict): + _drop_empty_keys_from_task_definition(value) + if isinstance(value, list): + for v in value: + if isinstance(v, dict): + _drop_empty_keys_from_task_definition(v) + + +def _get_container(containers: List[dict], name: str) -> Optional[dict]: + """ + Extract a container from a list of containers or container definitions. + If not found, `None` is returned. + """ + for container in containers: + if container.get("name") == name: + return container + return None + + +def _container_name_from_task_definition(task_definition: dict) -> Optional[str]: + """ + Attempt to infer the container name from a task definition. + + If not found, `None` is returned. + """ + if task_definition: + container_definitions = task_definition.get("containerDefinitions", []) + else: + container_definitions = [] + + if _get_container(container_definitions, ECS_DEFAULT_CONTAINER_NAME): + # Use the default container name if present + return ECS_DEFAULT_CONTAINER_NAME + elif container_definitions: + # Otherwise, if there's at least one container definition try to get the + # name from that + return container_definitions[0].get("name") + + return None + + +def parse_identifier(identifier: str) -> ECSIdentifier: + """ + Splits identifier into its cluster and task components, e.g. + input "cluster_name::task_arn" outputs ("cluster_name", "task_arn"). + """ + cluster, task = identifier.split("::", maxsplit=1) + return ECSIdentifier(cluster, task) + + +def mask_sensitive_env_values( + task_run_request: dict, values: List[str], keep_length=3, replace_with="***" +): + for container in task_run_request.get("overrides", {}).get( + "containerOverrides", [] + ): + for env_var in container.get("environment", []): + if ( + "name" not in env_var + or "value" not in env_var + or env_var["name"] not in values + ): + continue + if len(env_var["value"]) > keep_length: + # Replace characters beyond the keep length + env_var["value"] = env_var["value"][:keep_length] + replace_with + return task_run_request + + +def mask_api_key(task_run_request): + return mask_sensitive_env_values( + deepcopy(task_run_request), ["PREFECT_API_KEY"], keep_length=6 + ) + + +class CapacityProvider(BaseModel): + """ + The capacity provider strategy to use when running the task. + """ + + capacityProvider: str + weight: int + base: int + + +class ECSJobConfiguration(BaseJobConfiguration): + """ + Job configuration for an ECS worker. + """ + + aws_credentials: Optional[AwsCredentials] = Field(default_factory=AwsCredentials) + task_definition: Optional[Dict[str, Any]] = Field( + template=_default_task_definition_template() + ) + task_run_request: Dict[str, Any] = Field( + template=_default_task_run_request_template() + ) + configure_cloudwatch_logs: Optional[bool] = Field(default=None) + cloudwatch_logs_options: Dict[str, str] = Field(default_factory=dict) + cloudwatch_logs_prefix: Optional[str] = Field(default=None) + network_configuration: Dict[str, Any] = Field(default_factory=dict) + stream_output: Optional[bool] = Field(default=None) + task_start_timeout_seconds: int = Field(default=300) + task_watch_poll_interval: float = Field(default=5.0) + auto_deregister_task_definition: bool = Field(default=False) + vpc_id: Optional[str] = Field(default=None) + container_name: Optional[str] = Field(default=None) + cluster: Optional[str] = Field(default=None) + match_latest_revision_in_family: bool = Field(default=False) + + @root_validator + def task_run_request_requires_arn_if_no_task_definition_given(cls, values) -> dict: + """ + If no task definition is provided, a task definition ARN must be present on the + task run request. + """ + if not values.get("task_run_request", {}).get( + "taskDefinition" + ) and not values.get("task_definition"): + raise ValueError( + "A task definition must be provided if a task definition ARN is not " + "present on the task run request." + ) + return values + + @root_validator + def container_name_default_from_task_definition(cls, values) -> dict: + """ + Infers the container name from the task definition if not provided. + """ + if values.get("container_name") is None: + values["container_name"] = _container_name_from_task_definition( + values.get("task_definition") + ) + + # We may not have a name here still; for example if someone is using a task + # definition arn. In that case, we'll perform similar logic later to find + # the name to treat as the "orchestration" container. + + return values + + @root_validator(pre=True) + def set_default_configure_cloudwatch_logs(cls, values: dict) -> dict: + """ + Streaming output generally requires CloudWatch logs to be configured. + + To avoid entangled arguments in the simple case, `configure_cloudwatch_logs` + defaults to matching the value of `stream_output`. + """ + configure_cloudwatch_logs = values.get("configure_cloudwatch_logs") + if configure_cloudwatch_logs is None: + values["configure_cloudwatch_logs"] = values.get("stream_output") + return values + + @root_validator + def configure_cloudwatch_logs_requires_execution_role_arn( + cls, values: dict + ) -> dict: + """ + Enforces that an execution role arn is provided (or could be provided by a + runtime task definition) when configuring logging. + """ + if ( + values.get("configure_cloudwatch_logs") + and not values.get("execution_role_arn") + # TODO: Does not match + # Do not raise if they've linked to another task definition or provided + # it without using our shortcuts + and not values.get("task_run_request", {}).get("taskDefinition") + and not (values.get("task_definition") or {}).get("executionRoleArn") + ): + raise ValueError( + "An `execution_role_arn` must be provided to use " + "`configure_cloudwatch_logs` or `stream_logs`." + ) + return values + + @root_validator + def cloudwatch_logs_options_requires_configure_cloudwatch_logs( + cls, values: dict + ) -> dict: + """ + Enforces that an execution role arn is provided (or could be provided by a + runtime task definition) when configuring logging. + """ + if values.get("cloudwatch_logs_options") and not values.get( + "configure_cloudwatch_logs" + ): + raise ValueError( + "`configure_cloudwatch_log` must be enabled to use " + "`cloudwatch_logs_options`." + ) + return values + + @root_validator + def network_configuration_requires_vpc_id(cls, values: dict) -> dict: + """ + Enforces a `vpc_id` is provided when custom network configuration mode is + enabled for network settings. + """ + if values.get("network_configuration") and not values.get("vpc_id"): + raise ValueError( + "You must provide a `vpc_id` to enable custom `network_configuration`." + ) + return values + + +class ECSVariables(BaseVariables): + """ + Variables for templating an ECS job. + """ + + task_definition_arn: Optional[str] = Field( + default=None, + description=( + "An identifier for an existing task definition to use. If set, options that" + " require changes to the task definition will be ignored. All contents of " + "the task definition in the job configuration will be ignored." + ), + ) + env: Dict[str, Optional[str]] = Field( + title="Environment Variables", + default_factory=dict, + description=( + "Environment variables to provide to the task run. These variables are set " + "on the Prefect container at task runtime. These will not be set on the " + "task definition." + ), + ) + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description=( + "The AWS credentials to use to connect to ECS. If not provided, credentials" + " will be inferred from the local environment following AWS's boto client's" + " rules." + ), + ) + cluster: Optional[str] = Field( + default=None, + description=( + "The ECS cluster to run the task in. An ARN or name may be provided. If " + "not provided, the default cluster will be used." + ), + ) + family: Optional[str] = Field( + default=None, + description=( + "A family for the task definition. If not provided, it will be inferred " + "from the task definition. If the task definition does not have a family, " + "the name will be generated. When flow and deployment metadata is " + "available, the generated name will include their names. Values for this " + "field will be slugified to match AWS character requirements." + ), + ) + launch_type: Optional[ + Literal["FARGATE", "EC2", "EXTERNAL", "FARGATE_SPOT"] + ] = Field( + default=ECS_DEFAULT_LAUNCH_TYPE, + description=( + "The type of ECS task run infrastructure that should be used. Note that" + " 'FARGATE_SPOT' is not a formal ECS launch type, but we will configure" + " the proper capacity provider strategy if set here." + ), + ) + capacity_provider_strategy: Optional[List[CapacityProvider]] = Field( + default_factory=list, + description=( + "The capacity provider strategy to use when running the task. " + "If a capacity provider strategy is specified, the selected launch" + " type will be ignored." + ), + ) + image: Optional[str] = Field( + default=None, + description=( + "The image to use for the Prefect container in the task. If this value is " + "not null, it will override the value in the task definition. This value " + "defaults to a Prefect base image matching your local versions." + ), + ) + cpu: int = Field( + title="CPU", + default=None, + description=( + "The amount of CPU to provide to the ECS task. Valid amounts are " + "specified in the AWS documentation. If not provided, a default value of " + f"{ECS_DEFAULT_CPU} will be used unless present on the task definition." + ), + ) + memory: int = Field( + default=None, + description=( + "The amount of memory to provide to the ECS task. Valid amounts are " + "specified in the AWS documentation. If not provided, a default value of " + f"{ECS_DEFAULT_MEMORY} will be used unless present on the task definition." + ), + ) + container_name: str = Field( + default=None, + description=( + "The name of the container flow run orchestration will occur in. If not " + f"specified, a default value of {ECS_DEFAULT_CONTAINER_NAME} will be used " + "and if that is not found in the task definition the first container will " + "be used." + ), + ) + task_role_arn: str = Field( + title="Task Role ARN", + default=None, + description=( + "A role to attach to the task run. This controls the permissions of the " + "task while it is running." + ), + ) + execution_role_arn: str = Field( + title="Execution Role ARN", + default=None, + description=( + "An execution role to use for the task. This controls the permissions of " + "the task when it is launching. If this value is not null, it will " + "override the value in the task definition. An execution role must be " + "provided to capture logs from the container." + ), + ) + vpc_id: Optional[str] = Field( + title="VPC ID", + default=None, + description=( + "The AWS VPC to link the task run to. This is only applicable when using " + "the 'awsvpc' network mode for your task. FARGATE tasks require this " + "network mode, but for EC2 tasks the default network mode is 'bridge'. " + "If using the 'awsvpc' network mode and this field is null, your default " + "VPC will be used. If no default VPC can be found, the task run will fail." + ), + ) + configure_cloudwatch_logs: bool = Field( + default=None, + description=( + "If enabled, the Prefect container will be configured to send its output " + "to the AWS CloudWatch logs service. This functionality requires an " + "execution role with logs:CreateLogStream, logs:CreateLogGroup, and " + "logs:PutLogEvents permissions. The default for this field is `False` " + "unless `stream_output` is set." + ), + ) + cloudwatch_logs_options: Dict[str, str] = Field( + default_factory=dict, + description=( + "When `configure_cloudwatch_logs` is enabled, this setting may be used to" + " pass additional options to the CloudWatch logs configuration or override" + " the default options. See the [AWS" + " documentation](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html#create_awslogs_logdriver_options)" # noqa + " for available options. " + ), + ) + cloudwatch_logs_prefix: Optional[str] = Field( + default=None, + description=( + "When `configure_cloudwatch_logs` is enabled, this setting may be used to" + " set a prefix for the log group. If not provided, the default prefix will" + " be `prefect-logs__`. If" + " `awslogs-stream-prefix` is present in `Cloudwatch logs options` this" + " setting will be ignored." + ), + ) + + network_configuration: Dict[str, Any] = Field( + default_factory=dict, + description=( + "When `network_configuration` is supplied it will override ECS Worker's" + "awsvpcConfiguration that defined in the ECS task executing your workload. " + "See the [AWS documentation](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-ecs-service-awsvpcconfiguration.html)" # noqa + " for available options." + ), + ) + + stream_output: bool = Field( + default=None, + description=( + "If enabled, logs will be streamed from the Prefect container to the local " + "console. Unless you have configured AWS CloudWatch logs manually on your " + "task definition, this requires the same prerequisites outlined in " + "`configure_cloudwatch_logs`." + ), + ) + task_start_timeout_seconds: int = Field( + default=300, + description=( + "The amount of time to watch for the start of the ECS task " + "before marking it as failed. The task must enter a RUNNING state to be " + "considered started." + ), + ) + task_watch_poll_interval: float = Field( + default=5.0, + description=( + "The amount of time to wait between AWS API calls while monitoring the " + "state of an ECS task." + ), + ) + auto_deregister_task_definition: bool = Field( + default=False, + description=( + "If enabled, any task definitions that are created by this block will be " + "deregistered. Existing task definitions linked by ARN will never be " + "deregistered. Deregistering a task definition does not remove it from " + "your AWS account, instead it will be marked as INACTIVE." + ), + ) + match_latest_revision_in_family: bool = Field( + default=False, + description=( + "If enabled, the most recent active revision in the task definition " + "family will be compared against the desired ECS task configuration. " + "If they are equal, the existing task definition will be used instead " + "of registering a new one. If no family is specified the default family " + f'"{ECS_DEFAULT_FAMILY}" will be used.' + ), + ) + + +class ECSWorkerResult(BaseWorkerResult): + """ + The result of an ECS job. + """ + + +class ECSWorker(BaseWorker): + """ + A Prefect worker to run flow runs as ECS tasks. + """ + + type = "ecs" + job_configuration = ECSJobConfiguration + job_configuration_variables = ECSVariables + _description = ( + "Execute flow runs within containers on AWS ECS. Works with EC2 " + "and Fargate clusters. Requires an AWS account." + ) + _display_name = "AWS Elastic Container Service" + _documentation_url = "https://prefecthq.github.io/prefect-aws/ecs_worker/" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + + async def run( + self, + flow_run: "FlowRun", + configuration: ECSJobConfiguration, + task_status: Optional[anyio.abc.TaskStatus] = None, + ) -> BaseWorkerResult: + """ + Runs a given flow run on the current worker. + """ + ecs_client = await run_sync_in_worker_thread( + self._get_client, configuration, "ecs" + ) + + logger = self.get_flow_run_logger(flow_run) + + ( + task_arn, + cluster_arn, + task_definition, + is_new_task_definition, + ) = await run_sync_in_worker_thread( + self._create_task_and_wait_for_start, + logger, + ecs_client, + configuration, + flow_run, + ) + + # The task identifier is "{cluster}::{task}" where we use the configured cluster + # if set to preserve matching by name rather than arn + # Note "::" is used despite the Prefect standard being ":" because ARNs contain + # single colons. + identifier = ( + (configuration.cluster if configuration.cluster else cluster_arn) + + "::" + + task_arn + ) + + if task_status: + task_status.started(identifier) + + status_code = await run_sync_in_worker_thread( + self._watch_task_and_get_exit_code, + logger, + configuration, + task_arn, + cluster_arn, + task_definition, + is_new_task_definition and configuration.auto_deregister_task_definition, + ecs_client, + ) + + return ECSWorkerResult( + identifier=identifier, + # If the container does not start the exit code can be null but we must + # still report a status code. We use a -1 to indicate a special code. + status_code=status_code if status_code is not None else -1, + ) + + def _get_client( + self, configuration: ECSJobConfiguration, client_type: Union[str, ClientType] + ) -> _ECSClient: + """ + Get a boto3 client of client_type. Will use a cached client if one exists. + """ + return configuration.aws_credentials.get_client(client_type) + + def _create_task_and_wait_for_start( + self, + logger: logging.Logger, + ecs_client: _ECSClient, + configuration: ECSJobConfiguration, + flow_run: FlowRun, + ) -> Tuple[str, str, dict, bool]: + """ + Register the task definition, create the task run, and wait for it to start. + + Returns a tuple of + - The task ARN + - The task's cluster ARN + - The task definition + - A bool indicating if the task definition is newly registered + """ + task_definition_arn = configuration.task_run_request.get("taskDefinition") + new_task_definition_registered = False + + if not task_definition_arn: + task_definition = self._prepare_task_definition( + configuration, region=ecs_client.meta.region_name, flow_run=flow_run + ) + ( + task_definition_arn, + new_task_definition_registered, + ) = self._get_or_register_task_definition( + logger, ecs_client, configuration, flow_run, task_definition + ) + else: + task_definition = self._retrieve_task_definition( + logger, ecs_client, task_definition_arn + ) + if configuration.task_definition: + logger.warning( + "Ignoring task definition in configuration since task definition" + " ARN is provided on the task run request." + ) + + self._validate_task_definition(task_definition, configuration) + + _TASK_DEFINITION_CACHE[flow_run.deployment_id] = task_definition_arn + + logger.info(f"Using ECS task definition {task_definition_arn!r}...") + logger.debug( + f"Task definition {json.dumps(task_definition, indent=2, default=str)}" + ) + + task_run_request = self._prepare_task_run_request( + configuration, + task_definition, + task_definition_arn, + ) + + logger.info("Creating ECS task run...") + logger.debug( + "Task run request" + f"{json.dumps(mask_api_key(task_run_request), indent=2, default=str)}" + ) + + try: + task = self._create_task_run(ecs_client, task_run_request) + task_arn = task["taskArn"] + cluster_arn = task["clusterArn"] + except Exception as exc: + self._report_task_run_creation_failure(configuration, task_run_request, exc) + raise + + logger.info("Waiting for ECS task run to start...") + self._wait_for_task_start( + logger, + configuration, + task_arn, + cluster_arn, + ecs_client, + timeout=configuration.task_start_timeout_seconds, + ) + + return task_arn, cluster_arn, task_definition, new_task_definition_registered + + def _get_or_register_task_definition( + self, + logger: logging.Logger, + ecs_client: _ECSClient, + configuration: ECSJobConfiguration, + flow_run: FlowRun, + task_definition: dict, + ) -> Tuple[str, bool]: + """Get or register a task definition for the given flow run. + + Returns a tuple of the task definition ARN and a bool indicating if the task + definition is newly registered. + """ + + cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(flow_run.deployment_id) + new_task_definition_registered = False + + if cached_task_definition_arn: + try: + cached_task_definition = self._retrieve_task_definition( + logger, ecs_client, cached_task_definition_arn + ) + if not cached_task_definition[ + "status" + ] == "ACTIVE" or not self._task_definitions_equal( + task_definition, cached_task_definition + ): + cached_task_definition_arn = None + except Exception: + cached_task_definition_arn = None + + if ( + not cached_task_definition_arn + and configuration.match_latest_revision_in_family + ): + family_name = task_definition.get("family", ECS_DEFAULT_FAMILY) + try: + task_definition_from_family = self._retrieve_task_definition( + logger, ecs_client, family_name + ) + if task_definition_from_family and self._task_definitions_equal( + task_definition, task_definition_from_family + ): + cached_task_definition_arn = task_definition_from_family[ + "taskDefinitionArn" + ] + except Exception: + cached_task_definition_arn = None + + if not cached_task_definition_arn: + task_definition_arn = self._register_task_definition( + logger, ecs_client, task_definition + ) + new_task_definition_registered = True + else: + task_definition_arn = cached_task_definition_arn + + return task_definition_arn, new_task_definition_registered + + def _watch_task_and_get_exit_code( + self, + logger: logging.Logger, + configuration: ECSJobConfiguration, + task_arn: str, + cluster_arn: str, + task_definition: dict, + deregister_task_definition: bool, + ecs_client: _ECSClient, + ) -> Optional[int]: + """ + Wait for the task run to complete and retrieve the exit code of the Prefect + container. + """ + + # Wait for completion and stream logs + task = self._wait_for_task_finish( + logger, + configuration, + task_arn, + cluster_arn, + task_definition, + ecs_client, + ) + + if deregister_task_definition: + ecs_client.deregister_task_definition( + taskDefinition=task["taskDefinitionArn"] + ) + + container_name = ( + configuration.container_name + or _container_name_from_task_definition(task_definition) + or ECS_DEFAULT_CONTAINER_NAME + ) + + # Check the status code of the Prefect container + container = _get_container(task["containers"], container_name) + assert ( + container is not None + ), f"'{container_name}' container missing from task: {task}" + status_code = container.get("exitCode") + self._report_container_status_code(logger, container_name, status_code) + + return status_code + + def _report_container_status_code( + self, logger: logging.Logger, name: str, status_code: Optional[int] + ) -> None: + """ + Display a log for the given container status code. + """ + if status_code is None: + logger.error( + f"Task exited without reporting an exit status for container {name!r}." + ) + elif status_code == 0: + logger.info(f"Container {name!r} exited successfully.") + else: + logger.warning( + f"Container {name!r} exited with non-zero exit code {status_code}." + ) + + def _report_task_run_creation_failure( + self, configuration: ECSJobConfiguration, task_run: dict, exc: Exception + ) -> None: + """ + Wrap common AWS task run creation failures with nicer user-facing messages. + """ + # AWS generates exception types at runtime so they must be captured a bit + # differently than normal. + if "ClusterNotFoundException" in str(exc): + cluster = task_run.get("cluster", "default") + raise RuntimeError( + f"Failed to run ECS task, cluster {cluster!r} not found. " + "Confirm that the cluster is configured in your region." + ) from exc + elif ( + "No Container Instances" in str(exc) and task_run.get("launchType") == "EC2" + ): + cluster = task_run.get("cluster", "default") + raise RuntimeError( + f"Failed to run ECS task, cluster {cluster!r} does not appear to " + "have any container instances associated with it. Confirm that you " + "have EC2 container instances available." + ) from exc + elif ( + "failed to validate logger args" in str(exc) + and "AccessDeniedException" in str(exc) + and configuration.configure_cloudwatch_logs + ): + raise RuntimeError( + "Failed to run ECS task, the attached execution role does not appear" + " to have sufficient permissions. Ensure that the execution role" + f" {configuration.execution_role!r} has permissions" + " logs:CreateLogStream, logs:CreateLogGroup, and logs:PutLogEvents." + ) + else: + raise + + def _validate_task_definition( + self, task_definition: dict, configuration: ECSJobConfiguration + ) -> None: + """ + Ensure that the task definition is compatible with the configuration. + + Raises `ValueError` on incompatibility. Returns `None` on success. + """ + launch_type = configuration.task_run_request.get( + "launchType", ECS_DEFAULT_LAUNCH_TYPE + ) + if ( + launch_type != "EC2" + and "FARGATE" not in task_definition["requiresCompatibilities"] + ): + raise ValueError( + "Task definition does not have 'FARGATE' in 'requiresCompatibilities'" + f" and cannot be used with launch type {launch_type!r}" + ) + + if launch_type == "FARGATE" or launch_type == "FARGATE_SPOT": + # Only the 'awsvpc' network mode is supported when using FARGATE + network_mode = task_definition.get("networkMode") + if network_mode != "awsvpc": + raise ValueError( + f"Found network mode {network_mode!r} which is not compatible with " + f"launch type {launch_type!r}. Use either the 'EC2' launch " + "type or the 'awsvpc' network mode." + ) + + if configuration.configure_cloudwatch_logs and not task_definition.get( + "executionRoleArn" + ): + raise ValueError( + "An execution role arn must be set on the task definition to use " + "`configure_cloudwatch_logs` or `stream_logs` but no execution role " + "was found on the task definition." + ) + + def _register_task_definition( + self, + logger: logging.Logger, + ecs_client: _ECSClient, + task_definition: dict, + ) -> str: + """ + Register a new task definition with AWS. + + Returns the ARN. + """ + logger.info("Registering ECS task definition...") + logger.debug( + "Task definition request" + f"{json.dumps(task_definition, indent=2, default=str)}" + ) + response = ecs_client.register_task_definition(**task_definition) + return response["taskDefinition"]["taskDefinitionArn"] + + def _retrieve_task_definition( + self, + logger: logging.Logger, + ecs_client: _ECSClient, + task_definition: str, + ): + """ + Retrieve an existing task definition from AWS. + """ + if task_definition.startswith("arn:aws:ecs:"): + logger.info(f"Retrieving ECS task definition {task_definition!r}...") + else: + logger.info( + "Retrieving most recent active revision from " + f"ECS task family {task_definition!r}..." + ) + response = ecs_client.describe_task_definition(taskDefinition=task_definition) + return response["taskDefinition"] + + def _wait_for_task_start( + self, + logger: logging.Logger, + configuration: ECSJobConfiguration, + task_arn: str, + cluster_arn: str, + ecs_client: _ECSClient, + timeout: int, + ) -> dict: + """ + Waits for an ECS task run to reach a RUNNING status. + + If a STOPPED status is reached instead, an exception is raised indicating the + reason that the task run did not start. + """ + for task in self._watch_task_run( + logger, + configuration, + task_arn, + cluster_arn, + ecs_client, + until_status="RUNNING", + timeout=timeout, + ): + # TODO: It is possible that the task has passed _through_ a RUNNING + # status during the polling interval. In this case, there is not an + # exception to raise. + if task["lastStatus"] == "STOPPED": + code = task.get("stopCode") + reason = task.get("stoppedReason") + # Generate a dynamic exception type from the AWS name + raise type(code, (RuntimeError,), {})(reason) + + return task + + def _wait_for_task_finish( + self, + logger: logging.Logger, + configuration: ECSJobConfiguration, + task_arn: str, + cluster_arn: str, + task_definition: dict, + ecs_client: _ECSClient, + ): + """ + Watch an ECS task until it reaches a STOPPED status. + + If configured, logs from the Prefect container are streamed to stderr. + + Returns a description of the task on completion. + """ + can_stream_output = False + container_name = ( + configuration.container_name + or _container_name_from_task_definition(task_definition) + or ECS_DEFAULT_CONTAINER_NAME + ) + + if configuration.stream_output: + container_def = _get_container( + task_definition["containerDefinitions"], container_name + ) + if not container_def: + logger.warning( + "Prefect container definition not found in " + "task definition. Output cannot be streamed." + ) + elif not container_def.get("logConfiguration"): + logger.warning( + "Logging configuration not found on task. " + "Output cannot be streamed." + ) + elif not container_def["logConfiguration"].get("logDriver") == "awslogs": + logger.warning( + "Logging configuration uses unsupported " + " driver {container_def['logConfiguration'].get('logDriver')!r}. " + "Output cannot be streamed." + ) + else: + # Prepare to stream the output + log_config = container_def["logConfiguration"]["options"] + logs_client = self._get_client(configuration, "logs") + can_stream_output = True + # Track the last log timestamp to prevent double display + last_log_timestamp: Optional[int] = None + # Determine the name of the stream as "prefix/container/run-id" + stream_name = "/".join( + [ + log_config["awslogs-stream-prefix"], + container_name, + task_arn.rsplit("/")[-1], + ] + ) + self._logger.info( + f"Streaming output from container {container_name!r}..." + ) + + for task in self._watch_task_run( + logger, + configuration, + task_arn, + cluster_arn, + ecs_client, + current_status="RUNNING", + ): + if configuration.stream_output and can_stream_output: + # On each poll for task run status, also retrieve available logs + last_log_timestamp = self._stream_available_logs( + logger, + logs_client, + log_group=log_config["awslogs-group"], + log_stream=stream_name, + last_log_timestamp=last_log_timestamp, + ) + + return task + + def _stream_available_logs( + self, + logger: logging.Logger, + logs_client: Any, + log_group: str, + log_stream: str, + last_log_timestamp: Optional[int] = None, + ) -> Optional[int]: + """ + Stream logs from the given log group and stream since the last log timestamp. + + Will continue on paginated responses until all logs are returned. + + Returns the last log timestamp which can be used to call this method in the + future. + """ + last_log_stream_token = "NO-TOKEN" + next_log_stream_token = None + + # AWS will return the same token that we send once the end of the paginated + # response is reached + while last_log_stream_token != next_log_stream_token: + last_log_stream_token = next_log_stream_token + + request = { + "logGroupName": log_group, + "logStreamName": log_stream, + } + + if last_log_stream_token is not None: + request["nextToken"] = last_log_stream_token + + if last_log_timestamp is not None: + # Bump the timestamp by one ms to avoid retrieving the last log again + request["startTime"] = last_log_timestamp + 1 + + try: + response = logs_client.get_log_events(**request) + except Exception: + logger.error( + f"Failed to read log events with request {request}", + exc_info=True, + ) + return last_log_timestamp + + log_events = response["events"] + for log_event in log_events: + # TODO: This doesn't forward to the local logger, which can be + # bad for customizing handling and understanding where the + # log is coming from, but it avoid nesting logger information + # when the content is output from a Prefect logger on the + # running infrastructure + print(log_event["message"], file=sys.stderr) + + if ( + last_log_timestamp is None + or log_event["timestamp"] > last_log_timestamp + ): + last_log_timestamp = log_event["timestamp"] + + next_log_stream_token = response.get("nextForwardToken") + if not log_events: + # Stop reading pages if there was no data + break + + return last_log_timestamp + + def _watch_task_run( + self, + logger: logging.Logger, + configuration: ECSJobConfiguration, + task_arn: str, + cluster_arn: str, + ecs_client: _ECSClient, + current_status: str = "UNKNOWN", + until_status: str = None, + timeout: int = None, + ) -> Generator[None, None, dict]: + """ + Watches an ECS task run by querying every `poll_interval` seconds. After each + query, the retrieved task is yielded. This function returns when the task run + reaches a STOPPED status or the provided `until_status`. + + Emits a log each time the status changes. + """ + last_status = status = current_status + t0 = time.time() + while status != until_status: + tasks = ecs_client.describe_tasks( + tasks=[task_arn], cluster=cluster_arn, include=["TAGS"] + )["tasks"] + + if tasks: + task = tasks[0] + + status = task["lastStatus"] + if status != last_status: + logger.info(f"ECS task status is {status}.") + + yield task + + # No point in continuing if the status is final + if status == "STOPPED": + break + + last_status = status + + else: + # Intermittently, the task will not be described. We wat to respect the + # watch timeout though. + logger.debug("Task not found.") + + elapsed_time = time.time() - t0 + if timeout is not None and elapsed_time > timeout: + raise RuntimeError( + f"Timed out after {elapsed_time}s while watching task for status " + f"{until_status or 'STOPPED'}." + ) + time.sleep(configuration.task_watch_poll_interval) + + def _get_or_generate_family(self, task_definition: dict, flow_run: FlowRun) -> str: + """ + Gets or generate a family for the task definition. + """ + family = task_definition.get("family") + if not family: + assert self._work_pool_name and flow_run.deployment_id + family = ( + f"{ECS_DEFAULT_FAMILY}_{self._work_pool_name}_{flow_run.deployment_id}" + ) + slugify( + family, + max_length=255, + regex_pattern=r"[^a-zA-Z0-9-_]+", + ) + return family + + def _prepare_task_definition( + self, + configuration: ECSJobConfiguration, + region: str, + flow_run: FlowRun, + ) -> dict: + """ + Prepare a task definition by inferring any defaults and merging overrides. + """ + task_definition = copy.deepcopy(configuration.task_definition) + + # Configure the Prefect runtime container + task_definition.setdefault("containerDefinitions", []) + + # Remove empty container definitions + task_definition["containerDefinitions"] = [ + d for d in task_definition["containerDefinitions"] if d + ] + + container_name = configuration.container_name + if not container_name: + container_name = ( + _container_name_from_task_definition(task_definition) + or ECS_DEFAULT_CONTAINER_NAME + ) + + container = _get_container( + task_definition["containerDefinitions"], container_name + ) + if container is None: + if container_name != ECS_DEFAULT_CONTAINER_NAME: + raise ValueError( + f"Container {container_name!r} not found in task definition." + ) + + # Look for a container without a name + for container in task_definition["containerDefinitions"]: + if "name" not in container: + container["name"] = container_name + break + else: + container = {"name": container_name} + task_definition["containerDefinitions"].append(container) + + # Image is required so make sure it's present + container.setdefault("image", get_prefect_image_name()) + + # Remove any keys that have been explicitly "unset" + unset_keys = {key for key, value in configuration.env.items() if value is None} + for item in tuple(container.get("environment", [])): + if item["name"] in unset_keys or item["value"] is None: + container["environment"].remove(item) + + if configuration.configure_cloudwatch_logs: + prefix = f"prefect-logs_{self._work_pool_name}_{flow_run.deployment_id}" + container["logConfiguration"] = { + "logDriver": "awslogs", + "options": { + "awslogs-create-group": "true", + "awslogs-group": "prefect", + "awslogs-region": region, + "awslogs-stream-prefix": ( + configuration.cloudwatch_logs_prefix or prefix + ), + **configuration.cloudwatch_logs_options, + }, + } + + task_definition["family"] = self._get_or_generate_family( + task_definition, flow_run + ) + # CPU and memory are required in some cases, retrieve the value to use + cpu = task_definition.get("cpu") or ECS_DEFAULT_CPU + memory = task_definition.get("memory") or ECS_DEFAULT_MEMORY + + launch_type = configuration.task_run_request.get( + "launchType", ECS_DEFAULT_LAUNCH_TYPE + ) + + if launch_type == "FARGATE" or launch_type == "FARGATE_SPOT": + # Task level memory and cpu are required when using fargate + task_definition["cpu"] = str(cpu) + task_definition["memory"] = str(memory) + + # The FARGATE compatibility is required if it will be used as as launch type + requires_compatibilities = task_definition.setdefault( + "requiresCompatibilities", [] + ) + if "FARGATE" not in requires_compatibilities: + task_definition["requiresCompatibilities"].append("FARGATE") + + # Only the 'awsvpc' network mode is supported when using FARGATE + # However, we will not enforce that here if the user has set it + task_definition.setdefault("networkMode", "awsvpc") + + elif launch_type == "EC2": + # Container level memory and cpu are required when using ec2 + container.setdefault("cpu", cpu) + container.setdefault("memory", memory) + + # Ensure set values are cast to integers + container["cpu"] = int(container["cpu"]) + container["memory"] = int(container["memory"]) + + # Ensure set values are cast to strings + if task_definition.get("cpu"): + task_definition["cpu"] = str(task_definition["cpu"]) + if task_definition.get("memory"): + task_definition["memory"] = str(task_definition["memory"]) + + return task_definition + + def _load_network_configuration( + self, vpc_id: Optional[str], configuration: ECSJobConfiguration + ) -> dict: + """ + Load settings from a specific VPC or the default VPC and generate a task + run request's network configuration. + """ + ec2_client = self._get_client(configuration, "ec2") + vpc_message = "the default VPC" if not vpc_id else f"VPC with ID {vpc_id}" + + if not vpc_id: + # Retrieve the default VPC + describe = {"Filters": [{"Name": "isDefault", "Values": ["true"]}]} + else: + describe = {"VpcIds": [vpc_id]} + + vpcs = ec2_client.describe_vpcs(**describe)["Vpcs"] + if not vpcs: + help_message = ( + "Pass an explicit `vpc_id` or configure a default VPC." + if not vpc_id + else "Check that the VPC exists in the current region." + ) + raise ValueError( + f"Failed to find {vpc_message}. " + "Network configuration cannot be inferred. " + help_message + ) + + vpc_id = vpcs[0]["VpcId"] + subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["Subnets"] + if not subnets: + raise ValueError( + f"Failed to find subnets for {vpc_message}. " + "Network configuration cannot be inferred." + ) + + return { + "awsvpcConfiguration": { + "subnets": [s["SubnetId"] for s in subnets], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + def _custom_network_configuration( + self, + vpc_id: str, + network_configuration: dict, + configuration: ECSJobConfiguration, + ) -> dict: + """ + Load settings from a specific VPC or the default VPC and generate a task + run request's network configuration. + """ + ec2_client = self._get_client(configuration, "ec2") + vpc_message = f"VPC with ID {vpc_id}" + + vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs") + + if not vpcs: + raise ValueError( + f"Failed to find {vpc_message}. " + + "Network configuration cannot be inferred. " + + "Pass an explicit `vpc_id`." + ) + + vpc_id = vpcs[0]["VpcId"] + subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["Subnets"] + + if not subnets: + raise ValueError( + f"Failed to find subnets for {vpc_message}. " + + "Network configuration cannot be inferred." + ) + + subnet_ids = [subnet["SubnetId"] for subnet in subnets] + + config_subnets = network_configuration.get("subnets", []) + if not all(conf_sn in subnet_ids for conf_sn in config_subnets): + raise ValueError( + f"Subnets {config_subnets} not found within {vpc_message}." + + "Please check that VPC is associated with supplied subnets." + ) + + return {"awsvpcConfiguration": network_configuration} + + def _prepare_task_run_request( + self, + configuration: ECSJobConfiguration, + task_definition: dict, + task_definition_arn: str, + ) -> dict: + """ + Prepare a task run request payload. + """ + task_run_request = deepcopy(configuration.task_run_request) + + task_run_request.setdefault("taskDefinition", task_definition_arn) + assert task_run_request["taskDefinition"] == task_definition_arn + capacityProviderStrategy = task_run_request.get("capacityProviderStrategy") + + if capacityProviderStrategy: + # Should not be provided at all if capacityProviderStrategy is set, see https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html#ECS-RunTask-request-capacityProviderStrategy # noqa + self._logger.warning( + "Found capacityProviderStrategy. " + "Removing launchType from task run request." + ) + task_run_request.pop("launchType", None) + + elif task_run_request.get("launchType") == "FARGATE_SPOT": + # Should not be provided at all for FARGATE SPOT + task_run_request.pop("launchType", None) + + # A capacity provider strategy is required for FARGATE SPOT + task_run_request["capacityProviderStrategy"] = [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] + overrides = task_run_request.get("overrides", {}) + container_overrides = overrides.get("containerOverrides", []) + + # Ensure the network configuration is present if using awsvpc for network mode + if ( + task_definition.get("networkMode") == "awsvpc" + and not task_run_request.get("networkConfiguration") + and not configuration.network_configuration + ): + task_run_request["networkConfiguration"] = self._load_network_configuration( + configuration.vpc_id, configuration + ) + + # Use networkConfiguration if supplied by user + if ( + task_definition.get("networkMode") == "awsvpc" + and configuration.network_configuration + and configuration.vpc_id + ): + task_run_request[ + "networkConfiguration" + ] = self._custom_network_configuration( + configuration.vpc_id, + configuration.network_configuration, + configuration, + ) + + # Ensure the container name is set if not provided at template time + + container_name = ( + configuration.container_name + or _container_name_from_task_definition(task_definition) + or ECS_DEFAULT_CONTAINER_NAME + ) + + if container_overrides and not container_overrides[0].get("name"): + container_overrides[0]["name"] = container_name + + # Ensure configuration command is respected post-templating + + orchestration_container = _get_container(container_overrides, container_name) + + if orchestration_container: + # Override the command if given on the configuration + if configuration.command: + orchestration_container["command"] = configuration.command + + # Clean up templated variable formatting + + for container in container_overrides: + if isinstance(container.get("command"), str): + container["command"] = shlex.split(container["command"]) + if isinstance(container.get("environment"), dict): + container["environment"] = [ + {"name": k, "value": v} for k, v in container["environment"].items() + ] + + # Remove null values — they're not allowed by AWS + container["environment"] = [ + item + for item in container.get("environment", []) + if item["value"] is not None + ] + + if isinstance(task_run_request.get("tags"), dict): + task_run_request["tags"] = [ + {"key": k, "value": v} for k, v in task_run_request["tags"].items() + ] + + if overrides.get("cpu"): + overrides["cpu"] = str(overrides["cpu"]) + + if overrides.get("memory"): + overrides["memory"] = str(overrides["memory"]) + + # Ensure configuration tags and env are respected post-templating + + tags = [ + item + for item in task_run_request.get("tags", []) + if item["key"] not in configuration.labels.keys() + ] + [ + {"key": k, "value": v} + for k, v in configuration.labels.items() + if v is not None + ] + + # Slugify tags keys and values + tags = [ + { + "key": slugify( + item["key"], + regex_pattern=_TAG_REGEX, + allow_unicode=True, + lowercase=False, + ), + "value": slugify( + item["value"], + regex_pattern=_TAG_REGEX, + allow_unicode=True, + lowercase=False, + ), + } + for item in tags + ] + + if tags: + task_run_request["tags"] = tags + + if orchestration_container: + environment = [ + item + for item in orchestration_container.get("environment", []) + if item["name"] not in configuration.env.keys() + ] + [ + {"name": k, "value": v} + for k, v in configuration.env.items() + if v is not None + ] + if environment: + orchestration_container["environment"] = environment + + # Remove empty container overrides + + overrides["containerOverrides"] = [v for v in container_overrides if v] + + return task_run_request + + @retry( + stop=stop_after_attempt(MAX_CREATE_TASK_RUN_ATTEMPTS), + wait=wait_fixed(CREATE_TASK_RUN_MIN_DELAY_SECONDS) + + wait_random( + CREATE_TASK_RUN_MIN_DELAY_JITTER_SECONDS, + CREATE_TASK_RUN_MAX_DELAY_JITTER_SECONDS, + ), + reraise=True, + ) + def _create_task_run(self, ecs_client: _ECSClient, task_run_request: dict) -> str: + """ + Create a run of a task definition. + + Returns the task run ARN. + """ + task = ecs_client.run_task(**task_run_request) + if task["failures"]: + raise RuntimeError( + f"Failed to run ECS task: {task['failures'][0]['reason']}" + ) + elif not task["tasks"]: + raise RuntimeError( + "Failed to run ECS task: no tasks or failures were returned." + ) + return task["tasks"][0] + + def _task_definitions_equal(self, taskdef_1, taskdef_2) -> bool: + """ + Compare two task definitions. + + Since one may come from the AWS API and have populated defaults, we do our best + to homogenize the definitions without changing their meaning. + """ + if taskdef_1 == taskdef_2: + return True + + if taskdef_1 is None or taskdef_2 is None: + return False + + taskdef_1 = copy.deepcopy(taskdef_1) + taskdef_2 = copy.deepcopy(taskdef_2) + + for taskdef in (taskdef_1, taskdef_2): + # Set defaults that AWS would set after registration + container_definitions = taskdef.get("containerDefinitions", []) + essential = any( + container.get("essential") for container in container_definitions + ) + if not essential: + container_definitions[0].setdefault("essential", True) + + taskdef.setdefault("networkMode", "bridge") + + _drop_empty_keys_from_task_definition(taskdef_1) + _drop_empty_keys_from_task_definition(taskdef_2) + + # Clear fields that change on registration for comparison + for field in ECS_POST_REGISTRATION_FIELDS: + taskdef_1.pop(field, None) + taskdef_2.pop(field, None) + + return taskdef_1 == taskdef_2 + + async def kill_infrastructure( + self, + configuration: ECSJobConfiguration, + infrastructure_pid: str, + grace_seconds: int = 30, + ) -> None: + """ + Kill a task running on ECS. + + Args: + infrastructure_pid: A cluster and task arn combination. This should match a + value yielded by `ECSWorker.run`. + """ + if grace_seconds != 30: + self._logger.warning( + f"Kill grace period of {grace_seconds}s requested, but AWS does not " + "support dynamic grace period configuration so 30s will be used. " + "See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/ecs-agent-config.html for configuration of grace periods." # noqa + ) + cluster, task = parse_identifier(infrastructure_pid) + await run_sync_in_worker_thread(self._stop_task, configuration, cluster, task) + + def _stop_task( + self, configuration: ECSJobConfiguration, cluster: str, task: str + ) -> None: + """ + Stop a running ECS task. + """ + if configuration.cluster is not None and cluster != configuration.cluster: + raise InfrastructureNotAvailable( + "Cannot stop ECS task: this infrastructure block has access to " + f"cluster {configuration.cluster!r} but the task is running in cluster " + f"{cluster!r}." + ) + + ecs_client = self._get_client(configuration, "ecs") + try: + ecs_client.stop_task(cluster=cluster, task=task) + except Exception as exc: + # Raise a special exception if the task does not exist + if "ClusterNotFound" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the cluster {cluster!r} could not be found." + ) from exc + if "not find task" in str(exc) or "referenced task was not found" in str( + exc + ): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the task {task!r} could not be found in " + f"cluster {cluster!r}." + ) from exc + if "no registered tasks" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop ECS task: the cluster {cluster!r} has no tasks." + ) from exc + + # Reraise unknown exceptions + raise diff --git a/src/integrations/prefect-aws/pyproject.toml b/src/integrations/prefect-aws/pyproject.toml new file mode 100644 index 000000000000..9104a88b9e13 --- /dev/null +++ b/src/integrations/prefect-aws/pyproject.toml @@ -0,0 +1,91 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] +build-backend = "setuptools.build_meta" + +[project] +name = "prefect-aws" +description = "Prefect integrations for interacting with Amazon Web Services." +readme = "README.md" +requires-python = ">=3.8" +license = {text = "Apache License 2.0"} +keywords = ["prefect"] +authors = [ + {name = "Prefect Technologies, Inc.", email = "help@prefect.io"} +] +classifiers = [ + "Natural Language :: English", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Libraries", +] +dependencies = [ + "boto3>=1.24.53", + "botocore>=1.27.53", + "mypy_boto3_s3>=1.24.94", + "mypy_boto3_secretsmanager>=1.26.49", + "prefect>=2.16.4", + "pyparsing>=3.1.1", + "tenacity>=8.0.0", +] +dynamic = ["version"] + +[project.optional-dependencies] +dev = [ + "black", + "boto3-stubs >= 1.24.39", + "coverage", + "flake8", + "interrogate", + "isort", + "mkdocs", + "mkdocs-gen-files", + "mkdocs-material", + "mkdocstrings-python-legacy", + "mock; python_version < '3.8'", # moto 4.2.5 broke something fairly deep in our test suite https://github.com/PrefectHQ/prefect-aws/issues/318 + "moto >= 3.1.16, < 4.2.5", + "mypy", + "pillow", + "pre-commit", + "pytest", + "pytest-asyncio >= 0.18.2, != 0.22.0, < 0.23.0", # Cannot override event loop in 0.23.0. See https://github.com/pytest-dev/pytest-asyncio/issues/706 for more details. + "pytest-cov", + "pytest-xdist", + "types-boto3 >= 1.0.2", +] + +[project.urls] +Homepage = "https://github.com/PrefectHQ/prefect/tree/main/src/integrations/prefect-aws" + +[project.entry-points."prefect.collections"] +prefect_aws = "prefect_aws" + +[tool.setuptools_scm] +version_file = "prefect_aws/_version.py" +root = "../../.." +tag_regex = "^prefect-aws-(?P\\d+\\.\\d+\\.\\d+)$" +fallback_version = "0.0.0" +git_describe_command = 'git describe --dirty --tags --long --match "prefect-aws-*[0-9]*"' + +[tool.interrogate] +ignore-init-module = true +ignore_init_method = true +exclude = ["prefect_aws/_version.py", "tests"] +fail-under = 95 +omit-covered-files = true + +[tool.coverage.run] +omit = ["tests/*", "prefect_aws/_version.py"] + +[tool.coverage.report] +fail_under = 80 +show_missing = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" \ No newline at end of file diff --git a/src/integrations/prefect-aws/tests/__init__.py b/src/integrations/prefect-aws/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/integrations/prefect-aws/tests/conftest.py b/src/integrations/prefect-aws/tests/conftest.py new file mode 100644 index 000000000000..a6ac4e6c75b0 --- /dev/null +++ b/src/integrations/prefect-aws/tests/conftest.py @@ -0,0 +1,57 @@ +import pytest +from botocore import UNSIGNED +from botocore.client import Config +from prefect_aws import AwsCredentials +from prefect_aws.client_parameters import AwsClientParameters + +from prefect.testing.utilities import prefect_test_harness + + +# added to eliminate warnings +def pytest_configure(config): + config.addinivalue_line( + "markers", "is_public: mark test as using public S3 bucket or not" + ) + + +@pytest.fixture(scope="session", autouse=True) +def prefect_db(): + with prefect_test_harness(): + yield + + +@pytest.fixture +def aws_credentials(): + block = AwsCredentials( + aws_access_key_id="access_key_id", + aws_secret_access_key="secret_access_key", + region_name="us-east-1", + ) + block.save("test-creds-block", overwrite=True) + return block + + +@pytest.fixture +def aws_client_parameters_custom_endpoint(): + return AwsClientParameters(endpoint_url="http://custom.internal.endpoint.org") + + +@pytest.fixture +def aws_client_parameters_empty(): + return AwsClientParameters() + + +@pytest.fixture +def aws_client_parameters_public_bucket(): + return AwsClientParameters(config=Config(signature_version=UNSIGNED)) + + +@pytest.fixture(autouse=True) +def reset_object_registry(): + """ + Ensures each test has a clean object registry. + """ + from prefect.context import PrefectObjectRegistry + + with PrefectObjectRegistry(): + yield diff --git a/src/integrations/prefect-aws/tests/deployments/test_steps.py b/src/integrations/prefect-aws/tests/deployments/test_steps.py new file mode 100644 index 000000000000..6c33639b9e2d --- /dev/null +++ b/src/integrations/prefect-aws/tests/deployments/test_steps.py @@ -0,0 +1,439 @@ +import os +import sys +from pathlib import Path, PurePath, PurePosixPath +from unittest.mock import patch + +import boto3 +import pytest +from moto import mock_s3 +from prefect_aws import AwsCredentials +from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3 + + +@pytest.fixture(scope="module", autouse=True) +def set_custom_endpoint(): + original = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS") + os.environ["MOTO_S3_CUSTOM_ENDPOINTS"] = "http://custom.minio.endpoint:9000" + yield + os.environ.pop("MOTO_S3_CUSTOM_ENDPOINTS") + if original is not None: + os.environ["MOTO_S3_CUSTOM_ENDPOINTS"] = original + + +@pytest.fixture +def s3_setup(): + with mock_s3(): + bucket_name = "my-test-bucket" + s3 = boto3.client("s3") + s3.create_bucket(Bucket=bucket_name) + yield s3, bucket_name + + +@pytest.fixture +def tmp_files(tmp_path: Path): + files = [ + "testfile1.txt", + "testfile2.txt", + "testfile3.txt", + "testdir1/testfile4.txt", + "testdir2/testfile5.txt", + ] + + (tmp_path / ".prefectignore").write_text( + """ + testdir1/* + .prefectignore + """ + ) + + for file in files: + filepath = tmp_path / file + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_text("Sample text") + + return tmp_path + + +@pytest.fixture +def tmp_files_win(tmp_path: Path): + files = [ + "testfile1.txt", + "testfile2.txt", + "testfile3.txt", + r"testdir1\testfile4.txt", + r"testdir2\testfile5.txt", + ] + + for file in files: + filepath = tmp_path / file + filepath.parent.mkdir(parents=True, exist_ok=True) + filepath.write_text("Sample text") + + return tmp_path + + +@pytest.fixture +def mock_aws_credentials(monkeypatch): + # Set mock environment variables for AWS credentials + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test_access_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test_secret_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "test_session_token") + + # Yield control back to the test function + yield + + # Clean up by deleting the mock environment variables + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False) + + +def test_push_to_s3(s3_setup, tmp_files, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + os.chdir(tmp_files) + + push_to_s3(bucket_name, folder) + + s3_objects = s3.list_objects_v2(Bucket=bucket_name) + object_keys = [PurePath(item["Key"]).as_posix() for item in s3_objects["Contents"]] + + expected_keys = [ + f"{folder}/testfile1.txt", + f"{folder}/testfile2.txt", + f"{folder}/testfile3.txt", + f"{folder}/testdir2/testfile5.txt", + ] + + assert set(object_keys) == set(expected_keys) + + +@pytest.mark.skipif(sys.platform != "win32", reason="requires Windows") +def test_push_to_s3_as_posix(s3_setup, tmp_files_win, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + os.chdir(tmp_files_win) + + push_to_s3(bucket_name, folder) + + s3_objects = s3.list_objects_v2(Bucket=bucket_name) + object_keys = [item["Key"] for item in s3_objects["Contents"]] + + expected_keys = [ + f"{folder}/testfile1.txt", + f"{folder}/testfile2.txt", + f"{folder}/testfile3.txt", + f"{folder}/testdir1/testfile4.txt", + f"{folder}/testdir2/testfile5.txt", + ] + + assert set(object_keys) == set(expected_keys) + + +def test_pull_from_s3(s3_setup, tmp_path, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + files = { + f"{folder}/testfile1.txt": "Hello, world!", + f"{folder}/testfile2.txt": "Test content", + f"{folder}/testdir1/testfile3.txt": "Nested file", + } + + for key, content in files.items(): + s3.put_object(Bucket=bucket_name, Key=key, Body=content) + + os.chdir(tmp_path) + pull_from_s3(bucket_name, folder) + + for key, content in files.items(): + target = Path(tmp_path) / PurePosixPath(key).relative_to(folder) + assert target.exists() + assert target.read_text() == content + + +def test_push_pull_empty_folders(s3_setup, tmp_path, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + # Create empty folders + (tmp_path / "empty1").mkdir() + (tmp_path / "empty2").mkdir() + + # Create test files + (tmp_path / "testfile1.txt").write_text("Sample text") + (tmp_path / "testfile2.txt").write_text("Sample text") + + os.chdir(tmp_path) + + # Push to S3 + push_to_s3(bucket_name, folder) + + # Check if the empty folders are not uploaded + s3_objects = s3.list_objects_v2(Bucket=bucket_name) + object_keys = [item["Key"] for item in s3_objects["Contents"]] + + assert f"{folder}/empty1/" not in object_keys + assert f"{folder}/empty2/" not in object_keys + + # Pull from S3 + pull_from_s3(bucket_name, folder) + + # Check if the empty folders are not created + assert not (tmp_path / "empty1_copy").exists() + assert not (tmp_path / "empty2_copy").exists() + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires Python 3.8+") +def test_s3_session_with_params(): + with patch("boto3.Session") as mock_session: + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "profile_name": "foo", + "region_name": "us-weast-1", + "aws_client_parameters": { + "api_version": "v1", + "config": {"connect_timeout": 300}, + }, + } + ) + get_s3_client( + credentials={ + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + }, + client_parameters={ + "region_name": "us-west-1", + "config": {"signature_version": "s3v4"}, + }, + ) + creds_block = AwsCredentials( + aws_access_key_id="BlockKey", + aws_secret_access_key="BlockSecret", + aws_session_token="BlockToken", + profile_name="BlockProfile", + region_name="BlockRegion", + aws_client_parameters={ + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + "config": {"connect_timeout": 123}, + }, + ) + get_s3_client(credentials=creds_block.dict()) + get_s3_client( + credentials={ + "minio_root_user": "MY_USER", + "minio_root_password": "MY_PASSWORD", + }, + client_parameters={"endpoint_url": "http://custom.minio.endpoint:9000"}, + ) + all_calls = mock_session.mock_calls + assert len(all_calls) == 8 + assert all_calls[0].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": "foo", + "region_name": "us-weast-1", + } + assert all_calls[1].args[0] == "s3" + assert { + "api_version": "v1", + "endpoint_url": None, + "use_ssl": True, + "verify": None, + }.items() <= all_calls[1].kwargs.items() + assert all_calls[1].kwargs.get("config").connect_timeout == 300 + assert all_calls[1].kwargs.get("config").signature_version is None + assert all_calls[2].kwargs == { + "aws_access_key_id": "THE_KEY", + "aws_secret_access_key": "SHHH!", + "aws_session_token": None, + "profile_name": None, + "region_name": "us-west-1", + } + assert all_calls[3].args[0] == "s3" + assert { + "api_version": None, + "endpoint_url": None, + "use_ssl": True, + "verify": None, + }.items() <= all_calls[3].kwargs.items() + assert all_calls[3].kwargs.get("config").connect_timeout == 60 + assert all_calls[3].kwargs.get("config").signature_version == "s3v4" + assert all_calls[4].kwargs == { + "aws_access_key_id": "BlockKey", + "aws_secret_access_key": creds_block.aws_secret_access_key, + "aws_session_token": "BlockToken", + "profile_name": "BlockProfile", + "region_name": "BlockRegion", + } + assert all_calls[5].args[0] == "s3" + assert { + "api_version": "v1", + "use_ssl": True, + "verify": True, + "endpoint_url": "BlockEndpoint", + }.items() <= all_calls[5].kwargs.items() + assert all_calls[5].kwargs.get("config").connect_timeout == 123 + assert all_calls[5].kwargs.get("config").signature_version is None + assert all_calls[6].kwargs == { + "aws_access_key_id": "MY_USER", + "aws_secret_access_key": "MY_PASSWORD", + "aws_session_token": None, + "profile_name": None, + "region_name": None, + } + assert all_calls[7].args[0] == "s3" + assert { + "api_version": None, + "use_ssl": True, + "verify": None, + "endpoint_url": "http://custom.minio.endpoint:9000", + }.items() <= all_calls[7].kwargs.items() + + +def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): + s3, bucket_name = s3_setup + folder = "my-project" + + # Custom credentials and client parameters + custom_credentials = { + "aws_access_key_id": "fake_access_key", + "aws_secret_access_key": "fake_secret_key", + } + + custom_client_parameters = { + "region_name": "us-west-1", + "config": {"signature_version": "s3v4"}, + } + + os.chdir(tmp_files) + + # Test push_to_s3 with custom credentials and client parameters + push_to_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + # Test pull_from_s3 with custom credentials and client parameters + tmp_path = tmp_files / "test_pull" + tmp_path.mkdir(parents=True, exist_ok=True) + os.chdir(tmp_path) + + pull_from_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + for file in tmp_files.iterdir(): + if file.is_file() and file.name != ".prefectignore": + assert (tmp_path / file.name).exists() + + +def test_custom_credentials_and_client_parameters_minio(s3_setup, tmp_files): + s3, bucket_name = s3_setup + folder = "my-project" + + # Custom credentials and client parameters + custom_credentials = { + "minio_root_user": "fake_user", + "minio_root_password": "fake_password", + } + + custom_client_parameters = { + "endpoint_url": "http://custom.minio.endpoint:9000", + } + + os.chdir(tmp_files) + + # Test push_to_s3 with custom credentials and client parameters + push_to_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + # Test pull_from_s3 with custom credentials and client parameters + tmp_path = tmp_files / "test_pull" + tmp_path.mkdir(parents=True, exist_ok=True) + os.chdir(tmp_path) + + pull_from_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + for file in tmp_files.iterdir(): + if file.is_file() and file.name != ".prefectignore": + assert (tmp_path / file.name).exists() + + +def test_without_prefectignore_file(s3_setup, tmp_files: Path, mock_aws_credentials): + s3, bucket_name = s3_setup + folder = "my-project" + + # Remove the .prefectignore file + (tmp_files / ".prefectignore").unlink() + + os.chdir(tmp_files) + + # Test push_to_s3 without .prefectignore file + push_to_s3(bucket_name, folder) + + # Test pull_from_s3 without .prefectignore file + tmp_path = tmp_files / "test_pull" + tmp_path.mkdir(parents=True, exist_ok=True) + os.chdir(tmp_path) + + pull_from_s3(bucket_name, folder) + + for file in tmp_files.iterdir(): + if file.is_file(): + assert (tmp_path / file.name).exists() + + +def test_prefectignore_with_comments_and_empty_lines( + s3_setup, tmp_files: Path, mock_aws_credentials +): + s3, bucket_name = s3_setup + folder = "my-project" + + # Update the .prefectignore file with comments and empty lines + (tmp_files / ".prefectignore").write_text( + """ + # This is a comment + testdir1/* + + .prefectignore + """ + ) + + os.chdir(tmp_files) + + # Test push_to_s3 + push_to_s3(bucket_name, folder) + + # Test pull_from_s3 + tmp_path = tmp_files / "test_pull" + tmp_path.mkdir(parents=True, exist_ok=True) + os.chdir(tmp_path) + + pull_from_s3(bucket_name, folder) + + for file in tmp_files.iterdir(): + if file.is_file() and file.name != ".prefectignore": + assert (tmp_path / file.name).exists() diff --git a/src/integrations/prefect-aws/tests/mock_aws_credentials b/src/integrations/prefect-aws/tests/mock_aws_credentials new file mode 100644 index 000000000000..101e6ca68838 --- /dev/null +++ b/src/integrations/prefect-aws/tests/mock_aws_credentials @@ -0,0 +1,11 @@ +[TEST_PROFILE_1] +aws_access_key_id = mock +aws_secret_access_key = mock +aws_region = us - east - 1 +aws_default_region = us - east - 1 + +[TEST_PROFILE_2] +aws_access_key_id = mock +aws_secret_access_key = mock +aws_region = us - east - 1 +aws_default_region = us - east - 1 diff --git a/src/integrations/prefect-aws/tests/test_batch.py b/src/integrations/prefect-aws/tests/test_batch.py new file mode 100644 index 000000000000..404aa38dee57 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_batch.py @@ -0,0 +1,81 @@ +from uuid import UUID + +import boto3 +import pytest +from moto import mock_batch, mock_iam +from prefect_aws.batch import batch_submit + +from prefect import flow + + +@pytest.fixture(scope="function") +def batch_client(aws_credentials): + with mock_batch(): + yield boto3.client("batch", region_name="us-east-1") + + +@pytest.fixture(scope="function") +def iam_client(aws_credentials): + with mock_iam(): + yield boto3.client("iam", region_name="us-east-1") + + +@pytest.fixture() +def job_queue_arn(iam_client, batch_client): + iam_role = iam_client.create_role( + RoleName="test_batch_client", + AssumeRolePolicyDocument="string", + ) + iam_arn = iam_role.get("Role").get("Arn") + + compute_environment = batch_client.create_compute_environment( + computeEnvironmentName="test_batch_ce", type="UNMANAGED", serviceRole=iam_arn + ) + + compute_environment_arn = compute_environment.get("computeEnvironmentArn") + + created_queue = batch_client.create_job_queue( + jobQueueName="test_batch_queue", + state="ENABLED", + priority=1, + computeEnvironmentOrder=[ + {"order": 1, "computeEnvironment": compute_environment_arn}, + ], + ) + job_queue_arn = created_queue.get("jobQueueArn") + return job_queue_arn + + +@pytest.fixture +def job_definition_arn(batch_client): + job_definition = batch_client.register_job_definition( + jobDefinitionName="test_batch_jobdef", + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "2"], + }, + ) + job_definition_arn = job_definition.get("jobDefinitionArn") + return job_definition_arn + + +def test_batch_submit(job_queue_arn, job_definition_arn, aws_credentials): + @flow + def test_flow(): + return batch_submit( + "batch_test_job", + job_queue_arn, + job_definition_arn, + aws_credentials, + ) + + job_id = test_flow() + + try: + UUID(str(job_id)) + assert True, f"{job_id} is a valid UUID" + except ValueError: + assert False, f"{job_id} is not a valid UUID" diff --git a/src/integrations/prefect-aws/tests/test_block_standards.py b/src/integrations/prefect-aws/tests/test_block_standards.py new file mode 100644 index 000000000000..b0f0d0a1a5ae --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_block_standards.py @@ -0,0 +1,23 @@ +import pytest + +from prefect.blocks.core import Block +from prefect.testing.standard_test_suites import BlockStandardTestSuite +from prefect.utilities.dispatch import get_registry_for_type +from prefect.utilities.importtools import to_qualified_name + + +def find_module_blocks(): + blocks = get_registry_for_type(Block) + module_blocks = [ + block + for block in blocks.values() + if to_qualified_name(block).startswith("prefect_aws") + ] + return module_blocks + + +@pytest.mark.parametrize("block", find_module_blocks()) +class TestAllBlocksAdhereToStandards(BlockStandardTestSuite): + @pytest.fixture + def block(self, block): + return block diff --git a/src/integrations/prefect-aws/tests/test_client_parameters.py b/src/integrations/prefect-aws/tests/test_client_parameters.py new file mode 100644 index 000000000000..d70945965445 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_client_parameters.py @@ -0,0 +1,133 @@ +from typing import Any, Dict + +import pytest +from botocore import UNSIGNED +from botocore.client import Config +from prefect_aws.client_parameters import AwsClientParameters + + +class TestAwsClientParameters: + @pytest.mark.parametrize( + "params,result", + [ + (AwsClientParameters(), {}), + ( + AwsClientParameters( + use_ssl=False, verify=False, endpoint_url="http://localhost:9000" + ), + { + "use_ssl": False, + "verify": False, + "endpoint_url": "http://localhost:9000", + }, + ), + ( + AwsClientParameters(endpoint_url="https://localhost:9000"), + {"endpoint_url": "https://localhost:9000"}, + ), + ( + AwsClientParameters(api_version="1.0.0"), + {"api_version": "1.0.0"}, + ), + ], + ) + def test_get_params_override_expected_output( + self, params: AwsClientParameters, result: Dict[str, Any], tmp_path + ): + if "use_ssl" not in result: + result["use_ssl"] = True + if "verify" not in result: + result["verify"] = True + assert result == params.get_params_override() + + @pytest.mark.parametrize( + "params,result", + [ + ( + AwsClientParameters( + config=dict( + region_name="eu_west_1", + retries={"max_attempts": 10, "mode": "standard"}, + signature_version="unsigned", + ) + ), + { + "config": { + "region_name": "eu_west_1", + "retries": {"max_attempts": 10, "mode": "standard"}, + "signature_version": UNSIGNED, + }, + }, + ), + ], + ) + def test_with_custom_config( + self, params: AwsClientParameters, result: Dict[str, Any] + ): + assert ( + result["config"]["region_name"] + == params.get_params_override()["config"].region_name + ) + assert ( + result["config"]["retries"] + == params.get_params_override()["config"].retries + ) + + def test_with_not_verify_and_verify_cert_path(self, tmp_path): + cert_path = tmp_path / "ca-bundle.crt" + cert_path.touch() + with pytest.warns( + UserWarning, match="verify_cert_path is set but verify is False" + ): + params = AwsClientParameters(verify=False, verify_cert_path=cert_path) + assert params.verify_cert_path is None + assert not params.verify + + def test_get_params_override_with_config_with_deprecated_verify(self, tmp_path): + cert_path = tmp_path / "ca-bundle.crt" + cert_path.touch() + with pytest.warns(DeprecationWarning, match="verify should be a boolean"): + params = AwsClientParameters(verify=cert_path) + assert params.verify + assert not params.verify_cert_path + override_params = params.get_params_override() + override_params["verify"] == cert_path + + def test_get_params_override_with_config(self, tmp_path): + cert_path = tmp_path / "ca-bundle.crt" + cert_path.touch() + params = AwsClientParameters( + config=Config( + region_name="eu_west_1", + retries={"max_attempts": 10, "mode": "standard"}, + ), + verify_cert_path=cert_path, + ) + override_params = params.get_params_override() + override_params["config"].region_name == "eu_west_1" + override_params["config"].retries == { + "max_attempts": 10, + "mode": "standard", + } + + def test_get_params_override_with_verify_cert_path(self, tmp_path): + cert_path = tmp_path / "ca-bundle.crt" + cert_path.touch() + params = AwsClientParameters(verify_cert_path=cert_path) + override_params = params.get_params_override() + assert override_params["verify"] == cert_path + + def test_get_params_override_with_both_cert_path(self, tmp_path): + old_cert_path = tmp_path / "old-ca-bundle.crt" + old_cert_path.touch() + + cert_path = tmp_path / "ca-bundle.crt" + cert_path.touch() + with pytest.warns( + UserWarning, match="verify_cert_path is set but verify is also set" + ): + params = AwsClientParameters( + verify=old_cert_path, verify_cert_path=cert_path + ) + override_params = params.get_params_override() + assert override_params["verify"] == cert_path diff --git a/src/integrations/prefect-aws/tests/test_client_waiter.py b/src/integrations/prefect-aws/tests/test_client_waiter.py new file mode 100644 index 000000000000..6981f0182ba1 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_client_waiter.py @@ -0,0 +1,69 @@ +from unittest.mock import MagicMock + +import pytest +from moto import mock_ec2 +from prefect_aws.client_waiter import client_waiter + +from prefect import flow + + +@pytest.fixture +def mock_waiter(monkeypatch): + waiter = MagicMock(name="mock_waiter") + monkeypatch.setattr( + "prefect_aws.client_waiter.create_waiter_with_client", + waiter, + ) + return waiter + + +@pytest.fixture +def mock_client(monkeypatch, mock_waiter): + client_mock = MagicMock( + waiter_names=["instance_exists"], get_waiter=lambda waiter_name: mock_waiter + ) + client_creator_mock = MagicMock( + ClientCreator=lambda *args, **kwargs: MagicMock( + create_client=lambda *args, **kwargs: client_mock + ) + ) + monkeypatch.setattr("botocore.client", client_creator_mock) + return client_creator_mock + + +@mock_ec2 +def test_client_waiter_custom(mock_waiter, aws_credentials): + @flow + def test_flow(): + waiter = client_waiter( + "batch", + "JobExists", + aws_credentials, + waiter_definition={"waiters": {"JobExists": ["definition"]}, "version": 2}, + ) + return waiter + + test_flow() + mock_waiter().wait.assert_called_once_with() + + +@mock_ec2 +def test_client_waiter_custom_no_definition(mock_waiter, aws_credentials): + @flow + def test_flow(): + waiter = client_waiter("batch", "JobExists", aws_credentials) + return waiter + + with pytest.raises(ValueError, match="The waiter name, JobExists"): + test_flow() + + +@mock_ec2 +def test_client_waiter_boto(mock_waiter, mock_client, aws_credentials): + @flow + def test_flow(): + waiter = client_waiter("ec2", "instance_exists", aws_credentials) + return waiter + + test_flow() + mock_waiter.wait.assert_called_once_with() diff --git a/src/integrations/prefect-aws/tests/test_credentials.py b/src/integrations/prefect-aws/tests/test_credentials.py new file mode 100644 index 000000000000..32e6c1c2a812 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_credentials.py @@ -0,0 +1,190 @@ +import pytest +from boto3.session import Session +from botocore.client import BaseClient +from moto import mock_s3 +from prefect_aws.credentials import ( + AwsCredentials, + ClientType, + MinIOCredentials, + _get_client_cached, +) + + +def test_aws_credentials_get_boto3_session(): + """ + Asserts that instantiated AwsCredentials block creates an + authenticated boto3 session. + """ + + with mock_s3(): + aws_credentials_block = AwsCredentials() + boto3_session = aws_credentials_block.get_boto3_session() + assert isinstance(boto3_session, Session) + + +def test_minio_credentials_get_boto3_session(): + """ + Asserts that instantiated MinIOCredentials block creates + an authenticated boto3 session. + """ + + minio_credentials_block = MinIOCredentials( + minio_root_user="root_user", minio_root_password="root_password" + ) + boto3_session = minio_credentials_block.get_boto3_session() + assert isinstance(boto3_session, Session) + + +@pytest.mark.parametrize( + "credentials", + [ + AwsCredentials(), + MinIOCredentials( + minio_root_user="root_user", minio_root_password="root_password" + ), + ], +) +@pytest.mark.parametrize("client_type", ["s3", ClientType.S3]) +def test_credentials_get_client(credentials, client_type): + with mock_s3(): + assert isinstance(credentials.get_client(client_type), BaseClient) + + +@pytest.mark.parametrize( + "credentials", + [ + AwsCredentials(region_name="us-east-1"), + MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ), + ], +) +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_get_client_cached(credentials, client_type): + """ + Test to ensure that _get_client_cached function returns the same instance + for multiple calls with the same parameters and properly utilizes lru_cache. + """ + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + credentials.get_client(client_type) + credentials.get_client(client_type) + credentials.get_client(client_type) + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_aws_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = AwsCredentials(region_name="us-east-1") + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_minio_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ) + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize( + "credentials_type, initial_field, new_field", + [ + ( + AwsCredentials, + {"region_name": "us-east-1"}, + {"region_name": "us-east-2"}, + ), + ( + MinIOCredentials, + { + "region_name": "us-east-1", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + { + "region_name": "us-east-2", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + ), + ], +) +def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field): + credentials = credentials_type(**initial_field) + initial_hash = hash(credentials) + + setattr(credentials, list(new_field.keys())[0], list(new_field.values())[0]) + new_hash = hash(credentials) + + assert initial_hash != new_hash, "Hash should change when region_name changes" + + +def test_aws_credentials_nested_client_parameters_are_hashable(): + """ + Test to ensure that nested client parameters are hashable. + """ + + creds = AwsCredentials( + region_name="us-east-1", + aws_client_parameters=dict( + config=dict( + connect_timeout=5, + read_timeout=5, + retries=dict(max_attempts=10, mode="standard"), + ) + ), + ) + + assert hash(creds) is not None + + client = creds.get_client("s3") + + _client = creds.get_client("s3") + + assert client is _client diff --git a/src/integrations/prefect-aws/tests/test_ecs.py b/src/integrations/prefect-aws/tests/test_ecs.py new file mode 100644 index 000000000000..a4b1bc703cad --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_ecs.py @@ -0,0 +1,2263 @@ +import json +import logging +import textwrap +from copy import deepcopy +from functools import partial +from typing import Any, Awaitable, Callable, Dict, List, Optional +from unittest.mock import MagicMock + +import anyio +import pytest +import yaml +from botocore.exceptions import ClientError +from moto import mock_ec2, mock_ecs, mock_logs +from moto.ec2.utils import generate_instance_identity_document +from prefect_aws.workers.ecs_worker import ECSWorker +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect._internal.compatibility.deprecated import PrefectDeprecationWarning +from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound +from prefect.logging.configuration import setup_logging +from prefect.server.schemas.core import Deployment, Flow, FlowRun +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.dockerutils import get_prefect_image_name + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import ValidationError +else: + from pydantic import ValidationError + +from prefect_aws.ecs import ( + ECS_DEFAULT_CPU, + ECS_DEFAULT_MEMORY, + ECSTask, + get_container, + get_prefect_container, + parse_task_identifier, +) + + +def test_ecs_task_emits_deprecation_warning(): + with pytest.warns( + PrefectDeprecationWarning, + match=( + "prefect_aws.ecs.ECSTask has been deprecated." + " It will not be available after Sep 2024." + " Use the ECS worker instead." + " Refer to the upgrade guide for more information" + ), + ): + ECSTask() + + +setup_logging() + + +BASE_TASK_DEFINITION_YAML = """ +containerDefinitions: +- cpu: 1024 + image: prefecthq/prefect:2.1.0-python3.8 + memory: 2048 + name: prefect +family: prefect +""" + +BASE_TASK_DEFINITION = yaml.safe_load(BASE_TASK_DEFINITION_YAML) + + +def inject_moto_patches(moto_mock, patches: Dict[str, List[Callable]]): + def injected_call(method, patch_list, *args, **kwargs): + for patch in patch_list: + result = patch(method, *args, **kwargs) + return result + + for account in moto_mock.backends: + for region in moto_mock.backends[account]: + backend = moto_mock.backends[account][region] + + for attr, attr_patches in patches.items(): + original_method = getattr(backend, attr) + setattr( + backend, attr, partial(injected_call, original_method, attr_patches) + ) + + +def patch_run_task(mock, run_task, *args, **kwargs): + """ + Track calls to `run_task` by calling a mock as well. + """ + mock(*args, **kwargs) + return run_task(*args, **kwargs) + + +def patch_describe_tasks_add_prefect_container(describe_tasks, *args, **kwargs): + """ + Adds the minimal prefect container to moto's task description. + """ + result = describe_tasks(*args, **kwargs) + for task in result: + if not task.containers: + task.containers = [] + if not get_prefect_container(task.containers): + task.containers.append({"name": "prefect"}) + return result + + +def patch_calculate_task_resource_requirements( + _calculate_task_resource_requirements, task_definition +): + """ + Adds support for non-EC2 execution modes to moto's calculation of task definition. + """ + for container_definition in task_definition.container_definitions: + container_definition.setdefault("memory", 0) + return _calculate_task_resource_requirements(task_definition) + + +def create_log_stream(session, run_task, *args, **kwargs): + """ + When running a task, create the log group and stream if logging is configured on + containers. + + See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html + """ + tasks = run_task(*args, **kwargs) + if not tasks: + return tasks + task = tasks[0] + + ecs_client = session.client("ecs") + logs_client = session.client("logs") + + task_definition = ecs_client.describe_task_definition( + taskDefinition=task.task_definition_arn + )["taskDefinition"] + + for container in task_definition.get("containerDefinitions", []): + log_config = container.get("logConfiguration", {}) + if log_config: + if log_config.get("logDriver") != "awslogs": + continue + + options = log_config.get("options", {}) + if not options: + raise ValueError("logConfiguration does not include options.") + + group_name = options.get("awslogs-group") + if not group_name: + raise ValueError( + "logConfiguration.options does not include awslogs-group" + ) + + if options.get("awslogs-create-group") == "true": + logs_client.create_log_group(logGroupName=group_name) + + stream_prefix = options.get("awslogs-stream-prefix") + if not stream_prefix: + raise ValueError( + "logConfiguration.options does not include awslogs-stream-prefix" + ) + + logs_client.create_log_stream( + logGroupName=group_name, + logStreamName=f"{stream_prefix}/{container['name']}/{task.id}", + ) + + return tasks + + +def add_ec2_instance_to_ecs_cluster(session, cluster_name): + ecs_client = session.client("ecs") + ec2_client = session.client("ec2") + ec2_resource = session.resource("ec2") + + ecs_client.create_cluster(clusterName=cluster_name) + + images = ec2_client.describe_images() + image_id = images["Images"][0]["ImageId"] + + test_instance = ec2_resource.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1 + )[0] + + ecs_client.register_container_instance( + cluster=cluster_name, + instanceIdentityDocument=json.dumps( + generate_instance_identity_document(test_instance) + ), + ) + + +def create_test_ecs_cluster(ecs_client, cluster_name) -> str: + """ + Create an ECS cluster and return its ARN + """ + return ecs_client.create_cluster(clusterName=cluster_name)["cluster"]["clusterArn"] + + +def describe_task(ecs_client, task_arn, **kwargs) -> dict: + """ + Describe a single ECS task + """ + return ecs_client.describe_tasks(tasks=[task_arn], include=["TAGS"], **kwargs)[ + "tasks" + ][0] + + +async def stop_task(ecs_client, task_arn, **kwargs): + """ + Stop an ECS task. + + Additional keyword arguments are passed to `ECSClient.stop_task`. + """ + task = await run_sync_in_worker_thread(describe_task, ecs_client, task_arn) + # Check that the task started successfully + assert task["lastStatus"] == "RUNNING", "Task should be RUNNING before stopping" + print("Stopping task...") + await run_sync_in_worker_thread(ecs_client.stop_task, task=task_arn, **kwargs) + + +def describe_task_definition(ecs_client, task): + return ecs_client.describe_task_definition( + taskDefinition=task["taskDefinitionArn"] + )["taskDefinition"] + + +async def run_then_stop_task( + task: ECSTask, after_start: Optional[Callable[[str], Awaitable[Any]]] = None +) -> str: + """ + Run an ECS Task then stop it. + + Moto will not advance the state of tasks, so `ECSTask.run` would hang forever if + the run is created successfully and not stopped. + + `after_start` can be used to run something after the task starts but before it is + stopped. It will be passed the task arn. + """ + session = task.aws_credentials.get_boto3_session() + + with anyio.fail_after(20): + async with anyio.create_task_group() as tg: + identifier = await tg.start(task.run) + cluster, task_arn = parse_task_identifier(identifier) + + if after_start: + await after_start(task_arn) + + # Stop the task after it starts to prevent the test from running forever + tg.start_soon( + partial(stop_task, session.client("ecs"), task_arn, cluster=cluster) + ) + + return task_arn + + +@pytest.fixture(autouse=True) +def patch_task_watch_poll_interval(monkeypatch): + # Patch the poll interval to be way shorter for speed during testing! + monkeypatch.setattr(ECSTask.__fields__["task_watch_poll_interval"], "default", 0.05) + + +@pytest.fixture +def ecs_mocks(aws_credentials): + with mock_ecs() as ecs: + with mock_ec2(): + with mock_logs(): + session = aws_credentials.get_boto3_session() + + inject_moto_patches( + ecs, + { + # Ensure container is created in described tasks + "describe_tasks": [patch_describe_tasks_add_prefect_container], + # Fix moto internal resource requirement calculations + "_calculate_task_resource_requirements": [ + patch_calculate_task_resource_requirements + ], + # Add log group creation + "run_task": [partial(create_log_stream, session)], + }, + ) + + create_test_ecs_cluster(session.client("ecs"), "default") + + # NOTE: Even when using FARGATE, moto requires container instances to be + # registered. This differs from AWS behavior. + add_ec2_instance_to_ecs_cluster(session, "default") + + yield ecs + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_launch_types(aws_credentials, launch_type: str): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + launch_type=launch_type, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + if launch_type != "FARGATE_SPOT": + assert launch_type in task_definition["compatibilities"] + assert task["launchType"] == launch_type + else: + assert "FARGATE" in task_definition["compatibilities"] + # FARGATE SPOT requires a null launch type + assert not task.get("launchType") + # Instead, it requires a capacity provider strategy but this is not supported + # by moto and is not present on the task even when provided + # assert task["capacityProviderStrategy"] == [ + # {"capacityProvider": "FARGATE_SPOT", "weight": 1} + # ] + + requires_capabilities = task_definition.get("requiresCompatibilities", []) + if launch_type != "EC2": + assert "FARGATE" in requires_capabilities + else: + assert not requires_capabilities + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +@pytest.mark.parametrize( + "cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)] +) +async def test_cpu_and_memory(aws_credentials, launch_type: str, cpu: int, memory: int): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + launch_type=launch_type, + cpu=cpu, + memory=memory, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + container_definition = get_prefect_container( + task_definition["containerDefinitions"] + ) + overrides = task["overrides"] + container_overrides = get_prefect_container(overrides["containerOverrides"]) + + if launch_type == "EC2": + # EC2 requires CPU and memory to be defined at the container level + assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU + assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY + else: + # Fargate requires CPU and memory to be defined at the task definition level + assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU) + assert task_definition["memory"] == str(memory or ECS_DEFAULT_MEMORY) + + # We always provide non-null values as overrides on the task run + assert overrides.get("cpu") == (str(cpu) if cpu else None) + assert overrides.get("memory") == (str(memory) if memory else None) + # And as overrides for the Prefect container + assert container_overrides.get("cpu") == cpu + assert container_overrides.get("memory") == memory + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_network_mode_default(aws_credentials, launch_type: str): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + launch_type=launch_type, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + if launch_type == "EC2": + assert task_definition["networkMode"] == "bridge" + else: + assert task_definition["networkMode"] == "awsvpc" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_container_command(aws_credentials, launch_type: str): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + launch_type=launch_type, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + + container_overrides = get_prefect_container(task["overrides"]["containerOverrides"]) + assert container_overrides["command"] == ["prefect", "version"] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_environment_variables(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + env={"FOO": "BAR"}, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + prefect_container_definition = get_prefect_container( + task_definition["containerDefinitions"] + ) + assert not prefect_container_definition[ + "environment" + ], "Variables should not be passed until runtime" + + prefect_container_overrides = get_prefect_container( + task["overrides"]["containerOverrides"] + ) + expected = [ + {"name": key, "value": value} + for key, value in ECSTask._base_environment().items() + ] + expected.append({"name": "FOO", "value": "BAR"}) + assert prefect_container_overrides.get("environment") == expected + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_labels(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + labels={"foo": "bar"}, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert not task_definition.get("tags"), "Labels should not be passed until runtime" + + assert task.get("tags") == [{"key": "foo", "value": "bar"}] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_container_command_from_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [{"name": "prefect", "command": ["echo", "hello"]}] + }, + command=[], + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + + container_overrides = get_prefect_container(task["overrides"]["containerOverrides"]) + assert "command" not in container_overrides + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_extra_containers_in_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + {"name": "secondary", "command": ["echo", "hello"], "image": "alpine"} + ] + }, + command=["prefect", "version"], + image="test", + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + user_container = get_container(task_definition["containerDefinitions"], "secondary") + assert ( + user_container is not None + ), "The user-specified container should be present still" + assert user_container["command"] == ["echo", "hello"] + assert user_container["image"] == "alpine", "The image should be left unchanged" + + prefect_container = get_prefect_container(task_definition["containerDefinitions"]) + assert prefect_container is not None, "The prefect container should be added" + assert ( + prefect_container["image"] == "test" + ), "The prefect container should use the image field" + + container_overrides = task["overrides"]["containerOverrides"] + user_container_overrides = get_container(container_overrides, "secondary") + prefect_container_overrides = get_prefect_container(container_overrides) + assert ( + user_container_overrides is None + ), "The user container should not be included in overrides" + assert ( + prefect_container_overrides + ), "The prefect container should have overrides still" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_prefect_container_in_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "command": ["should", "be", "gone"], + "image": "should-be-gone", + "privileged": True, + } + ] + }, + command=["prefect", "version"], + image="test", + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container = get_prefect_container(task_definition["containerDefinitions"]) + + assert ( + prefect_container["image"] == "test" + ), "The prefect container should use the image field" + + assert prefect_container["command"] == [ + "should", + "be", + "gone", + ], "The command should be left unchanged on the task definition" + + assert ( + prefect_container["privileged"] is True + ), "Extra attributes should be retained" + + container_overrides = get_prefect_container(task["overrides"]["containerOverrides"]) + assert container_overrides["command"] == [ + "prefect", + "version", + ], "The command should be passed as an override" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_default_image_in_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "image": "use-this-image", + } + ] + }, + command=["prefect", "version"], + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + # The image on the block is inferred from the task defintinion + assert task.image == "use-this-image" + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container = get_prefect_container(task_definition["containerDefinitions"]) + assert ( + prefect_container["image"] == "use-this-image" + ), "The image from the task definition should be used" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_image_overrides_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "image": "use-this-image", + } + ] + }, + command=["prefect", "version"], + image="override-image", + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container = get_prefect_container(task_definition["containerDefinitions"]) + assert ( + prefect_container["image"] == "override-image" + ), "The provided image should override task definition" + + +@pytest.mark.parametrize( + "task_definition", + [ + # Empty task definition + {}, + # Task definition with prefect container but no image + { + "containerDefinitions": [ + { + "name": "prefect", + } + ] + }, + # Task definition with other container with image + {"containerDefinitions": [{"name": "foo", "image": "not-me-image"}]}, + ], +) +@pytest.mark.usefixtures("ecs_mocks") +async def test_default_image(aws_credentials, task_definition): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition=task_definition, + command=["prefect", "version"], + ) + print(task.preview()) + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + # The image on the block is inferred from Prefect/Python versions + assert task.image == get_prefect_image_name() + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container = get_prefect_container(task_definition["containerDefinitions"]) + assert ( + prefect_container["image"] == get_prefect_image_name() + ), "The image should be the default Prefect tag" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_default_cpu_and_memory_in_task_definition( + aws_credentials, launch_type: str +): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "command": ["should", "be", "gone"], + "image": "should-be-gone", + "cpu": 2048, + "memory": 4096, + } + ], + "cpu": "4096", + "memory": "8192", + }, + command=["prefect", "version"], + image="test", + launch_type=launch_type, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + container_definition = get_prefect_container( + task_definition["containerDefinitions"] + ) + overrides = task["overrides"] + container_overrides = get_prefect_container(overrides["containerOverrides"]) + + # All of these values should be retained + assert container_definition["cpu"] == 2048 + assert container_definition["memory"] == 4096 + assert task_definition["cpu"] == str(4096) + assert task_definition["memory"] == str(8192) + + # No values should be overridden at runtime + assert overrides.get("cpu") is None + assert overrides.get("memory") is None + assert container_overrides.get("cpu") is None + assert container_overrides.get("memory") is None + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_environment_variables_in_task_definition(aws_credentials): + # See also, `test_unset_environment_variables_in_task_definition` + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "environment": [ + {"name": "BAR", "value": "FOO"}, + {"name": "OVERRIDE", "value": "OLD"}, + ], + } + ], + }, + env={"FOO": "BAR", "OVERRIDE": "NEW"}, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + prefect_container_definition = get_prefect_container( + task_definition["containerDefinitions"] + ) + + assert prefect_container_definition["environment"] == [ + {"name": "BAR", "value": "FOO"}, + {"name": "OVERRIDE", "value": "OLD"}, + ] + + prefect_container_overrides = get_prefect_container( + task["overrides"]["containerOverrides"] + ) + expected_base = [ + {"name": key, "value": value} + for key, value in ECSTask._base_environment().items() + ] + assert prefect_container_overrides.get("environment") == expected_base + [ + {"name": "FOO", "value": "BAR"}, + {"name": "OVERRIDE", "value": "NEW"}, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_unset_environment_variables_in_task_definition(aws_credentials): + # In contrast to `test_environment_variables_in_task_definition`, this tests the + # use of `None` in `ECSTask.env` values to signal _removal_ of an environment + # variable instead of overriding a value. + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={ + "containerDefinitions": [ + { + "name": "prefect", + "environment": [ + {"name": "FOO", "value": "FOO"}, + {"name": "BAR", "value": "BAR"}, + ], + } + ] + }, + env={"FOO": None}, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + prefect_container_definition = get_prefect_container( + task_definition["containerDefinitions"] + ) + assert prefect_container_definition["environment"] == [ + {"name": "BAR", "value": "BAR"} + ], "FOO should be removed from the task definition" + + expected_base = [ + {"name": key, "value": value} + for key, value in ECSTask._base_environment().items() + ] + prefect_container_overrides = get_prefect_container( + task["overrides"]["containerOverrides"] + ) + assert ( + prefect_container_overrides.get("environment") == expected_base + ), "FOO should not be passed at runtime" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("provided_as_field", [True, False]) +async def test_execution_role_arn_in_task_definition( + aws_credentials, provided_as_field: bool +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition={"executionRoleArn": "test"}, + execution_role_arn="override" if provided_as_field else None, + ) + print(task.preview()) + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + # Check if it is overridden if provided as a field + assert ( + task_definition["executionRoleArn"] == "test" + if not provided_as_field + else "override" + ) + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("default_cluster", [True, False]) +async def test_cluster(aws_credentials, default_cluster: bool): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + # Construct a non-default cluster. We build this in either case since otherwise + # there is only one cluster and there's no choice but to use the default. + second_cluster_arn = create_test_ecs_cluster(ecs_client, "second-cluster") + add_ec2_instance_to_ecs_cluster(session, "second-cluster") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + cluster=None if default_cluster else "second-cluster", + ) + print(task.preview()) + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + + if default_cluster: + assert task["clusterArn"].endswith("default") + else: + assert task["clusterArn"] == second_cluster_arn + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_execution_role_arn(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + execution_role_arn="test", + ) + print(task.preview()) + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + assert task_definition["executionRoleArn"] == "test" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_role_arn(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_role_arn="test", + ) + print(task.preview()) + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + + assert task["overrides"]["taskRoleArn"] == "test" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_vpc_id(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + task = ECSTask(aws_credentials=aws_credentials, vpc_id=vpc.id) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_default_vpc(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_client = session.client("ec2") + + default_vpc_id = ec2_client.describe_vpcs( + Filters=[{"Name": "isDefault", "Values": ["true"]}] + )["Vpcs"][0]["VpcId"] + default_subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}] + )["Subnets"] + + task = ECSTask(aws_credentials=aws_credentials) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet["SubnetId"] for subnet in default_subnets], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("explicit_network_mode", [True, False]) +async def test_network_config_is_empty_without_awsvpc_network_mode( + aws_credentials, explicit_network_mode +): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + # EC2 uses the 'bridge' network mode by default but we want to have test + # coverage for when it is set on the task definition + task_definition={"networkMode": "bridge"} if explicit_network_mode else None, + # FARGATE requires the 'awsvpc' network mode + launch_type="EC2", + ) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + assert network_configuration is None + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_missing_default_vpc(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_client = session.client("ec2") + + default_vpc_id = ec2_client.describe_vpcs( + Filters=[{"Name": "isDefault", "Values": ["true"]}] + )["Vpcs"][0]["VpcId"] + ec2_client.delete_vpc(VpcId=default_vpc_id) + + task = ECSTask(aws_credentials=aws_credentials) + + with pytest.raises(ValueError, match="Failed to find the default VPC"): + await run_then_stop_task(task) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_vpc_with_no_subnets(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="172.16.0.0/16") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + vpc_id=vpc.id, + ) + print(task.preview()) + + with pytest.raises( + ValueError, match=f"Failed to find subnets for VPC with ID {vpc.id}" + ): + await run_then_stop_task(task) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_logging_requires_execution_role_arn(aws_credentials): + with pytest.raises( + ValidationError, + match="`execution_role_arn` must be provided", + ): + ECSTask( + aws_credentials=aws_credentials, + command=["prefect", "version"], + configure_cloudwatch_logs=True, + ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_log_options_requires_logging(aws_credentials): + with pytest.raises( + ValidationError, + match=( # noqa + "`configure_cloudwatch_log` must be enabled to use" + " `cloudwatch_logs_options`" + ), + ): + ECSTask( + aws_credentials=aws_credentials, + command=["prefect", "version"], + configure_cloudwatch_logs=False, + cloudwatch_logs_options={"foo": " bar"}, + ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_logging_requires_execution_role_arn_at_runtime(aws_credentials): + # In contrast to `test_logging_requires_execution_role_arn`, a task definition + # has been provided by ARN reference and we do not know if the execution role is + # missing until runtime. + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + task_definition_arn = ecs_client.register_task_definition(**BASE_TASK_DEFINITION)[ + "taskDefinition" + ]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + configure_cloudwatch_logs=True, + task_definition_arn=task_definition_arn, + # This test is launch type agnostic but the task definition we register receives + # the default network mode type of 'bridge' which is not compatible with FARGATE + launch_type="EC2", + ) + with pytest.raises(ValueError, match="An execution role arn must be set"): + await task.run() + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_configure_cloudwatch_logging(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + configure_cloudwatch_logs=True, + execution_role_arn="test", + ) + + task_arn = await run_then_stop_task(task) + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + for container in task_definition["containerDefinitions"]: + if container["name"] == "prefect": + # Assert that the 'prefect' container has logging configured + assert container["logConfiguration"] == { + "logDriver": "awslogs", + "options": { + "awslogs-create-group": "true", + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + }, + } + else: + # Other containers should not be modified + assert "logConfiguration" not in container + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_stream_output(aws_credentials, caplog): + session = aws_credentials.get_boto3_session() + logs_client = session.client("logs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + configure_cloudwatch_logs=True, + stream_output=True, + execution_role_arn="test", + # Override the family so it does not match the container name + family="test-family", + # Override the prefix so it does not match the container name + cloudwatch_logs_options={"awslogs-stream-prefix": "test-prefix"}, + # Avoid slow polling during the test + task_watch_poll_interval=0.1, + ) + + async def write_fake_log(task_arn): + # TODO: moto does not appear to support actually reading these logs + # as they do not appear during `get_log_event` calls + # prefix/container-name/task-id + stream_name = f"test-prefix/prefect/{task_arn.rsplit('/')[-1]}" + logs_client.put_log_events( + logGroupName="prefect", + logStreamName=stream_name, + logEvents=[ + {"timestamp": i, "message": f"test-message-{i}"} for i in range(100) + ], + ) + + await run_then_stop_task(task, after_start=write_fake_log) + + logs_client = session.client("logs") + streams = logs_client.describe_log_streams(logGroupName="prefect")["logStreams"] + + assert len(streams) == 1 + + # Ensure we did not encounter any logging errors + assert "Failed to read log events" not in caplog.text + + # TODO: When moto supports reading logs, fix this + # out, err = capsys.readouterr() + # assert "test-message-{i}" in err + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_cloudwatch_log_options(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + configure_cloudwatch_logs=True, + execution_role_arn="test", + cloudwatch_logs_options={ + "max-buffer-size": "2m", + "awslogs-stream-prefix": "override-prefix", + }, + ) + + task_arn = await run_then_stop_task(task) + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + for container in task_definition["containerDefinitions"]: + if container["name"] == "prefect": + # Assert that the 'prefect' container has logging configured with user + # provided options + assert container["logConfiguration"] == { + "logDriver": "awslogs", + "options": { + "awslogs-create-group": "true", + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "override-prefix", + "max-buffer-size": "2m", + }, + } + else: + # Other containers should not be modified + assert "logConfiguration" not in container + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["FARGATE", "FARGATE_SPOT"]) +async def test_bridge_network_mode_warns_on_fargate(aws_credentials, launch_type: str): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["prefect", "version"], + task_definition={"networkMode": "bridge"}, + launch_type=launch_type, + ) + with pytest.warns( + UserWarning, + match=( + "Found network mode 'bridge' which is not compatible with launch type " + f"{launch_type!r}" + ), + ): + with pytest.raises(ClientError): + await run_then_stop_task(task) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_deregister_task_definition(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=True, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["status"] == "INACTIVE" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_latest_task_definition_used_if_equal(aws_credentials): + task = ECSTask(aws_credentials=aws_credentials) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn_1 = await run_then_stop_task(task) + task_arn_2 = await run_then_stop_task(task) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_latest_task_definition_not_used_if_in_another_family( + aws_credentials, +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_1 = ECSTask(aws_credentials=aws_credentials, family="test1") + task_2 = ECSTask(aws_credentials=aws_credentials, family="test2") + + task_arn_1 = await run_then_stop_task(task_1) + task_arn_2 = await run_then_stop_task(task_2) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_latest_task_definition_not_used_if_inequal( + aws_credentials, +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + # Place it in the same family + task_1 = ECSTask( + aws_credentials=aws_credentials, + family="test", + image="image1", + auto_deregister_task_definition=False, + ) + task_2 = ECSTask( + aws_credentials=aws_credentials, + family="test", + image="image2", + auto_deregister_task_definition=False, + ) + + task_arn_1 = await run_then_stop_task(task_1) + task_arn_2 = await run_then_stop_task(task_2) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE"]) +@pytest.mark.parametrize( + "overrides", + [ + {"env": {"FOO": "BAR"}}, + {"command": ["test"]}, + {"labels": {"FOO": "BAR"}}, + {"stream_output": True, "configure_cloudwatch_logs": False}, + {"cluster": "test"}, + {"task_role_arn": "test"}, + # Note: null environment variables can cause override, but not when missing + # from the base task definition + {"env": {"FOO": None}}, + # The following would not result in a copy when using a task_definition_arn + # but will be eagerly set on the new task definition and result in a cache miss + # {"cpu": 2048}, + # {"memory": 4096}, + # {"execution_role_arn": "test"}, + # {"launch_type": "EXTERNAL"}, + ], + ids=lambda item: str(sorted(list(set(item.keys())))), +) +async def test_latest_task_definition_with_overrides_that_do_not_require_copy( + aws_credentials, overrides, launch_type +): + """ + Any of these overrides should be configured at runtime and not require a new + task definition to be registered + """ + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + if "cluster" in overrides: + create_test_ecs_cluster(ecs_client, overrides["cluster"]) + add_ec2_instance_to_ecs_cluster(session, overrides["cluster"]) + + task_1 = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + family="test", + launch_type=launch_type, + ) + task_2 = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + family="test", + launch_type=launch_type, + **overrides, + ) + task_arn_1 = await run_then_stop_task(task_1) + task_arn_2 = await run_then_stop_task(task_2) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + assert ( + task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + ), "The existing task definition should be used" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_definition_arn(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition(**BASE_TASK_DEFINITION)[ + "taskDefinition" + ]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition_arn=task_definition_arn, + launch_type="EC2", + image=None, + ) + print(task.preview()) + task_arn = await run_then_stop_task(task) + + assert task.image is None, "Image option can be null when using task definition arn" + + task = describe_task(ecs_client, task_arn) + assert ( + task["taskDefinitionArn"] == task_definition_arn + ), "The task definition should be used without registering a new one" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "overrides", + [{"image": "new-image"}, {"configure_cloudwatch_logs": True}, {"family": "foobar"}], +) +async def test_task_definition_arn_with_overrides_that_require_copy( + aws_credentials, overrides, caplog +): + """ + Any of these overrides should cause the task definition to be copied and + registered as a new version + """ + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition( + **BASE_TASK_DEFINITION, executionRoleArn="base" + )["taskDefinition"]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition_arn=task_definition_arn, + launch_type="EC2", + **overrides, + ) + print(task.preview()) + with caplog.at_level(logging.INFO, logger=task.logger.name): + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + assert ( + task["taskDefinitionArn"] != task_definition_arn + ), "A new task definition should be registered" + + assert ( + "Settings require changes to the linked task definition. " + "A new task definition will be registered. " + "Enable DEBUG level logs to see the difference." in caplog.text + ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_definition_arn_with_overrides_requiring_copy_shows_diff( + aws_credentials, caplog +): + """ + Any of these overrides should cause the task definition to be copied and + registered as a new version + """ + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition( + **BASE_TASK_DEFINITION, executionRoleArn="base" + )["taskDefinition"]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition_arn=task_definition_arn, + launch_type="EC2", + image="foobar", + ) + print(task.preview()) + with caplog.at_level(logging.DEBUG, logger=task.logger.name): + await run_then_stop_task(task) + + assert ( + "Settings require changes to the linked task definition. " + "A new task definition will be registered. " in caplog.text + ) + + assert "Enable DEBUG level logs to see the difference." not in caplog.text + + expected_diff = textwrap.dedent( + """ + - 'image': 'prefecthq/prefect:2.1.0-python3.8', + + 'image': 'foobar', + """ + ) + assert expected_diff in caplog.text + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "overrides", + [ + {"env": {"FOO": "BAR"}}, + {"command": ["test"]}, + {"labels": {"FOO": "BAR"}}, + {"cpu": 2048}, + {"memory": 4096}, + {"execution_role_arn": "test"}, + {"stream_output": True, "configure_cloudwatch_logs": False}, + {"launch_type": "EXTERNAL"}, + {"cluster": "test"}, + {"task_role_arn": "test"}, + # Note: null environment variables can cause override, but not when missing + # from the base task definition + {"env": {"FOO": None}}, + ], + ids=lambda item: str(sorted(list(set(item.keys())))), +) +async def test_task_definition_arn_with_overrides_that_do_not_require_copy( + aws_credentials, overrides +): + """ + Any of these overrides should be configured at runtime and not require a new + task definition to be registered + """ + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + if "cluster" in overrides: + create_test_ecs_cluster(ecs_client, overrides["cluster"]) + add_ec2_instance_to_ecs_cluster(session, overrides["cluster"]) + + task_definition_arn = ecs_client.register_task_definition( + **BASE_TASK_DEFINITION, + )["taskDefinition"]["taskDefinitionArn"] + + # Set the default launch type for compatibility with the base task definition + overrides.setdefault("launch_type", "EC2") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition_arn=task_definition_arn, + image=None, + **overrides, + ) + print(task.preview()) + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + assert ( + task["taskDefinitionArn"] == task_definition_arn + ), "The existing task definition should be used" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_deregister_task_definition_does_not_apply_to_linked_arn(aws_credentials): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition(**BASE_TASK_DEFINITION)[ + "taskDefinition" + ]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=True, + task_definition_arn=task_definition_arn, + launch_type="EC2", + image=None, + ) + print(task.preview()) + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + describe_task_definition(ecs_client, task)["status"] == "ACTIVE" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_adding_security_groups_to_network_config(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + ec2_client = session.client("ec2") + security_group_id = ec2_client.create_security_group( + GroupName="test", Description="testing" + )["GroupId"] + + task = ECSTask( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", + "value": [security_group_id], + }, + ], + ) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "ENABLED", + "securityGroups": [security_group_id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_disable_public_ip_in_network_config(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + task = ECSTask( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + task_customizations=[ + { + "op": "replace", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + ], + ) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_custom_subnets_in_the_network_configuration(aws_credentials): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + task = ECSTask( + aws_credentials=aws_credentials, + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/subnets", + "value": [subnet.id], + }, + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + ], + ) + + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + print(task.preview()) + + await run_then_stop_task(task) + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_customizations_as_string(aws_credentials): + tc = ( + '[{"op": "replace", "path": "/overrides/cpu", "value": "512"}, ' + '{"op": "replace", "path": "/overrides/memory", "value": "1024"}]' + ) + + task = ECSTask( + aws_credentials=aws_credentials, memory=512, cpu=256, task_customizations=tc + ) # type: ignore + + original_run_task = task._run_task + mock_run_task = MagicMock(side_effect=original_run_task) + task._run_task = mock_run_task + + await run_then_stop_task(task) + + overrides = mock_run_task.call_args[0][1].get("overrides") + + assert overrides["memory"] == "1024" + assert overrides["cpu"] == "512" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "fields,prepare_inputs,expected_family", + [ + # Default + ({}, {}, "prefect"), + # Only flow + ({}, {"flow": Flow(name="foo")}, "prefect__foo"), + # Only deployment + ( + {}, + {"deployment": Deployment.construct(name="foo")}, + "prefect__unknown-flow__foo", + ), + # Flow and deployment + ( + {}, + { + "flow": Flow(name="foo"), + "deployment": Deployment.construct(name="bar"), + }, + "prefect__foo__bar", + ), + # Family provided as a field + ( + {"family": "test"}, + { + "flow": Flow(name="foo"), + "deployment": Deployment.construct(name="bar"), + }, + "test", + ), + # Family provided in a task definition + ( + {"task_definition": {"family": "test"}}, + { + "flow": Flow(name="foo"), + "deployment": Deployment.construct(name="bar"), + }, + "test", + ), + ], +) +async def test_family_from_flow_run_metadata( + aws_credentials, fields, prepare_inputs, expected_family +): + prepare_inputs.setdefault("flow_run", FlowRun.construct()) + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + **fields, + ).prepare_for_flow_run(**prepare_inputs) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["family"] == expected_family + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "given_family,expected_family", + [ + # Default + (None, "prefect"), + ("", "prefect"), + # Length limited to 255 + ("x" * 300, "x" * 255), + # Spaces are not allowed + ("foo bar", "foo-bar"), + # Special characters are not allowed + ("foo*bar&!", "foo-bar"), + ], +) +async def test_user_provided_family(aws_credentials, given_family, expected_family): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + family=given_family, + ) + print(task.preview()) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["family"] == expected_family + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("prepare_for_flow_run", [True, False]) +async def test_family_from_task_definition_arn(aws_credentials, prepare_for_flow_run): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition( + **{**BASE_TASK_DEFINITION, "family": "test-family"} + )["taskDefinition"]["taskDefinitionArn"] + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + task_definition_arn=task_definition_arn, + launch_type="EC2", + image=None, + ) + if prepare_for_flow_run: + task = task.prepare_for_flow_run( + flow_run=FlowRun.construct(), + flow=Flow(name="foo"), + deployment=Deployment.construct(name="bar"), + ) + + print(task.preview()) + + task_arn = await run_then_stop_task(task) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["family"] == "test-family" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "cluster", [None, "default", "second-cluster", "second-cluster-arn"] +) +async def test_kill(aws_credentials, cluster: str): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + # Kill requires cluster-specificity so we test with variable clusters + second_cluster_arn = create_test_ecs_cluster(ecs_client, "second-cluster") + add_ec2_instance_to_ecs_cluster(session, "second-cluster") + + if cluster == "second-cluster-arn": + # Use the actual arn for this test case + cluster = second_cluster_arn + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + cluster=cluster, + ) + print(task.preview()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + identifier = await tg.start(task.run) + + await task.kill(identifier) + + _, task_arn = parse_task_identifier(identifier) + task = describe_task(ecs_client, task_arn) + assert task["lastStatus"] == "STOPPED" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_invalid_identifier(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + ) + print(task.preview()) + + with pytest.raises(ValueError): + await task.kill("test") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_mismatched_cluster(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + cluster="foo", + ) + print(task.preview()) + + with pytest.raises( + InfrastructureNotAvailable, + match=( + "Cannot stop ECS task: this infrastructure block has access to cluster " + "'foo' but the task is running in cluster 'bar'." + ), + ): + await task.kill("bar:::task_arn") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_cluster_that_does_not_exist(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + cluster="foo", + ) + print(task.preview()) + + with pytest.raises( + InfrastructureNotFound, + match="Cannot stop ECS task: the cluster 'foo' could not be found.", + ): + await task.kill("foo::task_arn") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_task_that_does_not_exist(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + cluster="default", + ) + print(task.preview()) + + # Run the task so that a task definition is registered in the cluster + await run_then_stop_task(task) + + with pytest.raises( + InfrastructureNotFound, + match=( + "Cannot stop ECS task: the task 'foo' could not be found in cluster" + " 'default'" + ), + ): + await task.kill("default::foo") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_cluster_that_has_no_tasks(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + cluster="default", + ) + print(task.preview()) + + with pytest.raises( + InfrastructureNotFound, + match="Cannot stop ECS task: the cluster 'default' has no tasks.", + ): + await task.kill("default::foo") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_task_that_is_already_stopped(aws_credentials): + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + command=["sleep", "1000"], + cluster="default", + ) + print(task.preview()) + + # Run and stop the task + task_arn = await run_then_stop_task(task) + + # AWS will happily stop the task "again" + await task.kill(f"default::{task_arn}") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_with_grace_period(aws_credentials, caplog): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = ECSTask( + aws_credentials=aws_credentials, + auto_deregister_task_definition=False, + ) + print(task.preview()) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + identifier = await tg.start(task.run) + + await task.kill(identifier, grace_seconds=60) + + # Task stops correctly + _, task_arn = parse_task_identifier(identifier) + task = describe_task(ecs_client, task_arn) + assert task["lastStatus"] == "STOPPED" + + # Logs warning + assert "grace period of 60s requested, but AWS does not support" in caplog.text + + +@pytest.fixture +def default_base_job_template(): + return deepcopy(ECSWorker.get_default_base_job_template()) + + +@pytest.fixture +def base_job_template_with_defaults(default_base_job_template, aws_credentials): + base_job_template_with_defaults = deepcopy(default_base_job_template) + base_job_template_with_defaults["variables"]["properties"]["command"][ + "default" + ] = "python my_script.py" + base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = { + "VAR1": "value1", + "VAR2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = { + "label1": "value1", + "label2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["name"][ + "default" + ] = "prefect-job" + base_job_template_with_defaults["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][ + "default" + ] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}} + base_job_template_with_defaults["variables"]["properties"]["launch_type"][ + "default" + ] = "FARGATE_SPOT" + base_job_template_with_defaults["variables"]["properties"]["vpc_id"][ + "default" + ] = "vpc-123456" + base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["cluster"][ + "default" + ] = "test-cluster" + base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048 + base_job_template_with_defaults["variables"]["properties"]["memory"][ + "default" + ] = 4096 + + base_job_template_with_defaults["variables"]["properties"]["family"][ + "default" + ] = "test-family" + base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][ + "default" + ] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + base_job_template_with_defaults["variables"]["properties"][ + "cloudwatch_logs_options" + ]["default"] = { + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + } + base_job_template_with_defaults["variables"]["properties"][ + "configure_cloudwatch_logs" + ]["default"] = True + base_job_template_with_defaults["variables"]["properties"]["stream_output"][ + "default" + ] = True + base_job_template_with_defaults["variables"]["properties"][ + "task_watch_poll_interval" + ]["default"] = 5.1 + base_job_template_with_defaults["variables"]["properties"][ + "task_start_timeout_seconds" + ]["default"] = 60 + base_job_template_with_defaults["variables"]["properties"][ + "auto_deregister_task_definition" + ]["default"] = False + base_job_template_with_defaults["variables"]["properties"]["network_configuration"][ + "default" + ] = { + "awsvpcConfiguration": { + "subnets": ["subnet-***"], + "assignPublicIp": "DISABLED", + "securityGroups": ["sg-***"], + } + } + return base_job_template_with_defaults + + +@pytest.fixture +def base_job_template_with_task_arn(default_base_job_template, aws_credentials): + base_job_template_with_task_arn = deepcopy(default_base_job_template) + base_job_template_with_task_arn["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + + base_job_template_with_task_arn["job_configuration"]["task_definition"] = { + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + } + return base_job_template_with_task_arn + + +@pytest.mark.parametrize( + "job_config", + [ + "default", + "custom", + "task_definition_arn", + ], +) +async def test_generate_work_pool_base_job_template( + job_config, + base_job_template_with_defaults, + aws_credentials, + default_base_job_template, + base_job_template_with_task_arn, + caplog, +): + job = ECSTask() + expected_template = default_base_job_template + expected_template["variables"]["properties"]["image"][ + "default" + ] = get_prefect_image_name() + if job_config == "custom": + expected_template = base_job_template_with_defaults + job = ECSTask( + command=["python", "my_script.py"], + env={"VAR1": "value1", "VAR2": "value2"}, + labels={"label1": "value1", "label2": "value2"}, + name="prefect-job", + image="docker.io/my_image:latest", + aws_credentials=aws_credentials, + launch_type="FARGATE_SPOT", + vpc_id="vpc-123456", + task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + cluster="test-cluster", + cpu=2048, + memory=4096, + task_customizations=[ + { + "op": "replace", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/subnets", + "value": ["subnet-***"], + }, + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", + "value": ["sg-***"], + }, + ], + family="test-family", + task_definition_arn=( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + ), + cloudwatch_logs_options={ + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + }, + configure_cloudwatch_logs=True, + stream_output=True, + task_watch_poll_interval=5.1, + task_start_timeout_seconds=60, + auto_deregister_task_definition=False, + ) + elif job_config == "task_definition_arn": + expected_template = base_job_template_with_task_arn + job = ECSTask( + image="docker.io/my_image:latest", + task_definition={ + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": ( + "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + ), + }, + ) + + template = await job.generate_work_pool_base_job_template() + + assert template == expected_template diff --git a/src/integrations/prefect-aws/tests/test_glue_job.py b/src/integrations/prefect-aws/tests/test_glue_job.py new file mode 100644 index 000000000000..7b59846e4b57 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_glue_job.py @@ -0,0 +1,153 @@ +from unittest.mock import MagicMock + +import pytest +from moto import mock_glue +from prefect_aws.glue_job import GlueJobBlock, GlueJobRun + + +@pytest.fixture(scope="function") +def glue_job_client(aws_credentials): + with mock_glue(): + boto_session = aws_credentials.get_boto3_session() + yield boto_session.client("glue", region_name="us-east-1") + + +async def test_fetch_result(aws_credentials, glue_job_client): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + job_run_id = glue_job_client.start_job_run( + JobName="test_job_name", + Arguments={}, + )["JobRunId"] + glue_job_run = GlueJobRun( + job_name="test_job_name", job_id=job_run_id, client=glue_job_client + ) + result = await glue_job_run.fetch_result() + assert result == "SUCCEEDED" + + +def test_wait_for_completion(aws_credentials, glue_job_client): + with mock_glue(): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + job_run_id = glue_job_client.start_job_run( + JobName="test_job_name", + Arguments={}, + )["JobRunId"] + + glue_job_run = GlueJobRun( + job_name="test_job_name", + job_id=job_run_id, + job_watch_poll_interval=0.1, + client=glue_job_client, + ) + + glue_job_client.get_job_run = MagicMock( + side_effect=[ + { + "JobRun": { + "JobName": "test_job_name", + "JobRunState": "RUNNING", + } + }, + { + "JobRun": { + "JobName": "test_job_name", + "JobRunState": "SUCCEEDED", + } + }, + ] + ) + glue_job_run.wait_for_completion() + + +def test_wait_for_completion_fail(aws_credentials, glue_job_client): + with mock_glue(): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + job_run_id = glue_job_client.start_job_run( + JobName="test_job_name", + Arguments={}, + )["JobRunId"] + glue_job_client.get_job_run = MagicMock( + side_effect=[ + { + "JobRun": { + "JobName": "test_job_name", + "JobRunState": "FAILED", + "ErrorMessage": "err", + } + }, + ] + ) + + glue_job_run = GlueJobRun( + job_name="test_job_name", job_id=job_run_id, client=glue_job_client + ) + with pytest.raises(RuntimeError): + glue_job_run.wait_for_completion() + + +def test__get_job_run(aws_credentials, glue_job_client): + with mock_glue(): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + job_run_id = glue_job_client.start_job_run( + JobName="test_job_name", + Arguments={}, + )["JobRunId"] + + glue_job_run = GlueJobRun( + job_name="test_job_name", job_id=job_run_id, client=glue_job_client + ) + response = glue_job_run._get_job_run() + assert response["JobRun"]["JobRunState"] == "SUCCEEDED" + + +async def test_trigger(aws_credentials, glue_job_client): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + glue_job = GlueJobBlock( + job_name="test_job_name", + arguments={"arg1": "value1"}, + aws_credential=aws_credentials, + ) + glue_job._get_client = MagicMock(side_effect=[glue_job_client]) + glue_job._start_job = MagicMock(side_effect=["test_job_id"]) + glue_job_run = await glue_job.trigger() + assert isinstance(glue_job_run, GlueJobRun) + + +def test_start_job(aws_credentials, glue_job_client): + with mock_glue(): + glue_job_client.create_job( + Name="test_job_name", Role="test-role", Command={}, DefaultArguments={} + ) + glue_job = GlueJobBlock(job_name="test_job_name", arguments={"arg1": "value1"}) + + glue_job_client.start_job_run = MagicMock( + side_effect=[{"JobRunId": "test_job_run_id"}] + ) + job_run_id = glue_job._start_job(glue_job_client) + assert job_run_id == "test_job_run_id" + + +def test_start_job_fail_because_not_exist_job(aws_credentials, glue_job_client): + with mock_glue(): + glue_job = GlueJobBlock(job_name="test_job_name", arguments={"arg1": "value1"}) + with pytest.raises(RuntimeError): + glue_job._start_job(glue_job_client) + + +def test_get_client(aws_credentials): + with mock_glue(): + glue_job_run = GlueJobBlock( + job_name="test_job_name", aws_credentials=aws_credentials + ) + client = glue_job_run._get_client() + assert hasattr(client, "get_job_run") diff --git a/src/integrations/prefect-aws/tests/test_lambda_function.py b/src/integrations/prefect-aws/tests/test_lambda_function.py new file mode 100644 index 000000000000..349586b08957 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_lambda_function.py @@ -0,0 +1,253 @@ +import inspect +import io +import json +import zipfile +from typing import Optional + +import boto3 +import pytest +from botocore.response import StreamingBody +from moto import mock_iam, mock_lambda +from prefect_aws.credentials import AwsCredentials +from prefect_aws.lambda_function import LambdaFunction + + +@pytest.fixture +def lambda_mock(aws_credentials: AwsCredentials): + with mock_lambda(): + yield boto3.client( + "lambda", + region_name=aws_credentials.region_name, + ) + + +@pytest.fixture +def iam_mock(aws_credentials: AwsCredentials): + with mock_iam(): + yield boto3.client( + "iam", + region_name=aws_credentials.region_name, + ) + + +@pytest.fixture +def mock_iam_rule(iam_mock): + yield iam_mock.create_role( + RoleName="test-role", + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "lambda.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + ) + + +def handler_a(event, context): + if isinstance(event, dict): + if "error" in event: + raise Exception(event["error"]) + event["foo"] = "bar" + else: + event = {"foo": "bar"} + return event + + +LAMBDA_TEST_CODE = inspect.getsource(handler_a) + + +@pytest.fixture +def mock_lambda_code(): + with io.BytesIO() as f: + with zipfile.ZipFile(f, mode="w") as z: + z.writestr("foo.py", LAMBDA_TEST_CODE) + f.seek(0) + yield f.read() + + +@pytest.fixture +def mock_lambda_function(lambda_mock, mock_iam_rule, mock_lambda_code): + r = lambda_mock.create_function( + FunctionName="test-function", + Runtime="python3.10", + Role=mock_iam_rule["Role"]["Arn"], + Handler="foo.handler", + Code={"ZipFile": mock_lambda_code}, + ) + r2 = lambda_mock.publish_version( + FunctionName="test-function", + ) + r["Version"] = r2["Version"] + yield r + + +def handler_b(event, context): + event = {"data": [1, 2, 3]} + return event + + +LAMBDA_TEST_CODE_V2 = inspect.getsource(handler_b) + + +@pytest.fixture +def mock_lambda_code_v2(): + with io.BytesIO() as f: + with zipfile.ZipFile(f, mode="w") as z: + z.writestr("foo.py", LAMBDA_TEST_CODE_V2) + f.seek(0) + yield f.read() + + +@pytest.fixture +def add_lambda_version(mock_lambda_function, lambda_mock, mock_lambda_code_v2): + r = mock_lambda_function.copy() + lambda_mock.update_function_code( + FunctionName="test-function", + ZipFile=mock_lambda_code_v2, + ) + r2 = lambda_mock.publish_version( + FunctionName="test-function", + ) + r["Version"] = r2["Version"] + yield r + + +@pytest.fixture +def lambda_function(aws_credentials): + return LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ) + + +def make_patched_invocation(client, handler): + """Creates a patched invoke method for moto lambda. The method replaces + the response 'Payload' with the result of the handler function. + """ + true_invoke = client.invoke + + def invoke(*args, **kwargs): + """Calls the true invoke and replaces the Payload with its result.""" + result = true_invoke(*args, **kwargs) + blob = json.dumps( + handler( + event=kwargs.get("Payload"), + context=kwargs.get("ClientContext"), + ) + ).encode() + result["Payload"] = StreamingBody(io.BytesIO(blob), len(blob)) + return result + + return invoke + + +@pytest.fixture +def mock_invoke( + lambda_function: LambdaFunction, handler, monkeypatch: pytest.MonkeyPatch +): + """Fixture to patch the invocation response's 'Payload' field. + + When `result["Payload"].read` is called, moto attempts to run the function + in a Docker container and return the result. This is total overkill, so + we actually call the handler with the given arguments. + """ + client = lambda_function._get_lambda_client() + + monkeypatch.setattr( + client, + "invoke", + make_patched_invocation(client, handler), + ) + + def _get_lambda_client(): + return client + + monkeypatch.setattr( + lambda_function, + "_get_lambda_client", + _get_lambda_client, + ) + + yield + + +class TestLambdaFunction: + def test_init(self, aws_credentials): + function = LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ) + assert function.function_name == "test-function" + assert function.qualifier is None + + @pytest.mark.parametrize( + "payload,expected,handler", + [ + ({"foo": "baz"}, {"foo": "bar"}, handler_a), + (None, {"foo": "bar"}, handler_a), + ], + ) + def test_invoke_lambda_payloads( + self, + payload: Optional[dict], + expected: dict, + handler, + mock_lambda_function, + lambda_function: LambdaFunction, + mock_invoke, + ): + result = lambda_function.invoke(payload) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == expected + + @pytest.mark.parametrize("handler", [handler_a]) + def test_invoke_lambda_tail( + self, lambda_function: LambdaFunction, mock_lambda_function, mock_invoke + ): + result = lambda_function.invoke(tail=True) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == {"foo": "bar"} + assert "LogResult" in result + + @pytest.mark.parametrize("handler", [handler_a]) + def test_invoke_lambda_client_context( + self, lambda_function: LambdaFunction, mock_lambda_function, mock_invoke + ): + # Just making sure boto doesn't throw an error + result = lambda_function.invoke(client_context={"bar": "foo"}) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == {"foo": "bar"} + + @pytest.mark.parametrize( + "func_fixture,expected,handler", + [ + ("mock_lambda_function", {"foo": "bar"}, handler_a), + ("add_lambda_version", {"data": [1, 2, 3]}, handler_b), + ], + ) + def test_invoke_lambda_qualifier( + self, + func_fixture, + expected, + lambda_function: LambdaFunction, + mock_invoke, + request, + ): + func_fixture = request.getfixturevalue(func_fixture) + try: + lambda_function.qualifier = func_fixture["Version"] + result = lambda_function.invoke() + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == expected + finally: + lambda_function.qualifier = None diff --git a/src/integrations/prefect-aws/tests/test_s3.py b/src/integrations/prefect-aws/tests/test_s3.py new file mode 100644 index 000000000000..7a14ce839492 --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_s3.py @@ -0,0 +1,1120 @@ +import io +import os +from pathlib import Path, PurePosixPath, PureWindowsPath + +import boto3 +import pytest +from botocore.exceptions import ClientError, EndpointConnectionError +from moto import mock_s3 +from prefect_aws import AwsCredentials, MinIOCredentials +from prefect_aws.client_parameters import AwsClientParameters +from prefect_aws.s3 import ( + S3Bucket, + s3_copy, + s3_download, + s3_list_objects, + s3_move, + s3_upload, +) + +from prefect import flow +from prefect.deployments import Deployment + +aws_clients = [ + "aws_client_parameters_custom_endpoint", + "aws_client_parameters_empty", + "aws_client_parameters_public_bucket", +] + + +@pytest.fixture +def s3_mock(monkeypatch, client_parameters): + if client_parameters.endpoint_url: + monkeypatch.setenv("MOTO_S3_CUSTOM_ENDPOINTS", client_parameters.endpoint_url) + with mock_s3(): + yield + + +@pytest.fixture +def client_parameters(request): + client_parameters = request.getfixturevalue(request.param) + return client_parameters + + +@pytest.fixture +def bucket(s3_mock, request): + s3 = boto3.resource("s3") + bucket = s3.Bucket("bucket") + marker = request.node.get_closest_marker("is_public", None) + if marker and marker.args[0]: + bucket.create(ACL="public-read") + else: + bucket.create() + return bucket + + +@pytest.fixture +def bucket_2(s3_mock, request): + s3 = boto3.resource("s3") + bucket = s3.Bucket("bucket_2") + marker = request.node.get_closest_marker("is_public", None) + if marker and marker.args[0]: + bucket.create(ACL="public-read") + else: + bucket.create() + return bucket + + +@pytest.fixture +def object(bucket, tmp_path): + file = tmp_path / "object.txt" + file.write_text("TEST") + with open(file, "rb") as f: + return bucket.upload_fileobj(f, "object") + + +@pytest.fixture +def object_in_folder(bucket, tmp_path): + file = tmp_path / "object_in_folder.txt" + file.write_text("TEST OBJECT IN FOLDER") + with open(file, "rb") as f: + return bucket.upload_fileobj(f, "folder/object") + + +@pytest.fixture +def objects_in_folder(bucket, tmp_path): + objects = [] + for filename in [ + "folderobject/foo.txt", + "folderobject/bar.txt", + "folder/object/foo.txt", + "folder/object/bar.txt", + ]: + file = tmp_path / filename + file.parent.mkdir(parents=True, exist_ok=True) + file.write_text("TEST OBJECTS IN FOLDER") + with open(file, "rb") as f: + filename = Path(filename) + obj = bucket.upload_fileobj(f, (filename.parent / filename.stem).as_posix()) + objects.append(obj) + return objects + + +@pytest.fixture +def a_lot_of_objects(bucket, tmp_path): + objects = [] + for i in range(0, 20): + file = tmp_path / f"object{i}.txt" + file.write_text("TEST") + with open(file, "rb") as f: + objects.append(bucket.upload_fileobj(f, f"object{i}")) + return objects + + +@pytest.mark.parametrize( + "client_parameters", + ["aws_client_parameters_custom_endpoint"], + indirect=True, +) +async def test_s3_download_failed_with_wrong_endpoint_setup( + object, client_parameters, aws_credentials +): + client_parameters_wrong_endpoint = AwsClientParameters( + endpoint_url="http://something" + ) + + @flow + async def test_flow(): + return await s3_download( + bucket="bucket", + key="object", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters_wrong_endpoint, + ) + + with pytest.raises(EndpointConnectionError): + await test_flow() + + +@pytest.mark.parametrize( + "client_parameters", + [ + pytest.param( + "aws_client_parameters_custom_endpoint", + marks=pytest.mark.is_public(False), + ), + pytest.param( + "aws_client_parameters_custom_endpoint", + marks=pytest.mark.is_public(True), + ), + pytest.param( + "aws_client_parameters_empty", + marks=pytest.mark.is_public(False), + ), + pytest.param( + "aws_client_parameters_empty", + marks=pytest.mark.is_public(True), + ), + pytest.param( + "aws_client_parameters_public_bucket", + marks=[ + pytest.mark.is_public(False), + pytest.mark.xfail(reason="Bucket is not a public one"), + ], + ), + pytest.param( + "aws_client_parameters_public_bucket", + marks=pytest.mark.is_public(True), + ), + ], + indirect=True, +) +async def test_s3_download(object, client_parameters, aws_credentials): + @flow + async def test_flow(): + return await s3_download( + bucket="bucket", + key="object", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + result = await test_flow() + assert result == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_download_object_not_found(object, client_parameters, aws_credentials): + @flow + async def test_flow(): + return await s3_download( + key="unknown_object", + bucket="bucket", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + with pytest.raises(ClientError): + await test_flow() + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_upload(bucket, client_parameters, tmp_path, aws_credentials): + @flow + async def test_flow(): + test_file = tmp_path / "test.txt" + test_file.write_text("NEW OBJECT") + with open(test_file, "rb") as f: + return await s3_upload( + data=f.read(), + bucket="bucket", + key="new_object", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + await test_flow() + + stream = io.BytesIO() + bucket.download_fileobj("new_object", stream) + stream.seek(0) + output = stream.read() + + assert output == b"NEW OBJECT" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_copy(object, bucket, bucket_2, aws_credentials): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test cross-bucket copy + await s3_copy( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + target_bucket_name="bucket_2", + ) + + # Test within-bucket copy + await s3_copy( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + ) + + await test_flow() + assert read(bucket_2, "subfolder/new_object") == b"TEST" + assert read(bucket, "subfolder/new_object") == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_move(object, bucket, bucket_2, aws_credentials): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test within-bucket move + await s3_move( + source_path="object", + target_path="subfolder/object_copy", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + ) + + # Test cross-bucket move + await s3_move( + source_path="subfolder/object_copy", + target_path="object_copy_2", + source_bucket_name="bucket", + target_bucket_name="bucket_2", + aws_credentials=aws_credentials, + ) + + await test_flow() + + assert read(bucket_2, "object_copy_2") == b"TEST" + + with pytest.raises(ClientError): + read(bucket, "object") + + with pytest.raises(ClientError): + read(bucket, "subfolder/object_copy") + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_move_object_to_nonexistent_bucket_fails( + object, + bucket, + aws_credentials, +): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow(): + # Test cross-bucket move + await s3_move( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + aws_credentials=aws_credentials, + target_bucket_name="nonexistent-bucket", + ) + + with pytest.raises(ClientError): + await test_flow() + + assert read(bucket, "object") == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_move_object_fail_cases( + object, + bucket, + aws_credentials, +): + def read(bucket, key): + stream = io.BytesIO() + bucket.download_fileobj(key, stream) + stream.seek(0) + return stream.read() + + @flow + async def test_flow( + source_path, target_path, source_bucket_name, target_bucket_name + ): + # Test cross-bucket move + await s3_move( + source_path=source_path, + target_path=target_path, + source_bucket_name=source_bucket_name, + aws_credentials=aws_credentials, + target_bucket_name=target_bucket_name, + ) + + # Move to non-existent bucket + with pytest.raises(ClientError): + await test_flow( + source_path="object", + target_path="subfolder/new_object", + source_bucket_name="bucket", + target_bucket_name="nonexistent-bucket", + ) + assert read(bucket, "object") == b"TEST" + + # Move onto self + with pytest.raises(ClientError): + await test_flow( + source_path="object", + target_path="object", + source_bucket_name="bucket", + target_bucket_name="bucket", + ) + assert read(bucket, "object") == b"TEST" + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_list_objects( + object, client_parameters, object_in_folder, aws_credentials +): + @flow + async def test_flow(): + return await s3_list_objects( + bucket="bucket", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + objects = await test_flow() + assert len(objects) == 2 + assert [object["Key"] for object in objects] == ["folder/object", "object"] + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_list_objects_multiple_pages( + a_lot_of_objects, client_parameters, aws_credentials +): + @flow + async def test_flow(): + return await s3_list_objects( + bucket="bucket", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + page_size=2, + ) + + objects = await test_flow() + assert len(objects) == 20 + assert sorted([object["Key"] for object in objects]) == sorted( + [f"object{i}" for i in range(0, 20)] + ) + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_list_objects_prefix( + object, client_parameters, object_in_folder, aws_credentials +): + @flow + async def test_flow(): + return await s3_list_objects( + bucket="bucket", + prefix="folder", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + objects = await test_flow() + assert len(objects) == 1 + assert [object["Key"] for object in objects] == ["folder/object"] + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_list_objects_prefix_slashes( + object, client_parameters, objects_in_folder, aws_credentials +): + @flow + async def test_flow(slash=False): + return await s3_list_objects( + bucket="bucket", + prefix="folder" + ("/" if slash else ""), + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + objects = await test_flow(slash=True) + assert len(objects) == 2 + assert [object["Key"] for object in objects] == [ + "folder/object/bar", + "folder/object/foo", + ] + + objects = await test_flow(slash=False) + assert len(objects) == 4 + assert [object["Key"] for object in objects] == [ + "folder/object/bar", + "folder/object/foo", + "folderobject/bar", + "folderobject/foo", + ] + + +@pytest.mark.parametrize("client_parameters", aws_clients, indirect=True) +async def test_s3_list_objects_filter( + object, client_parameters, object_in_folder, aws_credentials +): + @flow + async def test_flow(): + return await s3_list_objects( + bucket="bucket", + jmespath_query="Contents[?Size > `10`][]", + aws_credentials=aws_credentials, + aws_client_parameters=client_parameters, + ) + + objects = await test_flow() + assert len(objects) == 1 + assert [object["Key"] for object in objects] == ["folder/object"] + + +# S3 BUCKET TESTS BELOW + + +@pytest.fixture +def aws_creds_block(): + return AwsCredentials(aws_access_key_id="testing", aws_secret_access_key="testing") + + +@pytest.fixture +def minio_creds_block(): + return MinIOCredentials( + minio_root_user="minioadmin", minio_root_password="minioadmin" + ) + + +BUCKET_NAME = "test_bucket" + + +@pytest.fixture +def s3(): + """Mock connection to AWS S3 with boto3 client.""" + + with mock_s3(): + yield boto3.client( + service_name="s3", + region_name="us-east-1", + aws_access_key_id="minioadmin", + aws_secret_access_key="testing", + aws_session_token="testing", + ) + + +@pytest.fixture +def nested_s3_bucket_structure(s3, s3_bucket, tmp_path: Path): + """Creates an S3 bucket with multiple files in a nested structure""" + + file = tmp_path / "object.txt" + file.write_text("TEST") + + s3.upload_file(str(file), BUCKET_NAME, "object.txt") + s3.upload_file(str(file), BUCKET_NAME, "level1/object_level1.txt") + s3.upload_file(str(file), BUCKET_NAME, "level1/level2/object_level2.txt") + s3.upload_file(str(file), BUCKET_NAME, "level1/level2/object2_level2.txt") + + file.unlink() + assert not file.exists() + + +@pytest.fixture(params=["aws_credentials", "minio_credentials"]) +def s3_bucket(s3, request, aws_creds_block, minio_creds_block): + key = request.param + + if key == "aws_credentials": + fs = S3Bucket(bucket_name=BUCKET_NAME, credentials=aws_creds_block) + elif key == "minio_credentials": + fs = S3Bucket(bucket_name=BUCKET_NAME, credentials=minio_creds_block) + + s3.create_bucket(Bucket=BUCKET_NAME) + + return fs + + +@pytest.fixture +def s3_bucket_with_file(s3_bucket): + key = s3_bucket.write_path("test.txt", content=b"hello") + return s3_bucket, key + + +async def test_read_write_roundtrip(s3_bucket): + """ + Create an S3 bucket, instantiate S3Bucket block, write to and read from + bucket. + """ + + key = await s3_bucket.write_path("test.txt", content=b"hello") + assert await s3_bucket.read_path(key) == b"hello" + + +async def test_write_with_missing_directory_succeeds(s3_bucket): + """ + Create an S3 bucket, instantiate S3Bucket block, write to path with + missing directory. + """ + + key = await s3_bucket.write_path("folder/test.txt", content=b"hello") + assert await s3_bucket.read_path(key) == b"hello" + + +async def test_read_fails_does_not_exist(s3_bucket): + """ + Create an S3 bucket, instantiate S3Bucket block, assert read from + nonexistent path fails. + """ + + with pytest.raises(ClientError): + await s3_bucket.read_path("test_bucket/foo/bar") + + +@pytest.mark.parametrize("type_", [PureWindowsPath, PurePosixPath, str]) +@pytest.mark.parametrize("delimiter", ["\\", "/"]) +async def test_aws_bucket_folder(s3_bucket, aws_creds_block, delimiter, type_): + """Test the bucket folder functionality.""" + + # create a new block with a subfolder + s3_bucket_block = S3Bucket( + bucket_name=BUCKET_NAME, + credentials=aws_creds_block, + bucket_folder="subfolder/subsubfolder", + ) + + key = await s3_bucket_block.write_path("test.txt", content=b"hello") + assert await s3_bucket_block.read_path("test.txt") == b"hello" + + expected: str = "subfolder/subsubfolder/test.txt" + assert key == expected + + +async def test_get_directory( + nested_s3_bucket_structure, s3_bucket: S3Bucket, tmp_path: Path +): + await s3_bucket.get_directory(local_path=str(tmp_path)) + + assert (tmp_path / "object.txt").exists() + assert (tmp_path / "level1" / "object_level1.txt").exists() + assert (tmp_path / "level1" / "level2" / "object_level2.txt").exists() + assert (tmp_path / "level1" / "level2" / "object2_level2.txt").exists() + + +async def test_get_directory_respects_bucket_folder( + nested_s3_bucket_structure, s3_bucket: S3Bucket, tmp_path: Path, aws_creds_block +): + s3_bucket_block = S3Bucket( + bucket_name=BUCKET_NAME, + credentials=aws_creds_block, + bucket_folder="level1/level2", + ) + + await s3_bucket_block.get_directory(local_path=str(tmp_path)) + + assert (len(list(tmp_path.glob("*")))) == 2 + assert (tmp_path / "object_level2.txt").exists() + assert (tmp_path / "object2_level2.txt").exists() + + +async def test_get_directory_respects_from_path( + nested_s3_bucket_structure, s3_bucket: S3Bucket, tmp_path: Path, aws_creds_block +): + await s3_bucket.get_directory(local_path=str(tmp_path), from_path="level1") + + assert (tmp_path / "object_level1.txt").exists() + assert (tmp_path / "level2" / "object_level2.txt").exists() + assert (tmp_path / "level2" / "object2_level2.txt").exists() + + +async def test_put_directory(s3_bucket: S3Bucket, tmp_path: Path): + (tmp_path / "file1.txt").write_text("FILE 1") + (tmp_path / "file2.txt").write_text("FILE 2") + (tmp_path / "folder1").mkdir() + (tmp_path / "folder1" / "file3.txt").write_text("FILE 3") + (tmp_path / "folder1" / "file4.txt").write_text("FILE 4") + (tmp_path / "folder1" / "folder2").mkdir() + (tmp_path / "folder1" / "folder2" / "file5.txt").write_text("FILE 5") + + uploaded_file_count = await s3_bucket.put_directory(local_path=str(tmp_path)) + assert uploaded_file_count == 5 + + (tmp_path / "downloaded_files").mkdir() + + await s3_bucket.get_directory(local_path=str(tmp_path / "downloaded_files")) + + assert (tmp_path / "downloaded_files" / "file1.txt").exists() + assert (tmp_path / "downloaded_files" / "file2.txt").exists() + assert (tmp_path / "downloaded_files" / "folder1" / "file3.txt").exists() + assert (tmp_path / "downloaded_files" / "folder1" / "file4.txt").exists() + assert ( + tmp_path / "downloaded_files" / "folder1" / "folder2" / "file5.txt" + ).exists() + + +async def test_put_directory_respects_basepath( + s3_bucket: S3Bucket, tmp_path: Path, aws_creds_block +): + (tmp_path / "file1.txt").write_text("FILE 1") + (tmp_path / "file2.txt").write_text("FILE 2") + (tmp_path / "folder1").mkdir() + (tmp_path / "folder1" / "file3.txt").write_text("FILE 3") + (tmp_path / "folder1" / "file4.txt").write_text("FILE 4") + (tmp_path / "folder1" / "folder2").mkdir() + (tmp_path / "folder1" / "folder2" / "file5.txt").write_text("FILE 5") + + s3_bucket_block = S3Bucket( + bucket_name=BUCKET_NAME, + aws_credentials=aws_creds_block, + basepath="subfolder", + ) + + uploaded_file_count = await s3_bucket_block.put_directory(local_path=str(tmp_path)) + assert uploaded_file_count == 5 + + (tmp_path / "downloaded_files").mkdir() + + await s3_bucket_block.get_directory(local_path=str(tmp_path / "downloaded_files")) + + assert (tmp_path / "downloaded_files" / "file1.txt").exists() + assert (tmp_path / "downloaded_files" / "file2.txt").exists() + assert (tmp_path / "downloaded_files" / "folder1" / "file3.txt").exists() + assert (tmp_path / "downloaded_files" / "folder1" / "file4.txt").exists() + assert ( + tmp_path / "downloaded_files" / "folder1" / "folder2" / "file5.txt" + ).exists() + + +async def test_put_directory_with_ignore_file( + s3_bucket: S3Bucket, tmp_path: Path, aws_creds_block +): + (tmp_path / "file1.txt").write_text("FILE 1") + (tmp_path / "file2.txt").write_text("FILE 2") + (tmp_path / "folder1").mkdir() + (tmp_path / "folder1" / "file3.txt").write_text("FILE 3") + (tmp_path / "folder1" / "file4.txt").write_text("FILE 4") + (tmp_path / "folder1" / "folder2").mkdir() + (tmp_path / "folder1" / "folder2" / "file5.txt").write_text("FILE 5") + (tmp_path / ".prefectignore").write_text("folder2/*") + + uploaded_file_count = await s3_bucket.put_directory( + local_path=str(tmp_path / "folder1"), + ignore_file=str(tmp_path / ".prefectignore"), + ) + assert uploaded_file_count == 2 + + (tmp_path / "downloaded_files").mkdir() + + await s3_bucket.get_directory(local_path=str(tmp_path / "downloaded_files")) + + assert (tmp_path / "downloaded_files" / "file3.txt").exists() + assert (tmp_path / "downloaded_files" / "file4.txt").exists() + assert not (tmp_path / "downloaded_files" / "folder2").exists() + assert not (tmp_path / "downloaded_files" / "folder2" / "file5.txt").exists() + + +async def test_put_directory_respects_local_path( + s3_bucket: S3Bucket, tmp_path: Path, aws_creds_block +): + (tmp_path / "file1.txt").write_text("FILE 1") + (tmp_path / "file2.txt").write_text("FILE 2") + (tmp_path / "folder1").mkdir() + (tmp_path / "folder1" / "file3.txt").write_text("FILE 3") + (tmp_path / "folder1" / "file4.txt").write_text("FILE 4") + (tmp_path / "folder1" / "folder2").mkdir() + (tmp_path / "folder1" / "folder2" / "file5.txt").write_text("FILE 5") + + uploaded_file_count = await s3_bucket.put_directory( + local_path=str(tmp_path / "folder1") + ) + assert uploaded_file_count == 3 + + (tmp_path / "downloaded_files").mkdir() + + await s3_bucket.get_directory(local_path=str(tmp_path / "downloaded_files")) + + assert (tmp_path / "downloaded_files" / "file3.txt").exists() + assert (tmp_path / "downloaded_files" / "file4.txt").exists() + assert (tmp_path / "downloaded_files" / "folder2" / "file5.txt").exists() + + +def test_read_path_in_sync_context(s3_bucket_with_file): + """Test that read path works in a sync context.""" + s3_bucket, key = s3_bucket_with_file + content = s3_bucket.read_path(key) + assert content == b"hello" + + +def test_write_path_in_sync_context(s3_bucket): + """Test that write path works in a sync context.""" + key = s3_bucket.write_path("test.txt", content=b"hello") + content = s3_bucket.read_path(key) + assert content == b"hello" + + +def test_deployment_default_basepath(s3_bucket): + deployment = Deployment(name="testing", storage=s3_bucket) + assert deployment.location == "/" + + +def test_deployment_set_basepath(aws_creds_block): + s3_bucket_block = S3Bucket( + bucket_name=BUCKET_NAME, + credentials=aws_creds_block, + bucket_folder="home", + ) + deployment = Deployment(name="testing", storage=s3_bucket_block) + assert deployment.location == "home/" + + +def test_resolve_path(s3_bucket): + assert s3_bucket._resolve_path("") == "" + + +class TestS3Bucket: + @pytest.fixture( + params=[ + AwsCredentials(), + MinIOCredentials(minio_root_user="root", minio_root_password="password"), + ] + ) + def credentials(self, request): + with mock_s3(): + yield request.param + + @pytest.fixture + def s3_bucket_empty(self, credentials, bucket): + _s3_bucket = S3Bucket(bucket_name="bucket", credentials=credentials) + return _s3_bucket + + @pytest.fixture + def s3_bucket_2_empty(self, credentials, bucket_2): + _s3_bucket = S3Bucket( + bucket_name="bucket_2", + credentials=credentials, + bucket_folder="subfolder", + ) + return _s3_bucket + + @pytest.fixture + def s3_bucket_with_object(self, s3_bucket_empty, object): + _s3_bucket_with_object = s3_bucket_empty # object will be added + return _s3_bucket_with_object + + @pytest.fixture + def s3_bucket_2_with_object(self, s3_bucket_2_empty): + _s3_bucket_with_object = s3_bucket_2_empty + s3_bucket_2_empty.write_path("object", content=b"TEST") + return _s3_bucket_with_object + + @pytest.fixture + def s3_bucket_with_objects(self, s3_bucket_with_object, object_in_folder): + _s3_bucket_with_objects = ( + s3_bucket_with_object # object in folder will be added + ) + return _s3_bucket_with_objects + + @pytest.fixture + def s3_bucket_with_similar_objects(self, s3_bucket_with_objects, objects_in_folder): + _s3_bucket_with_multiple_objects = ( + s3_bucket_with_objects # objects in folder will be added + ) + return _s3_bucket_with_multiple_objects + + def test_credentials_are_correct_type(self, credentials): + s3_bucket = S3Bucket(bucket_name="bucket", credentials=credentials) + s3_bucket_parsed = S3Bucket.parse_obj( + {"bucket_name": "bucket", "credentials": dict(credentials)} + ) + assert isinstance(s3_bucket.credentials, type(credentials)) + assert isinstance(s3_bucket_parsed.credentials, type(credentials)) + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_list_objects_empty(self, s3_bucket_empty, client_parameters): + assert s3_bucket_empty.list_objects() == [] + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_list_objects_one(self, s3_bucket_with_object, client_parameters): + objects = s3_bucket_with_object.list_objects() + assert len(objects) == 1 + assert [object["Key"] for object in objects] == ["object"] + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_list_objects(self, s3_bucket_with_objects, client_parameters): + objects = s3_bucket_with_objects.list_objects() + assert len(objects) == 2 + assert [object["Key"] for object in objects] == ["folder/object", "object"] + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_list_objects_with_params( + self, s3_bucket_with_similar_objects, client_parameters + ): + objects = s3_bucket_with_similar_objects.list_objects("folder/object/") + assert len(objects) == 2 + assert [object["Key"] for object in objects] == [ + "folder/object/bar", + "folder/object/foo", + ] + + objects = s3_bucket_with_similar_objects.list_objects("folder") + assert len(objects) == 5 + assert [object["Key"] for object in objects] == [ + "folder/object", + "folder/object/bar", + "folder/object/foo", + "folderobject/bar", + "folderobject/foo", + ] + + @pytest.mark.parametrize("to_path", [Path("to_path"), "to_path", None]) + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_download_object_to_path( + self, s3_bucket_with_object: S3Bucket, to_path, client_parameters, tmp_path + ): + os.chdir(tmp_path) + s3_bucket_with_object.download_object_to_path("object", to_path) + if to_path is None: + to_path = tmp_path / "object" + to_path = Path(to_path) + assert to_path.read_text() == "TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_download_object_to_file_object( + self, s3_bucket_with_object: S3Bucket, client_parameters, tmp_path + ): + to_path = tmp_path / "object" + with open(to_path, "wb") as f: + s3_bucket_with_object.download_object_to_file_object("object", f) + assert to_path.read_text() == "TEST" + + @pytest.mark.parametrize("to_path", [Path("to_path"), "to_path", None]) + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_download_folder_to_path( + self, s3_bucket_with_objects: S3Bucket, client_parameters, tmp_path, to_path + ): + os.chdir(tmp_path) + s3_bucket_with_objects.download_folder_to_path("folder", to_path) + if to_path is None: + to_path = "" + to_path = Path(to_path) + assert (to_path / "object").read_text() == "TEST OBJECT IN FOLDER" + + @pytest.mark.parametrize("to_path", ["to_path", None]) + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_stream_from( + self, + s3_bucket_2_with_object: S3Bucket, + s3_bucket_empty: S3Bucket, + client_parameters, + to_path, + ): + path = s3_bucket_empty.stream_from(s3_bucket_2_with_object, "object", to_path) + data: bytes = s3_bucket_empty.read_path(path) + assert data == b"TEST" + + @pytest.mark.parametrize("to_path", ["new_object", None]) + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_upload_from_path( + self, s3_bucket_empty: S3Bucket, client_parameters, tmp_path, to_path + ): + from_path = tmp_path / "new_object" + from_path.write_text("NEW OBJECT") + s3_bucket_empty.upload_from_path(from_path, to_path) + + with io.BytesIO() as buf: + s3_bucket_empty.download_object_to_file_object("new_object", buf) + buf.seek(0) + output = buf.read() + assert output == b"NEW OBJECT" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_upload_from_file_object( + self, s3_bucket_empty: S3Bucket, client_parameters, tmp_path + ): + with open(tmp_path / "hello", "wb") as f: + f.write(b"NEW OBJECT") + + with open(tmp_path / "hello", "rb") as f: + s3_bucket_empty.upload_from_file_object(f, "new_object") + + with io.BytesIO() as buf: + s3_bucket_empty.download_object_to_file_object("new_object", buf) + buf.seek(0) + output = buf.read() + assert output == b"NEW OBJECT" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_upload_from_folder( + self, s3_bucket_empty: S3Bucket, client_parameters, tmp_path, caplog + ): + from_path = tmp_path / "new_object" + from_path.write_text("NEW OBJECT") + + other_path = tmp_path / "other_object" + other_path.write_text("OTHER OBJECT") + + folder_dir = tmp_path / "folder" + folder_dir.mkdir() + folder_path = folder_dir / "other_object" + folder_path.write_text("FOLDER OBJECT") + + s3_bucket_empty.upload_from_folder(tmp_path) + + new_from_path = tmp_path / "downloaded_new_object" + s3_bucket_empty.download_object_to_path("new_object", new_from_path) + assert new_from_path.read_text() == "NEW OBJECT" + + new_other_path = tmp_path / "downloaded_other_object" + s3_bucket_empty.download_object_to_path("other_object", new_other_path) + assert new_other_path.read_text() == "OTHER OBJECT" + + new_folder_path = tmp_path / "downloaded_folder_object" + s3_bucket_empty.download_object_to_path("folder/other_object", new_folder_path) + assert new_folder_path.read_text() == "FOLDER OBJECT" + + empty_folder = tmp_path / "empty_folder" + empty_folder.mkdir() + s3_bucket_empty.upload_from_folder(empty_folder) + for record in caplog.records: + if "No files were uploaded from {empty_folder}": + break + else: + raise AssertionError("Files did upload") + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_copy_object( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + ): + s3_bucket_with_object.copy_object("object", "object_copy_1") + assert s3_bucket_with_object.read_path("object_copy_1") == b"TEST" + + s3_bucket_with_object.copy_object("object", "folder/object_copy_2") + assert s3_bucket_with_object.read_path("folder/object_copy_2") == b"TEST" + + # S3Bucket for second bucket has a basepath + s3_bucket_with_object.copy_object( + "object", + s3_bucket_2_empty._resolve_path("object_copy_3"), + to_bucket="bucket_2", + ) + assert s3_bucket_2_empty.read_path("object_copy_3") == b"TEST" + + s3_bucket_with_object.copy_object("object", "object_copy_4", s3_bucket_2_empty) + assert s3_bucket_2_empty.read_path("object_copy_4") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + @pytest.mark.parametrize( + "to_bucket, bucket_folder, expected_path", + [ + # to_bucket=None uses the s3_bucket_2_empty fixture + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + ], + ) + def test_copy_subpaths( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + to_bucket, + bucket_folder, + expected_path, + ): + if to_bucket is None: + to_bucket = s3_bucket_2_empty + if bucket_folder is not None: + to_bucket.bucket_folder = bucket_folder + else: + # For testing purposes, don't use bucket folder unless specified + to_bucket.bucket_folder = None + + key = s3_bucket_with_object.copy_object( + "object", + "object", + to_bucket=to_bucket, + ) + assert key == expected_path + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_within_bucket( + self, + s3_bucket_with_object: S3Bucket, + ): + s3_bucket_with_object.move_object("object", "object_copy_1") + assert s3_bucket_with_object.read_path("object_copy_1") == b"TEST" + + with pytest.raises(ClientError): + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_to_nonexistent_bucket_fails( + self, + s3_bucket_with_object: S3Bucket, + ): + with pytest.raises(ClientError): + s3_bucket_with_object.move_object( + "object", "object_copy_1", to_bucket="nonexistent-bucket" + ) + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_onto_itself_fails( + self, + s3_bucket_with_object: S3Bucket, + ): + with pytest.raises(ClientError): + s3_bucket_with_object.move_object("object", "object") + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + def test_move_object_between_buckets( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + ): + s3_bucket_with_object.move_object( + "object", "object_copy_1", to_bucket=s3_bucket_2_empty + ) + assert s3_bucket_2_empty.read_path("object_copy_1") == b"TEST" + + with pytest.raises(ClientError): + assert s3_bucket_with_object.read_path("object") == b"TEST" + + @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) + @pytest.mark.parametrize( + "to_bucket, bucket_folder, expected_path", + [ + # to_bucket=None uses the s3_bucket_2_empty fixture + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + (None, None, "object"), + (None, "subfolder", "subfolder/object"), + ("bucket_2", None, "object"), + ], + ) + def test_move_subpaths( + self, + s3_bucket_with_object: S3Bucket, + s3_bucket_2_empty: S3Bucket, + to_bucket, + bucket_folder, + expected_path, + ): + if to_bucket is None: + to_bucket = s3_bucket_2_empty + if bucket_folder is not None: + to_bucket.bucket_folder = bucket_folder + else: + # For testing purposes, don't use bucket folder unless specified + to_bucket.bucket_folder = None + + key = s3_bucket_with_object.move_object( + "object", + "object", + to_bucket=to_bucket, + ) + assert key == expected_path diff --git a/src/integrations/prefect-aws/tests/test_secrets_manager.py b/src/integrations/prefect-aws/tests/test_secrets_manager.py new file mode 100644 index 000000000000..3d08be0d1c9b --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_secrets_manager.py @@ -0,0 +1,208 @@ +from datetime import datetime, timedelta + +import boto3 +import pytest +from moto import mock_secretsmanager +from prefect_aws.secrets_manager import ( + AwsSecret, + create_secret, + delete_secret, + read_secret, + update_secret, +) + +from prefect import flow + + +@pytest.fixture +def secretsmanager_client(): + with mock_secretsmanager(): + yield boto3.client("secretsmanager", "us-east-1") + + +@pytest.fixture( + params=[ + dict(Name="secret_string_no_version", SecretString="1"), + dict( + Name="secret_string_with_version_id", SecretString="2", should_version=True + ), + dict(Name="secret_binary_no_version", SecretBinary=b"3"), + dict( + Name="secret_binary_with_version_id", SecretBinary=b"4", should_version=True + ), + ] +) +def secret_under_test(secretsmanager_client, request): + should_version = request.param.pop("should_version", False) + secretsmanager_client.create_secret(**request.param) + + update_result = None + if should_version: + if "SecretString" in request.param: + request.param["SecretString"] = request.param["SecretString"] + "-versioned" + elif "SecretBinary" in request.param: + request.param["SecretBinary"] = ( + request.param["SecretBinary"] + b"-versioned" + ) + update_secret_kwargs = request.param.copy() + update_secret_kwargs["SecretId"] = update_secret_kwargs.pop("Name") + update_result = secretsmanager_client.update_secret(**update_secret_kwargs) + + return dict( + secret_name=request.param.get("Name"), + version_id=update_result.get("VersionId") if update_result else None, + expected_value=request.param.get("SecretString") + or request.param.get("SecretBinary"), + ) + + +async def test_read_secret(secret_under_test, aws_credentials): + expected_value = secret_under_test.pop("expected_value") + + @flow + async def test_flow(): + return await read_secret( + aws_credentials=aws_credentials, + **secret_under_test, + ) + + assert (await test_flow()) == expected_value + + +async def test_update_secret(secret_under_test, aws_credentials, secretsmanager_client): + current_secret_value = secret_under_test["expected_value"] + new_secret_value = ( + current_secret_value + "2" + if isinstance(current_secret_value, str) + else current_secret_value + b"2" + ) + + @flow + async def test_flow(): + return await update_secret( + aws_credentials=aws_credentials, + secret_name=secret_under_test["secret_name"], + secret_value=new_secret_value, + ) + + flow_state = await test_flow() + assert flow_state.get("Name") == secret_under_test["secret_name"] + + updated_secret = secretsmanager_client.get_secret_value( + SecretId=secret_under_test["secret_name"] + ) + assert ( + updated_secret.get("SecretString") == new_secret_value + or updated_secret.get("SecretBinary") == new_secret_value + ) + + +@pytest.mark.parametrize( + ["secret_name", "secret_value"], [["string_secret", "42"], ["binary_secret", b"42"]] +) +async def test_create_secret( + aws_credentials, secret_name, secret_value, secretsmanager_client +): + @flow + async def test_flow(): + return await create_secret( + secret_name=secret_name, + secret_value=secret_value, + aws_credentials=aws_credentials, + ) + + flow_state = await test_flow() + assert flow_state.get("Name") == secret_name + + updated_secret = secretsmanager_client.get_secret_value(SecretId=secret_name) + assert ( + updated_secret.get("SecretString") == secret_value + or updated_secret.get("SecretBinary") == secret_value + ) + + +@pytest.mark.parametrize( + ["recovery_window_in_days", "force_delete_without_recovery"], + [ + [30, False], + [20, False], + [7, False], + [8, False], + [10, False], + [15, True], + [29, True], + ], +) +async def test_delete_secret( + aws_credentials, + secret_under_test, + recovery_window_in_days, + force_delete_without_recovery, +): + @flow + async def test_flow(): + return await delete_secret( + secret_name=secret_under_test["secret_name"], + aws_credentials=aws_credentials, + recovery_window_in_days=recovery_window_in_days, + force_delete_without_recovery=force_delete_without_recovery, + ) + + result = await test_flow() + if not force_delete_without_recovery and not 7 <= recovery_window_in_days <= 30: + with pytest.raises(ValueError): + result.get() + else: + assert result.get("Name") == secret_under_test["secret_name"] + deletion_date = result.get("DeletionDate") + + if not force_delete_without_recovery: + assert deletion_date.date() == ( + datetime.utcnow().date() + timedelta(days=recovery_window_in_days) + ) + else: + assert deletion_date.date() == datetime.utcnow().date() + + +class TestAwsSecret: + @pytest.fixture + def aws_secret(self, aws_credentials, secretsmanager_client): + yield AwsSecret(aws_credentials=aws_credentials, secret_name="my-test") + + def test_roundtrip_read_write_delete(self, aws_secret): + arn = "arn:aws:secretsmanager:us-east-1:123456789012:secret" + assert aws_secret.write_secret("my-secret").startswith(arn) + assert aws_secret.read_secret() == b"my-secret" + assert aws_secret.write_secret("my-updated-secret").startswith(arn) + assert aws_secret.read_secret() == b"my-updated-secret" + assert aws_secret.delete_secret().startswith(arn) + + def test_read_secret_version_id(self, aws_secret: AwsSecret): + client = aws_secret.aws_credentials.get_secrets_manager_client() + client.create_secret(Name="my-test", SecretBinary="my-secret") + response = client.update_secret( + SecretId="my-test", SecretBinary="my-updated-secret" + ) + assert ( + aws_secret.read_secret(version_id=response["VersionId"]) + == b"my-updated-secret" + ) + + def test_delete_secret_conflict(self, aws_secret: AwsSecret): + with pytest.raises(ValueError, match="Cannot specify recovery window"): + aws_secret.delete_secret( + force_delete_without_recovery=True, recovery_window_in_days=10 + ) + + def test_delete_secret_recovery_window(self, aws_secret: AwsSecret): + with pytest.raises( + ValueError, match="Recovery window must be between 7 and 30 days" + ): + aws_secret.delete_secret(recovery_window_in_days=42) + + async def test_read_secret(self, secret_under_test, aws_credentials): + secret = AwsSecret( + aws_credentials=aws_credentials, + secret_name=secret_under_test["secret_name"], + ) + assert await secret.read_secret() == secret_under_test["expected_value"] diff --git a/src/integrations/prefect-aws/tests/test_utilities.py b/src/integrations/prefect-aws/tests/test_utilities.py new file mode 100644 index 000000000000..cdff07b8458a --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_utilities.py @@ -0,0 +1,90 @@ +import pytest +from prefect_aws.utilities import ( + assemble_document_for_patches, + ensure_path_exists, + hash_collection, +) + + +class TestHashCollection: + def test_simple_dict(self): + simple_dict = {"key1": "value1", "key2": "value2"} + assert hash_collection(simple_dict) == hash_collection( + simple_dict + ), "Simple dictionary hashing failed" + + def test_nested_dict(self): + nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"} + assert hash_collection(nested_dict) == hash_collection( + nested_dict + ), "Nested dictionary hashing failed" + + def test_complex_structure(self): + complex_structure = { + "key1": [1, 2, 3], + "key2": {"subkey1": {"subsubkey1": "value"}}, + } + assert hash_collection(complex_structure) == hash_collection( + complex_structure + ), "Complex structure hashing failed" + + def test_unhashable_structure(self): + typically_unhashable_structure = dict(key=dict(subkey=[1, 2, 3])) + with pytest.raises(TypeError): + hash(typically_unhashable_structure) + assert hash_collection(typically_unhashable_structure) == hash_collection( + typically_unhashable_structure + ), "Unhashable structure hashing failed after transformation" + + +class TestAssembleDocumentForPatches: + def test_initial_document(self): + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"}, + ] + + initial_document = assemble_document_for_patches(patches) + + expected_document = {"name": {}, "contact": {}, "age": {}} + + assert initial_document == expected_document, "Initial document assembly failed" + + +class TestEnsurePathExists: + def test_existing_path(self): + doc = {"key1": {"subkey1": "value1"}} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": {"subkey1": "value1"} + }, "Existing path modification failed" + + def test_new_path_object(self): + doc = {} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == {"key1": {"subkey1": {}}}, "New path creation for object failed" + + def test_new_path_array(self): + doc = {} + path = ["key1", "0"] + ensure_path_exists(doc, path) + assert doc == {"key1": [{}]}, "New path creation for array failed" + + def test_existing_path_array(self): + doc = {"key1": [{"subkey1": "value1"}]} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": "value1"}] + }, "Existing path modification for array failed" + + def test_existing_path_array_index_out_of_range(self): + doc = {"key1": []} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": {}}] + }, "Existing path modification for array index out of range failed" diff --git a/src/integrations/prefect-aws/tests/test_version.py b/src/integrations/prefect-aws/tests/test_version.py new file mode 100644 index 000000000000..e35743215c9f --- /dev/null +++ b/src/integrations/prefect-aws/tests/test_version.py @@ -0,0 +1,9 @@ +from packaging.version import Version + + +def test_version(): + from prefect_aws import __version__ + + assert isinstance(__version__, str) + assert Version(__version__) + assert __version__.startswith("0.") diff --git a/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py new file mode 100644 index 000000000000..5749e1bb06aa --- /dev/null +++ b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py @@ -0,0 +1,2428 @@ +import json +import logging +from functools import partial +from typing import Any, Awaitable, Callable, Dict, List, Optional +from unittest.mock import ANY, MagicMock +from unittest.mock import patch as mock_patch +from uuid import uuid4 + +import anyio +import botocore +import pytest +import yaml +from moto import mock_ec2, mock_ecs, mock_logs +from moto.ec2.utils import generate_instance_identity_document +from pydantic import VERSION as PYDANTIC_VERSION + +from prefect.server.schemas.core import FlowRun +from prefect.utilities.asyncutils import run_sync_in_worker_thread + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import ValidationError +else: + from pydantic import ValidationError + +from prefect_aws.credentials import _get_client_cached +from prefect_aws.workers.ecs_worker import ( + _TASK_DEFINITION_CACHE, + ECS_DEFAULT_CONTAINER_NAME, + ECS_DEFAULT_CPU, + ECS_DEFAULT_FAMILY, + ECS_DEFAULT_MEMORY, + AwsCredentials, + ECSJobConfiguration, + ECSVariables, + ECSWorker, + InfrastructureNotAvailable, + InfrastructureNotFound, + _get_container, + get_prefect_image_name, + mask_sensitive_env_values, + parse_identifier, +) + +TEST_TASK_DEFINITION_YAML = """ +containerDefinitions: +- cpu: 1024 + image: prefecthq/prefect:2.1.0-python3.8 + memory: 2048 + name: prefect +family: prefect +""" + +TEST_TASK_DEFINITION = yaml.safe_load(TEST_TASK_DEFINITION_YAML) + + +@pytest.fixture +def flow_run(): + return FlowRun(flow_id=uuid4(), deployment_id=uuid4()) + + +@pytest.fixture +def container_status_code(): + yield MagicMock(return_value=0) + + +@pytest.fixture(autouse=True) +def reset_task_definition_cache(): + _TASK_DEFINITION_CACHE.clear() + yield + + +@pytest.fixture(autouse=True) +def patch_task_watch_poll_interval(monkeypatch): + # Patch the poll interval to be way shorter for speed during testing! + monkeypatch.setattr( + ECSVariables.__fields__["task_watch_poll_interval"], "default", 0.05 + ) + + +def inject_moto_patches(moto_mock, patches: Dict[str, List[Callable]]): + def injected_call(method, patch_list, *args, **kwargs): + for patch in patch_list: + result = patch(method, *args, **kwargs) + return result + + for account in moto_mock.backends: + for region in moto_mock.backends[account]: + backend = moto_mock.backends[account][region] + + for attr, attr_patches in patches.items(): + original_method = getattr(backend, attr) + setattr( + backend, attr, partial(injected_call, original_method, attr_patches) + ) + + +def patch_run_task(mock, run_task, *args, **kwargs): + """ + Track calls to `run_task` by calling a mock as well. + """ + mock(*args, **kwargs) + return run_task(*args, **kwargs) + + +def patch_describe_tasks_add_containers( + session, container_status_code, describe_tasks, *args, **kwargs +): + """ + Adds the containers to moto's task description. + + Otherwise, containers is always empty. + """ + ecs_client = session.client("ecs") + + result = describe_tasks(*args, **kwargs) + for task in result: + if not task.containers: + # Pull containers from the task definition + task_definition = ecs_client.describe_task_definition( + taskDefinition=task.task_definition_arn + )["taskDefinition"] + task.containers = [ + { + "name": container["name"], + "exitCode": container_status_code.return_value, + } + for container in task_definition.get("containerDefinitions", []) + ] + + # Populate all the containers in overrides + if task.overrides.get("container_overrides"): + for container in task.overrides["container_overrides"]: + if not _get_container(task.containers, container.name): + task.containers.append( + { + "name": container.name, + "exitCode": container_status_code.return_value, + } + ) + + # Or add the default container + else: + if not _get_container(task.containers, ECS_DEFAULT_CONTAINER_NAME): + task.containers.append( + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "exitCode": container_status_code.return_value, + } + ) + + return result + + +def patch_calculate_task_resource_requirements( + _calculate_task_resource_requirements, task_definition +): + """ + Adds support for non-EC2 execution modes to moto's calculation of task definition. + """ + for container_definition in task_definition.container_definitions: + container_definition.setdefault("memory", 0) + return _calculate_task_resource_requirements(task_definition) + + +def create_log_stream(session, run_task, *args, **kwargs): + """ + When running a task, create the log group and stream if logging is configured on + containers. + + See https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html + """ + tasks = run_task(*args, **kwargs) + if not tasks: + return tasks + task = tasks[0] + + ecs_client = session.client("ecs") + logs_client = session.client("logs") + + task_definition = ecs_client.describe_task_definition( + taskDefinition=task.task_definition_arn + )["taskDefinition"] + + for container in task_definition.get("containerDefinitions", []): + log_config = container.get("logConfiguration", {}) + if log_config: + if log_config.get("logDriver") != "awslogs": + continue + + options = log_config.get("options", {}) + if not options: + raise ValueError("logConfiguration does not include options.") + + group_name = options.get("awslogs-group") + if not group_name: + raise ValueError( + "logConfiguration.options does not include awslogs-group" + ) + + if options.get("awslogs-create-group") == "true": + logs_client.create_log_group(logGroupName=group_name) + + stream_prefix = options.get("awslogs-stream-prefix") + if not stream_prefix: + raise ValueError( + "logConfiguration.options does not include awslogs-stream-prefix" + ) + + logs_client.create_log_stream( + logGroupName=group_name, + logStreamName=f"{stream_prefix}/{container['name']}/{task.id}", + ) + + return tasks + + +def add_ec2_instance_to_ecs_cluster(session, cluster_name): + ecs_client = session.client("ecs") + ec2_client = session.client("ec2") + ec2_resource = session.resource("ec2") + + ecs_client.create_cluster(clusterName=cluster_name) + + images = ec2_client.describe_images() + image_id = images["Images"][0]["ImageId"] + + test_instance = ec2_resource.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1 + )[0] + + ecs_client.register_container_instance( + cluster=cluster_name, + instanceIdentityDocument=json.dumps( + generate_instance_identity_document(test_instance) + ), + ) + + +def create_test_ecs_cluster(ecs_client, cluster_name) -> str: + """ + Create an ECS cluster and return its ARN + """ + return ecs_client.create_cluster(clusterName=cluster_name)["cluster"]["clusterArn"] + + +def describe_task(ecs_client, task_arn, **kwargs) -> dict: + """ + Describe a single ECS task + """ + return ecs_client.describe_tasks(tasks=[task_arn], include=["TAGS"], **kwargs)[ + "tasks" + ][0] + + +async def stop_task(ecs_client, task_arn, **kwargs): + """ + Stop an ECS task. + + Additional keyword arguments are passed to `ECSClient.stop_task`. + """ + task = await run_sync_in_worker_thread(describe_task, ecs_client, task_arn) + # Check that the task started successfully + assert task["lastStatus"] == "RUNNING", "Task should be RUNNING before stopping" + print("Stopping task...") + await run_sync_in_worker_thread(ecs_client.stop_task, task=task_arn, **kwargs) + + +def describe_task_definition(ecs_client, task): + return ecs_client.describe_task_definition( + taskDefinition=task["taskDefinitionArn"] + )["taskDefinition"] + + +@pytest.fixture +def ecs_mocks( + aws_credentials: AwsCredentials, flow_run: FlowRun, container_status_code +): + with mock_ecs() as ecs: + with mock_ec2(): + with mock_logs(): + session = aws_credentials.get_boto3_session() + + inject_moto_patches( + ecs, + { + # Add containers to running tasks — otherwise not included + "describe_tasks": [ + partial( + patch_describe_tasks_add_containers, + session, + container_status_code, + ) + ], + # Fix moto internal resource requirement calculations + "_calculate_task_resource_requirements": [ + patch_calculate_task_resource_requirements + ], + # Add log group creation + "run_task": [partial(create_log_stream, session)], + }, + ) + + create_test_ecs_cluster(session.client("ecs"), "default") + + # NOTE: Even when using FARGATE, moto requires container instances to be + # registered. This differs from AWS behavior. + add_ec2_instance_to_ecs_cluster(session, "default") + + yield ecs + + +async def construct_configuration(**options): + variables = ECSVariables(**options) + print(f"Using variables: {variables.json(indent=2)}") + + configuration = await ECSJobConfiguration.from_template_and_values( + base_job_template=ECSWorker.get_default_base_job_template(), + values={**variables.dict(exclude_none=True)}, + ) + print(f"Constructed test configuration: {configuration.json(indent=2)}") + + return configuration + + +async def construct_configuration_with_job_template( + template_overrides: dict, **variables: dict +): + variables = ECSVariables(**variables) + print(f"Using variables: {variables.json(indent=2)}") + + base_template = ECSWorker.get_default_base_job_template() + for key in template_overrides: + base_template["job_configuration"][key] = template_overrides[key] + + print( + "Using base template configuration:" + f" {json.dumps(base_template['job_configuration'], indent=2)}" + ) + + configuration = await ECSJobConfiguration.from_template_and_values( + base_job_template=base_template, + values={**variables.dict(exclude_none=True)}, + ) + print(f"Constructed test configuration: {configuration.json(indent=2)}") + + return configuration + + +async def run_then_stop_task( + worker: ECSWorker, + configuration: ECSJobConfiguration, + flow_run: FlowRun, + after_start: Optional[Callable[[str], Awaitable[Any]]] = None, +) -> str: + """ + Run an ECS Task then stop it. + + Moto will not advance the state of tasks, so `ECSTask.run` would hang forever if + the run is created successfully and not stopped. + + `after_start` can be used to run something after the task starts but before it is + stopped. It will be passed the task arn. + """ + session = configuration.aws_credentials.get_boto3_session() + result = None + + async def run(task_status): + nonlocal result + result = await worker.run(flow_run, configuration, task_status=task_status) + return + + with anyio.fail_after(20): + async with anyio.create_task_group() as tg: + identifier = await tg.start(run) + cluster, task_arn = parse_identifier(identifier) + + if after_start: + await after_start(task_arn) + + # Stop the task after it starts to prevent the test from running forever + tg.start_soon( + partial(stop_task, session.client("ecs"), task_arn, cluster=cluster) + ) + + return result + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_default(aws_credentials: AwsCredentials, flow_run: FlowRun): + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + task = describe_task(ecs_client, task_arn) + + assert task == { + "attachments": ANY, + "clusterArn": ANY, + # Note: This container list is not generated by moto and our test suite injects + # reasonable values + "containers": [{"exitCode": 0, "name": "prefect"}], + "desiredStatus": "STOPPED", + "lastStatus": "STOPPED", + "launchType": "FARGATE", + "overrides": { + "containerOverrides": [ + {"name": "prefect", "environment": [], "command": ["echo", "test"]} + ] + }, + "startedBy": ANY, + "tags": [], + "taskArn": ANY, + "taskDefinitionArn": ANY, + } + + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["containerDefinitions"] == [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "image": get_prefect_image_name(), + "cpu": 0, + "memory": 0, + "portMappings": [], + "essential": True, + "environment": [], + "mountPoints": [], + "volumesFrom": [], + } + ] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_image(aws_credentials: AwsCredentials, flow_run: FlowRun): + configuration = await construct_configuration( + aws_credentials=aws_credentials, image="prefecthq/prefect-dev:main-python3.9" + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + task = describe_task(ecs_client, task_arn) + assert task["lastStatus"] == "STOPPED" + + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["containerDefinitions"] == [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "image": "prefecthq/prefect-dev:main-python3.9", + "cpu": 0, + "memory": 0, + "portMappings": [], + "essential": True, + "environment": [], + "mountPoints": [], + "volumesFrom": [], + } + ] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_launch_types( + aws_credentials: AwsCredentials, launch_type: str, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, launch_type=launch_type + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track + # 'capacityProviderStrategy' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + if launch_type != "FARGATE_SPOT": + assert launch_type in task_definition["compatibilities"] + assert task["launchType"] == launch_type + else: + assert "FARGATE" in task_definition["compatibilities"] + # FARGATE SPOT requires a null launch type + assert not task.get("launchType") + # Instead, it requires a capacity provider strategy but this is not supported + # by moto and is not present on the task even when provided so we assert on the + # mock call to ensure it is sent + + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ + {"capacityProvider": "FARGATE_SPOT", "weight": 1} + ] + + requires_capabilities = task_definition.get("requiresCompatibilities", []) + if launch_type != "EC2": + assert "FARGATE" in requires_capabilities + else: + assert not requires_capabilities + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +@pytest.mark.parametrize( + "cpu,memory", [(None, None), (1024, None), (None, 2048), (2048, 4096)] +) +async def test_cpu_and_memory( + aws_credentials: AwsCredentials, + launch_type: str, + flow_run: FlowRun, + cpu: int, + memory: int, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, launch_type=launch_type, cpu=cpu, memory=memory + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + container_definition = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + overrides = task["overrides"] + container_overrides = _get_container( + overrides["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + + if launch_type == "EC2": + # EC2 requires CPU and memory to be defined at the container level + assert container_definition["cpu"] == cpu or ECS_DEFAULT_CPU + assert container_definition["memory"] == memory or ECS_DEFAULT_MEMORY + else: + # Fargate requires CPU and memory to be defined at the task definition level + assert task_definition["cpu"] == str(cpu or ECS_DEFAULT_CPU) + assert task_definition["memory"] == str(memory or ECS_DEFAULT_MEMORY) + + # We always provide non-null values as overrides on the task run + assert overrides.get("cpu") == (str(cpu) if cpu else None) + assert overrides.get("memory") == (str(memory) if memory else None) + # And as overrides for the Prefect container + assert container_overrides.get("cpu") == cpu + assert container_overrides.get("memory") == memory + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_network_mode_default( + aws_credentials: AwsCredentials, + launch_type: str, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, launch_type=launch_type + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + if launch_type == "EC2": + assert task_definition["networkMode"] == "bridge" + else: + assert task_definition["networkMode"] == "awsvpc" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_container_command( + aws_credentials: AwsCredentials, + launch_type: str, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + launch_type=launch_type, + command="prefect version", + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + + container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + assert container_overrides["command"] == ["prefect", "version"] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_definition_arn(aws_credentials: AwsCredentials, flow_run: FlowRun): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition(**TEST_TASK_DEFINITION)[ + "taskDefinition" + ]["taskDefinitionArn"] + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + task_definition_arn=task_definition_arn, + launch_type="EC2", + ) + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + print(task) + assert ( + task["taskDefinitionArn"] == task_definition_arn + ), "The task definition should be used without registering a new one" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "overrides", + [{"image": "new-image"}, {"configure_cloudwatch_logs": True}, {"family": "foobar"}], +) +async def test_task_definition_arn_with_variables_that_are_ignored( + aws_credentials, overrides, caplog, flow_run +): + """ + Any of these overrides should cause the task definition to be copied and + registered as a new version + """ + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition( + **TEST_TASK_DEFINITION, executionRoleArn="base" + )["taskDefinition"]["taskDefinitionArn"] + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + task_definition_arn=task_definition_arn, + launch_type="EC2", + **overrides, + ) + + async with ECSWorker(work_pool_name="test") as worker: + with caplog.at_level( + logging.INFO, logger=worker.get_flow_run_logger(flow_run).name + ): + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert ( + task["taskDefinitionArn"] == task_definition_arn + ), "A new task definition should not be registered" + + # TODO: Add logging for this case + # assert ( + # "Settings require changes to the linked task definition. " + # "The settings will be ignored. " + # "Enable DEBUG level logs to see the difference." + # in caplog.text + # ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_environment_variables( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + env={"FOO": "BAR"}, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + prefect_container_definition = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + assert not prefect_container_definition[ + "environment" + ], "Variables should not be passed until runtime" + + prefect_container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + expected = [{"name": "FOO", "value": "BAR"}] + assert prefect_container_overrides.get("environment") == expected + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_labels( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + labels={ + "foo": "bar", + "af_sn253@!$@&$%@(bfausfg!#!*&):@cas{}[]'XY": ( + "af_sn253@!$@&$%@(bfausfg!#!*&):@cas{}[]'XY" + ), + }, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert not task_definition.get("tags"), "Labels should not be passed until runtime" + assert task.get("tags") == [ + { + "key": "foo", + "value": "bar", + }, + { + # Slugified to remove invalid characters + "key": "af_sn253@-@-@-bfausfg-:@cas-XY", + "value": "af_sn253@-@-@-bfausfg-:@cas-XY", + }, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("default_cluster", [True, False]) +async def test_cluster( + aws_credentials: AwsCredentials, flow_run: FlowRun, default_cluster: bool +): + configuration = configuration = await construct_configuration( + cluster=None if default_cluster else "second-cluster", + aws_credentials=aws_credentials, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + # Construct a non-default cluster. We build this in either case since otherwise + # there is only one cluster and there's no choice but to use the default. + second_cluster_arn = create_test_ecs_cluster(ecs_client, "second-cluster") + add_ec2_instance_to_ecs_cluster(session, "second-cluster") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + + if default_cluster: + assert task["clusterArn"].endswith("default") + else: + assert task["clusterArn"] == second_cluster_arn + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_execution_role_arn( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + assert task_definition["executionRoleArn"] == "test" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_task_role_arn( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + task_role_arn="test", + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + task = describe_task(ecs_client, task_arn) + + assert task["overrides"]["taskRoleArn"] == "test" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_vpc_id( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, vpc_id=vpc.id + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_1_subnet_in_custom_settings_1_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_1_sn_in_custom_settings_many_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.3.0/24", VpcId=vpc.id) + ec2_resource.create_subnet(CidrBlock="10.0.4.0/24", VpcId=vpc.id) + + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_many_subnet_in_custom_settings_many_in_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + subnets = [ + ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.33.0/24", VpcId=vpc.id), + ec2_resource.create_subnet(CidrBlock="10.0.44.0/24", VpcId=vpc.id), + ] + subnet_ids = [subnet.id for subnet in subnets] + + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": subnet_ids, + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings_invalid_subnet( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": ["sn-8asdas"], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + with pytest.raises( + ValueError, + match=( + r"Subnets \['sn-8asdas'\] not found within VPC with ID " + + vpc.id + + r"\.Please check that VPC is associated with supplied subnets\." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_custom_settings_invalid_subnet_multiple_vpc_subnets( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="10.0.0.0/16") + security_group = ec2_resource.create_security_group( + GroupName="ECSWorkerTestSG", Description="ECS Worker test SG", VpcId=vpc.id + ) + subnet = ec2_resource.create_subnet(CidrBlock="10.0.2.0/24", VpcId=vpc.id) + invalid_subnet_id = "subnet-3bf19de7" + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + override_network_configuration=True, + network_configuration={ + "subnets": [invalid_subnet_id, subnet.id], + "assignPublicIp": "DISABLED", + "securityGroups": [security_group.id], + }, + ) + + session = aws_credentials.get_boto3_session() + + with pytest.raises( + ValueError, + match=( + rf"Subnets \['{invalid_subnet_id}', '{subnet.id}'\] not found within VPC" + f" with ID {vpc.id}.Please check that VPC is associated with supplied" + " subnets." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_configure_network_requires_vpc_id( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + with pytest.raises( + ValidationError, + match="You must provide a `vpc_id` to enable custom `network_configuration`.", + ): + await construct_configuration( + aws_credentials=aws_credentials, + override_network_configuration=True, + network_configuration={ + "subnets": [], + "assignPublicIp": "ENABLED", + "securityGroups": [], + }, + ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_default_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_client = session.client("ec2") + + default_vpc_id = ec2_client.describe_vpcs( + Filters=[{"Name": "isDefault", "Values": ["true"]}] + )["Vpcs"][0]["VpcId"] + default_subnets = ec2_client.describe_subnets( + Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}] + )["Subnets"] + + configuration = await construct_configuration(aws_credentials=aws_credentials) + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + + # Subnet ids are copied from the vpc + assert network_configuration == { + "awsvpcConfiguration": { + "subnets": [subnet["SubnetId"] for subnet in default_subnets], + "assignPublicIp": "ENABLED", + "securityGroups": [], + } + } + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("explicit_network_mode", [True, False]) +async def test_network_config_is_empty_without_awsvpc_network_mode( + aws_credentials: AwsCredentials, explicit_network_mode: bool, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + # EC2 uses the 'bridge' network mode by default but we want to have test + # coverage for when it is set on the task definition + task_definition={"networkMode": "bridge"} if explicit_network_mode else None, + # FARGATE requires the 'awsvpc' network mode + launch_type="EC2", + ) + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track 'networkConfiguration' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + + network_configuration = mock_run_task.call_args[0][1].get("networkConfiguration") + assert network_configuration is None + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_missing_default_vpc( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_client = session.client("ec2") + + default_vpc_id = ec2_client.describe_vpcs( + Filters=[{"Name": "isDefault", "Values": ["true"]}] + )["Vpcs"][0]["VpcId"] + ec2_client.delete_vpc(VpcId=default_vpc_id) + + configuration = await construct_configuration(aws_credentials=aws_credentials) + + with pytest.raises(ValueError, match="Failed to find the default VPC"): + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_network_config_from_vpc_with_no_subnets( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ec2_resource = session.resource("ec2") + vpc = ec2_resource.create_vpc(CidrBlock="172.16.0.0/16") + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + vpc_id=vpc.id, + ) + + with pytest.raises( + ValueError, match=f"Failed to find subnets for VPC with ID {vpc.id}" + ): + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["FARGATE", "FARGATE_SPOT"]) +async def test_bridge_network_mode_raises_on_fargate( + aws_credentials: AwsCredentials, + flow_run: FlowRun, + launch_type: str, +): + configuration = await construct_configuration_with_job_template( + aws_credentials=aws_credentials, + launch_type=launch_type, + template_overrides=dict(task_definition={"networkMode": "bridge"}), + ) + + with pytest.raises( + ValueError, + match=( + "Found network mode 'bridge' which is not compatible with launch type " + f"{launch_type!r}" + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_stream_output( + aws_credentials: AwsCredentials, flow_run: FlowRun, caplog +): + session = aws_credentials.get_boto3_session() + logs_client = session.client("logs") + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + configure_cloudwatch_logs=True, + stream_output=True, + execution_role_arn="test", + # Override the family so it does not match the container name + family="test-family", + # Override the prefix so it does not match the container name + cloudwatch_logs_options={"awslogs-stream-prefix": "test-prefix"}, + cluster="default", + ) + + async def write_fake_log(task_arn): + # TODO: moto does not appear to support actually reading these logs + # as they do not appear during `get_log_event` calls + # prefix/container-name/task-id + stream_name = f"test-prefix/prefect/{task_arn.rsplit('/')[-1]}" + logs_client.put_log_events( + logGroupName="prefect", + logStreamName=stream_name, + logEvents=[ + {"timestamp": i, "message": f"test-message-{i}"} for i in range(100) + ], + ) + + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task( + worker, configuration, flow_run, after_start=write_fake_log + ) + + logs_client = session.client("logs") + streams = logs_client.describe_log_streams(logGroupName="prefect")["logStreams"] + + assert len(streams) == 1 + + # Ensure we did not encounter any logging errors + assert "Failed to read log events" not in caplog.text + + # TODO: When moto supports reading logs, fix this + # out, err = capsys.readouterr() + # assert "test-message-{i}" in err + + +orig = botocore.client.BaseClient._make_api_call + + +def mock_make_api_call(self, operation_name, kwarg): + if operation_name == "RunTask": + return { + "failures": [ + {"arn": "string", "reason": "string", "detail": "string"}, + ] + } + return orig(self, operation_name, kwarg) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_run_task_error_handling( + aws_credentials: AwsCredentials, + flow_run: FlowRun, + capsys, +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + task_role_arn="test", + ) + + with mock_patch( + "botocore.client.BaseClient._make_api_call", new=mock_make_api_call + ): + async with ECSWorker(work_pool_name="test") as worker: + with pytest.raises(RuntimeError, match="Failed to run ECS task") as exc: + await run_then_stop_task(worker, configuration, flow_run) + + assert exc.value.args[0] == "Failed to run ECS task: string" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "cloudwatch_logs_options", + [ + { + "awslogs-stream-prefix": "override-prefix", + "max-buffer-size": "2m", + }, + { + "max-buffer-size": "2m", + }, + ], +) +async def test_cloudwatch_log_options( + aws_credentials: AwsCredentials, flow_run: FlowRun, cloudwatch_logs_options: dict +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + configure_cloudwatch_logs=True, + execution_role_arn="test", + cloudwatch_logs_options=cloudwatch_logs_options, + ) + work_pool_name = "test" + async with ECSWorker(work_pool_name=work_pool_name) as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + for container in task_definition["containerDefinitions"]: + prefix = f"prefect-logs_{work_pool_name}_{flow_run.deployment_id}" + if cloudwatch_logs_options.get("awslogs-stream-prefix"): + prefix = cloudwatch_logs_options["awslogs-stream-prefix"] + if container["name"] == ECS_DEFAULT_CONTAINER_NAME: + # Assert that the container has logging configured with user + # provided options + assert container["logConfiguration"] == { + "logDriver": "awslogs", + "options": { + "awslogs-create-group": "true", + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": prefix, + "max-buffer-size": "2m", + }, + } + else: + # Other containers should not be modified + assert "logConfiguration" not in container + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_deregister_task_definition( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + auto_deregister_task_definition=True, + ) + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["status"] == "INACTIVE" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_deregister_task_definition_does_not_apply_to_linked_arn( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + task_definition_arn = ecs_client.register_task_definition(**TEST_TASK_DEFINITION)[ + "taskDefinition" + ]["taskDefinitionArn"] + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + auto_deregister_task_definition=True, + task_definition_arn=task_definition_arn, + launch_type="EC2", + ) + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + describe_task_definition(ecs_client, task)["status"] == "ACTIVE" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_match_latest_revision_in_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + ) + + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + ) + + configuration_3 = await construct_configuration( + aws_credentials=aws_credentials, + match_latest_revision_in_family=True, + execution_role_arn="test", + ) + + # Let the first worker run and register two task definitions + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration_1, flow_run) + result_1 = await run_then_stop_task(worker, configuration_2, flow_run) + + # Start a new worker with an empty cache + async with ECSWorker(work_pool_name="test") as worker: + result_2 = await run_then_stop_task(worker, configuration_3, flow_run) + + assert result_1.status_code == 0 + _, task_arn_1 = parse_identifier(result_1.identifier) + + assert result_2.status_code == 0 + _, task_arn_2 = parse_identifier(result_2.identifier) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + assert task_2["taskDefinitionArn"].endswith(":2") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_match_latest_revision_in_family_custom_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + family="test-family", + ) + + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + family="test-family", + ) + + configuration_3 = await construct_configuration( + aws_credentials=aws_credentials, + match_latest_revision_in_family=True, + execution_role_arn="test", + family="test-family", + ) + + # Let the first worker run and register two task definitions + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration_1, flow_run) + result_1 = await run_then_stop_task(worker, configuration_2, flow_run) + + # Start a new worker with an empty cache + async with ECSWorker(work_pool_name="test") as worker: + result_2 = await run_then_stop_task(worker, configuration_3, flow_run) + + assert result_1.status_code == 0 + _, task_arn_1 = parse_identifier(result_1.identifier) + + assert result_2.status_code == 0 + _, task_arn_2 = parse_identifier(result_2.identifier) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + assert task_2["taskDefinitionArn"].endswith(":2") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_caches_registered_task_definitions( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration, flow_run) + result_2 = await run_then_stop_task(worker, configuration, flow_run) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + assert flow_run.deployment_id in _TASK_DEFINITION_CACHE + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_cache_miss_for_registered_task_definitions_clears_from_cache( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration, flow_run) + + # Fail to retrieve from cache on next run + worker._retrieve_task_definition = MagicMock( + side_effect=RuntimeError("failure retrieving from cache") + ) + + result_2 = await run_then_stop_task(worker, configuration, flow_run) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + assert ( + task_1["taskDefinitionArn"] not in _TASK_DEFINITION_CACHE.values() + ), _TASK_DEFINITION_CACHE + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_task_definition_cache_is_per_deployment_id( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration, flow_run) + result_2 = await run_then_stop_task( + worker, configuration, flow_run.copy(update=dict(deployment_id=uuid4())) + ) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "overrides", + [{"image": "new-image"}, {"configure_cloudwatch_logs": True}, {"family": "foobar"}], +) +async def test_worker_task_definition_cache_miss_on_config_changes( + aws_credentials: AwsCredentials, flow_run: FlowRun, overrides: dict +): + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, execution_role_arn="test" + ) + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, execution_role_arn="test", **overrides + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration_1, flow_run) + result_2 = await run_then_stop_task(worker, configuration_2, flow_run) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "overrides", + [{"image": "new-image"}, {"configure_cloudwatch_logs": True}, {"family": "foobar"}], +) +async def test_worker_task_definition_cache_miss_on_deregistered( + aws_credentials: AwsCredentials, flow_run: FlowRun, overrides: dict +): + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + auto_deregister_task_defininition=True, + ) + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, execution_role_arn="test", **overrides + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration_1, flow_run) + result_2 = await run_then_stop_task(worker, configuration_2, flow_run) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE"]) +@pytest.mark.parametrize( + "overrides", + [ + {"env": {"FOO": "BAR"}}, + {"command": "test"}, + {"labels": {"FOO": "BAR"}}, + {"stream_output": True, "configure_cloudwatch_logs": False}, + {"cluster": "test"}, + {"task_role_arn": "test"}, + # Note: null environment variables can cause override, but not when missing + # from the base task definition + {"env": {"FOO": None}}, + # The following would not result in a copy when using a task_definition_arn + # but will be eagerly set on the new task definition and result in a cache miss + # {"cpu": 2048}, + # {"memory": 4096}, + # {"execution_role_arn": "test"}, + # {"launch_type": "EXTERNAL"}, + ], + ids=lambda item: str(sorted(list(set(item.keys())))), +) +async def test_worker_task_definition_cache_hit_on_config_changes( + aws_credentials: AwsCredentials, + flow_run: FlowRun, + overrides: dict, + launch_type: str, +): + """ + Any of these overrides should be configured at runtime and not cause a cache miss + and for a new task definition to be registered + """ + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + launch_type=launch_type, + ) + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + launch_type=launch_type, + **overrides, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + if "cluster" in overrides: + create_test_ecs_cluster(ecs_client, overrides["cluster"]) + add_ec2_instance_to_ecs_cluster(session, overrides["cluster"]) + + async with ECSWorker(work_pool_name="test") as worker: + result_1 = await run_then_stop_task(worker, configuration_1, flow_run) + result_2 = await run_then_stop_task(worker, configuration_2, flow_run) + + assert result_2.status_code == 0 + + _, task_arn_1 = parse_identifier(result_1.identifier) + task_1 = describe_task(ecs_client, task_arn_1) + _, task_arn_2 = parse_identifier(result_2.identifier) + task_2 = describe_task(ecs_client, task_arn_2) + + assert ( + task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + ), "The existing task definition should be used" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_container_command_in_task_definition_template( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + {"name": ECS_DEFAULT_CONTAINER_NAME, "command": ["echo", "hello"]} + ] + } + ), + aws_credentials=aws_credentials, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + + container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + assert "command" not in container_overrides + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_container_command_in_task_definition_template_override( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + {"name": ECS_DEFAULT_CONTAINER_NAME, "command": ["echo", "hello"]} + ] + } + ), + aws_credentials=aws_credentials, + command="echo goodbye", + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + + container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + assert container_overrides["command"] == ["echo", "goodbye"] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_container_in_task_definition_template( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + { + "name": "user-defined-name", + "command": ["echo", "hello"], + "image": "alpine", + } + ] + }, + ), + aws_credentials=aws_credentials, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + user_container = _get_container( + task_definition["containerDefinitions"], "user-defined-name" + ) + assert user_container is not None, "The user-specified container should be present" + assert user_container["command"] == ["echo", "hello"] + assert user_container["image"] == "alpine", "The image should be left unchanged" + + default_container = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + assert default_container is None, "The default container should be not be added" + + container_overrides = task["overrides"]["containerOverrides"] + user_container_overrides = _get_container(container_overrides, "user-defined-name") + default_container_overrides = _get_container( + container_overrides, ECS_DEFAULT_CONTAINER_NAME + ) + assert ( + user_container_overrides + ), "The user defined container should be included in overrides" + assert ( + default_container_overrides is None + ), "The default container should not be in overrides" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_container_image_in_task_definition_template( + aws_credentials: AwsCredentials, + flow_run: FlowRun, +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "image": "use-this-image", + } + ] + }, + ), + aws_credentials=aws_credentials, + image="not-templated-anywhere", + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + assert ( + prefect_container["image"] == "use-this-image" + ), "The image from the task definition should be used" + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize("launch_type", ["EC2", "FARGATE", "FARGATE_SPOT"]) +async def test_user_defined_cpu_and_memory_in_task_definition_template( + aws_credentials: AwsCredentials, launch_type: str, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "command": "{{ command }}", + "image": "{{ image }}", + "cpu": 2048, + "memory": 4096, + } + ], + "cpu": "4096", + "memory": "8192", + }, + ), + aws_credentials=aws_credentials, + launch_type=launch_type, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + container_definition = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + overrides = task["overrides"] + container_overrides = _get_container( + overrides["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + + # All of these values should be retained + assert container_definition["cpu"] == 2048 + assert container_definition["memory"] == 4096 + assert task_definition["cpu"] == str(4096) + assert task_definition["memory"] == str(8192) + + # No values should be overridden at runtime + assert overrides.get("cpu") is None + assert overrides.get("memory") is None + assert container_overrides.get("cpu") is None + assert container_overrides.get("memory") is None + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_environment_variables_in_task_definition_template( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_definition={ + "containerDefinitions": [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "environment": [ + {"name": "BAR", "value": "FOO"}, + {"name": "OVERRIDE", "value": "OLD"}, + ], + } + ], + }, + ), + aws_credentials=aws_credentials, + env={"FOO": "BAR", "OVERRIDE": "NEW"}, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container_definition = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + + assert prefect_container_definition["environment"] == [ + {"name": "BAR", "value": "FOO"}, + {"name": "OVERRIDE", "value": "OLD"}, + ] + + prefect_container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + assert prefect_container_overrides.get("environment") == [ + {"name": "FOO", "value": "BAR"}, + {"name": "OVERRIDE", "value": "NEW"}, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_capacity_provider_strategy( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + capacity_provider_strategy=[ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"} + ], + ) + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + # Capture the task run call because moto does not track + # 'capacityProviderStrategy' + original_run_task = worker._create_task_run + mock_run_task = MagicMock(side_effect=original_run_task) + worker._create_task_run = mock_run_task + + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert not task.get("launchType") + # Instead, it requires a capacity provider strategy but this is not supported + # by moto and is not present on the task even when provided so we assert on the + # mock call to ensure it is sent + assert mock_run_task.call_args[0][1].get("capacityProviderStrategy") == [ + {"base": 0, "weight": 1, "capacityProvider": "r6i.large"}, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_environment_variables_in_task_run_request_template( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_run_request={ + "overrides": { + "containerOverrides": [ + { + "name": ECS_DEFAULT_CONTAINER_NAME, + "environment": [ + {"name": "BAR", "value": "FOO"}, + {"name": "OVERRIDE", "value": "OLD"}, + {"name": "UNSET", "value": "GONE"}, + ], + } + ], + }, + }, + ), + aws_credentials=aws_credentials, + env={"FOO": "BAR", "OVERRIDE": "NEW", "UNSET": None}, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + + prefect_container_definition = _get_container( + task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME + ) + + assert ( + prefect_container_definition["environment"] == [] + ), "No environment variables in the task definition" + + prefect_container_overrides = _get_container( + task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME + ) + assert prefect_container_overrides.get("environment") == [ + {"name": "BAR", "value": "FOO"}, + {"name": "FOO", "value": "BAR"}, + {"name": "OVERRIDE", "value": "NEW"}, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_user_defined_tags_in_task_run_request_template( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration_with_job_template( + template_overrides=dict( + task_run_request={ + "tags": [ + {"key": "BAR", "value": "FOO"}, + {"key": "OVERRIDE", "value": "OLD"}, + ] + }, + ), + aws_credentials=aws_credentials, + labels={"FOO": "BAR", "OVERRIDE": "NEW"}, + ) + + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + async with ECSWorker(work_pool_name="test") as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + assert task.get("tags") == [ + {"key": "BAR", "value": "FOO"}, + {"key": "FOO", "value": "BAR"}, + {"key": "OVERRIDE", "value": "NEW"}, + ] + + +@pytest.mark.usefixtures("ecs_mocks") +@pytest.mark.parametrize( + "cluster", [None, "default", "second-cluster", "second-cluster-arn"] +) +async def test_kill_infrastructure(aws_credentials, cluster: str, flow_run): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + # Kill requires cluster-specificity so we test with variable clusters + second_cluster_arn = create_test_ecs_cluster(ecs_client, "second-cluster") + add_ec2_instance_to_ecs_cluster(session, "second-cluster") + + if cluster == "second-cluster-arn": + # Use the actual arn for this test case + cluster = second_cluster_arn + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster=cluster, + ) + + with anyio.fail_after(5): + async with ECSWorker(work_pool_name="test") as worker: + async with anyio.create_task_group() as tg: + identifier = await tg.start(worker.run, flow_run, configuration) + + await worker.kill_infrastructure( + configuration=configuration, infrastructure_pid=identifier + ) + + _, task_arn = parse_identifier(identifier) + task = describe_task(ecs_client, task_arn) + assert task["lastStatus"] == "STOPPED" + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_invalid_identifier(aws_credentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + with pytest.raises(ValueError): + async with ECSWorker(work_pool_name="test") as worker: + await worker.kill_infrastructure(configuration, "test") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_mismatched_cluster(aws_credentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster="foo", + ) + + with pytest.raises( + InfrastructureNotAvailable, + match=( + "Cannot stop ECS task: this infrastructure block has access to cluster " + "'foo' but the task is running in cluster 'bar'." + ), + ): + async with ECSWorker(work_pool_name="test") as worker: + await worker.kill_infrastructure(configuration, "bar:::task_arn") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_cluster_that_does_not_exist(aws_credentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster="foo", + ) + + with pytest.raises( + InfrastructureNotFound, + match="Cannot stop ECS task: the cluster 'foo' could not be found.", + ): + async with ECSWorker(work_pool_name="test") as worker: + await worker.kill_infrastructure(configuration, "foo::task_arn") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_task_that_does_not_exist( + aws_credentials, flow_run +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster="default", + ) + + # Run the task so that a task definition is registered in the cluster + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + with pytest.raises( + InfrastructureNotFound, + match=( + "Cannot stop ECS task: the task 'foo' could not be found in cluster" + " 'default'" + ), + ): + await worker.kill_infrastructure(configuration, "default::foo") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_cluster_that_has_no_tasks(aws_credentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster="default", + ) + + with pytest.raises( + InfrastructureNotFound, + match="Cannot stop ECS task: the cluster 'default' has no tasks.", + ): + async with ECSWorker(work_pool_name="test") as worker: + await worker.kill_infrastructure(configuration, "default::foo") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_task_that_is_already_stopped( + aws_credentials, flow_run +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + cluster="default", + ) + + async with ECSWorker(work_pool_name="test") as worker: + # Run and stop the task + result = await run_then_stop_task(worker, configuration, flow_run) + _, task_arn = parse_identifier(result.identifier) + + # AWS will happily stop the task "again" + await worker.kill_infrastructure(configuration, f"default::{task_arn}") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_kill_infrastructure_with_grace_period(aws_credentials, caplog, flow_run): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + async with ECSWorker(work_pool_name="test") as worker: + identifier = await tg.start(worker.run, flow_run, configuration) + + await worker.kill_infrastructure( + configuration, identifier, grace_seconds=60 + ) + + # Task stops correctly + _, task_arn = parse_identifier(identifier) + task = describe_task(ecs_client, task_arn) + assert task["lastStatus"] == "STOPPED" + + # Logs warning + assert "grace period of 60s requested, but AWS does not support" in caplog.text + + +async def test_retry_on_failed_task_start( + aws_credentials: AwsCredentials, flow_run, ecs_mocks +): + run_task_mock = MagicMock(return_value=[]) + + configuration = await construct_configuration( + aws_credentials=aws_credentials, command="echo test" + ) + + inject_moto_patches( + ecs_mocks, + { + "run_task": [run_task_mock], + }, + ) + + with pytest.raises(RuntimeError): + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration, flow_run) + + assert run_task_mock.call_count == 3 + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_uses_cached_boto3_client(aws_credentials: AwsCredentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + async with ECSWorker(work_pool_name="test") as worker: + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + +async def test_mask_sensitive_env_values(): + task_run_request = { + "overrides": { + "containerOverrides": [ + { + "environment": [ + {"name": "PREFECT_API_KEY", "value": "SeNsItiVe VaLuE"}, + {"name": "PREFECT_API_URL", "value": "NORMAL_VALUE"}, + ] + } + ] + } + } + + res = mask_sensitive_env_values(task_run_request, ["PREFECT_API_KEY"], 3, "***") + assert ( + res["overrides"]["containerOverrides"][0]["environment"][0]["value"] == "SeN***" + ) + assert ( + res["overrides"]["containerOverrides"][0]["environment"][1]["value"] + == "NORMAL_VALUE" + ) + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_get_or_generate_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + work_pool_name = "test" + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + family = f"{ECS_DEFAULT_FAMILY}_{work_pool_name}_{flow_run.deployment_id}" + + async with ECSWorker(work_pool_name=work_pool_name) as worker: + result = await run_then_stop_task(worker, configuration, flow_run) + + assert result.status_code == 0 + _, task_arn = parse_identifier(result.identifier) + + task = describe_task(ecs_client, task_arn) + task_definition = describe_task_definition(ecs_client, task) + assert task_definition["family"] == family