From 5735a492a421c924c0541f35463869a32b434131 Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Mon, 27 Nov 2023 15:40:10 -0500 Subject: [PATCH 01/22] mask prefect api key (#341) --- prefect_aws/workers/ecs_worker.py | 29 ++++++++++++++++++++++++++++- tests/workers/test_ecs_worker.py | 25 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 02c4117c..afbe631b 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -221,6 +221,31 @@ def parse_identifier(identifier: str) -> ECSIdentifier: return ECSIdentifier(cluster, task) +def mask_sensitive_env_values( + task_run_request: dict, values: List[str], keep_length=3, replace_with="***" +): + for container in task_run_request.get("overrides", {}).get( + "containerOverrides", [] + ): + for env_var in container.get("environment", []): + if ( + "name" not in env_var + or "value" not in env_var + or env_var["name"] not in values + ): + continue + if len(env_var["value"]) > keep_length: + # Replace characters beyond the keep length + env_var["value"] = env_var["value"][:keep_length] + replace_with + return task_run_request + + +def mask_api_key(task_run_request): + return mask_sensitive_env_values( + task_run_request, ["PREFECT_API_KEY"], keep_length=6 + ) + + class ECSJobConfiguration(BaseJobConfiguration): """ Job configuration for an ECS worker. @@ -724,8 +749,10 @@ def _create_task_and_wait_for_start( logger.info("Creating ECS task run...") logger.debug( - f"Task run request {json.dumps(task_run_request, indent=2, default=str)}" + "Task run request" + f"{json.dumps(mask_api_key(task_run_request), indent=2, default=str)}" ) + try: task = self._create_task_run(ecs_client, task_run_request) task_arn = task["taskArn"] diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index b6a39b35..077a178a 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -34,6 +34,7 @@ InfrastructureNotFound, _get_container, get_prefect_image_name, + mask_sensitive_env_values, parse_identifier, ) @@ -2180,3 +2181,27 @@ async def test_retry_on_failed_task_start( await run_then_stop_task(worker, configuration, flow_run) assert run_task_mock.call_count == 3 + + +async def test_mask_sensitive_env_values(): + task_run_request = { + "overrides": { + "containerOverrides": [ + { + "environment": [ + {"name": "PREFECT_API_KEY", "value": "SeNsItiVe VaLuE"}, + {"name": "PREFECT_API_URL", "value": "NORMAL_VALUE"}, + ] + } + ] + } + } + + res = mask_sensitive_env_values(task_run_request, ["PREFECT_API_KEY"], 3, "***") + assert ( + res["overrides"]["containerOverrides"][0]["environment"][0]["value"] == "SeN***" + ) + assert ( + res["overrides"]["containerOverrides"][0]["environment"][1]["value"] + == "NORMAL_VALUE" + ) From 6ba701f4208c33f5c40e035f08a956deb7285432 Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Thu, 30 Nov 2023 10:34:15 -0500 Subject: [PATCH 02/22] dont modify task run request (#347) --- prefect_aws/workers/ecs_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index afbe631b..c55e0f30 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -242,7 +242,7 @@ def mask_sensitive_env_values( def mask_api_key(task_run_request): return mask_sensitive_env_values( - task_run_request, ["PREFECT_API_KEY"], keep_length=6 + deepcopy(task_run_request), ["PREFECT_API_KEY"], keep_length=6 ) From 6865af76b7f0a0555eefc63fce41721fcf827407 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Thu, 30 Nov 2023 12:15:19 -0600 Subject: [PATCH 03/22] Update CHANGELOG.md for v0.4.5 (#348) --- CHANGELOG.md | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4e01351..514f2ffa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Changed -- Added 'SecretBrinary' suport to `AwsSecret` block - [#274](https://github.com/PrefectHQ/prefect-aws/pull/274) ### Fixed @@ -18,6 +17,30 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +## 0.4.5 + +Released November 30th, 2023. + +### Fixed + +- Bug where Prefect API key provided to ECS tasks was masked - [#347](https://github.com/PrefectHQ/prefect-aws/pull/347) + +## 0.4.4 + +Released November 29th, 2023. + +### Changed + +- Mask Prefect API key in logs - [#341](https://github.com/PrefectHQ/prefect-aws/pull/341) + +## 0.4.3 + +Released November 13th, 2023. + +### Added + +- `SecretBrinary` suport to `AwsSecret` block - [#274](https://github.com/PrefectHQ/prefect-aws/pull/274) + ## 0.4.2 Released November 6th, 2023. From 6aca613bf901d04541f53dc61874b732051448eb Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Mon, 11 Dec 2023 12:45:13 -0600 Subject: [PATCH 04/22] Adds ability to publish `ECSTask` block as a `ecs` work pool (#353) --- CHANGELOG.md | 8 ++ prefect_aws/ecs.py | 74 +++++++++++++++++- requirements.txt | 2 +- tests/conftest.py | 4 +- tests/test_ecs.py | 189 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 274 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 514f2ffa..653b1d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +## 0.4.6 + +Released December 11th, 2023. + +### Added + +Ability to publish `ECSTask`` block as an ecs work pool - [#353](https://github.com/PrefectHQ/prefect-aws/pull/353) + ## 0.4.5 Released November 30th, 2023. diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index a6ebe206..8e7052f2 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -108,6 +108,7 @@ import json import logging import pprint +import shlex import sys import time import warnings @@ -116,6 +117,8 @@ import boto3 import yaml from anyio.abc import TaskStatus +from jsonpointer import JsonPointerException +from prefect.blocks.core import BlockNotSavedError from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.infrastructure.base import Infrastructure, InfrastructureResult from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible @@ -132,7 +135,7 @@ from typing_extensions import Literal, Self from prefect_aws import AwsCredentials -from prefect_aws.workers.ecs_worker import _TAG_REGEX +from prefect_aws.workers.ecs_worker import _TAG_REGEX, ECSWorker # Internal type alias for ECS clients which are generated dynamically in botocore _ECSClient = Any @@ -681,6 +684,75 @@ async def kill(self, identifier: str, grace_seconds: int = 30) -> None: cluster, task = parse_task_identifier(identifier) await run_sync_in_worker_thread(self._stop_task, cluster, task) + @staticmethod + def get_corresponding_worker_type() -> str: + """Return the corresponding worker type for this infrastructure block.""" + return ECSWorker.type + + async def generate_work_pool_base_job_template(self) -> dict: + """ + Generate a base job template for a cloud-run work pool with the same + configuration as this block. + + Returns: + - dict: a base job template for a cloud-run work pool + """ + base_job_template = copy.deepcopy(ECSWorker.get_default_base_job_template()) + for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items(): + if key == "command": + base_job_template["variables"]["properties"]["command"]["default"] = ( + shlex.join(value) + ) + elif key in [ + "type", + "block_type_slug", + "_block_document_id", + "_block_document_name", + "_is_anonymous", + "task_customizations", + ]: + continue + elif key == "aws_credentials": + if not self.aws_credentials._block_document_id: + raise BlockNotSavedError( + "It looks like you are trying to use a block that" + " has not been saved. Please call `.save` on your block" + " before publishing it as a work pool." + ) + base_job_template["variables"]["properties"]["aws_credentials"][ + "default" + ] = { + "$ref": { + "block_document_id": str( + self.aws_credentials._block_document_id + ) + } + } + elif key == "task_definition": + base_job_template["job_configuration"]["task_definition"] = value + elif key in base_job_template["variables"]["properties"]: + base_job_template["variables"]["properties"][key]["default"] = value + else: + self.logger.warning( + f"Variable {key!r} is not supported by Cloud Run work pools." + " Skipping." + ) + + if self.task_customizations: + try: + base_job_template["job_configuration"]["task_run_request"] = ( + self.task_customizations.apply( + base_job_template["job_configuration"]["task_run_request"] + ) + ) + except JsonPointerException: + self.logger.warning( + "Unable to apply task customizations to the base job template." + "You may need to update the template manually." + ) + + return base_job_template + def _stop_task(self, cluster: str, task: str) -> None: """ Stop a running ECS task. diff --git a/requirements.txt b/requirements.txt index 919ce567..e5cfb0b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ boto3>=1.24.53 botocore>=1.27.53 mypy_boto3_s3>=1.24.94 mypy_boto3_secretsmanager>=1.26.49 -prefect>=2.13.5 +prefect>=2.14.10 tenacity>=8.0.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index ea17328f..9d2da154 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,11 +22,13 @@ def prefect_db(): @pytest.fixture def aws_credentials(): - return AwsCredentials( + block = AwsCredentials( aws_access_key_id="access_key_id", aws_secret_access_key="secret_access_key", region_name="us-east-1", ) + block.save("test-creds-block", overwrite=True) + return block @pytest.fixture diff --git a/tests/test_ecs.py b/tests/test_ecs.py index cf18bfe4..2f970116 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -1,6 +1,7 @@ import json import logging import textwrap +from copy import deepcopy from functools import partial from typing import Any, Awaitable, Callable, Dict, List, Optional from unittest.mock import MagicMock @@ -18,6 +19,8 @@ from prefect.utilities.dockerutils import get_prefect_image_name from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.workers.ecs_worker import ECSWorker + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import ValidationError else: @@ -2047,3 +2050,189 @@ async def test_kill_with_grace_period(aws_credentials, caplog): # Logs warning assert "grace period of 60s requested, but AWS does not support" in caplog.text + + +@pytest.fixture +def default_base_job_template(): + return deepcopy(ECSWorker.get_default_base_job_template()) + + +@pytest.fixture +def base_job_template_with_defaults(default_base_job_template, aws_credentials): + base_job_template_with_defaults = deepcopy(default_base_job_template) + base_job_template_with_defaults["variables"]["properties"]["command"][ + "default" + ] = "python my_script.py" + base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = { + "VAR1": "value1", + "VAR2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = { + "label1": "value1", + "label2": "value2", + } + base_job_template_with_defaults["variables"]["properties"]["name"][ + "default" + ] = "prefect-job" + base_job_template_with_defaults["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + base_job_template_with_defaults["variables"]["properties"]["aws_credentials"][ + "default" + ] = {"$ref": {"block_document_id": str(aws_credentials._block_document_id)}} + base_job_template_with_defaults["variables"]["properties"]["launch_type"][ + "default" + ] = "FARGATE_SPOT" + base_job_template_with_defaults["variables"]["properties"]["vpc_id"][ + "default" + ] = "vpc-123456" + base_job_template_with_defaults["variables"]["properties"]["task_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["execution_role_arn"][ + "default" + ] = "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + base_job_template_with_defaults["variables"]["properties"]["cluster"][ + "default" + ] = "test-cluster" + base_job_template_with_defaults["variables"]["properties"]["cpu"]["default"] = 2048 + base_job_template_with_defaults["variables"]["properties"]["memory"][ + "default" + ] = 4096 + + base_job_template_with_defaults["variables"]["properties"]["family"][ + "default" + ] = "test-family" + base_job_template_with_defaults["variables"]["properties"]["task_definition_arn"][ + "default" + ] = "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + base_job_template_with_defaults["variables"]["properties"][ + "cloudwatch_logs_options" + ]["default"] = { + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + } + base_job_template_with_defaults["variables"]["properties"][ + "configure_cloudwatch_logs" + ]["default"] = True + base_job_template_with_defaults["variables"]["properties"]["stream_output"][ + "default" + ] = True + base_job_template_with_defaults["variables"]["properties"][ + "task_watch_poll_interval" + ]["default"] = 5.1 + base_job_template_with_defaults["variables"]["properties"][ + "task_start_timeout_seconds" + ]["default"] = 60 + base_job_template_with_defaults["variables"]["properties"][ + "auto_deregister_task_definition" + ]["default"] = False + return base_job_template_with_defaults + + +@pytest.fixture +def base_job_template_with_task_arn(default_base_job_template, aws_credentials): + base_job_template_with_task_arn = deepcopy(default_base_job_template) + base_job_template_with_task_arn["variables"]["properties"]["image"][ + "default" + ] = "docker.io/my_image:latest" + + base_job_template_with_task_arn["job_configuration"]["task_definition"] = { + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + } + return base_job_template_with_task_arn + + +@pytest.mark.parametrize( + "job_config", + [ + "default", + "custom", + "task_definition_arn", + ], +) +async def test_generate_work_pool_base_job_template( + job_config, + base_job_template_with_defaults, + aws_credentials, + default_base_job_template, + base_job_template_with_task_arn, + caplog, +): + job = ECSTask() + expected_template = default_base_job_template + expected_template["variables"]["properties"]["image"][ + "default" + ] = get_prefect_image_name() + if job_config == "custom": + expected_template = base_job_template_with_defaults + job = ECSTask( + command=["python", "my_script.py"], + env={"VAR1": "value1", "VAR2": "value2"}, + labels={"label1": "value1", "label2": "value2"}, + name="prefect-job", + image="docker.io/my_image:latest", + aws_credentials=aws_credentials, + launch_type="FARGATE_SPOT", + vpc_id="vpc-123456", + task_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + execution_role_arn="arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + cluster="test-cluster", + cpu=2048, + memory=4096, + task_customizations=[ + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", + "value": ["sg-d72e9599956a084f5"], + }, + ], + family="test-family", + task_definition_arn=( + "arn:aws:ecs:us-east-1:123456789012:task-definition/test-family:1" + ), + cloudwatch_logs_options={ + "awslogs-group": "prefect", + "awslogs-region": "us-east-1", + "awslogs-stream-prefix": "prefect", + }, + configure_cloudwatch_logs=True, + stream_output=True, + task_watch_poll_interval=5.1, + task_start_timeout_seconds=60, + auto_deregister_task_definition=False, + ) + elif job_config == "task_definition_arn": + expected_template = base_job_template_with_task_arn + job = ECSTask( + image="docker.io/my_image:latest", + task_definition={ + "containerDefinitions": [ + {"image": "docker.io/my_image:latest", "name": "prefect-job"} + ], + "cpu": "2048", + "family": "test-family", + "memory": "2024", + "executionRoleArn": ( + "arn:aws:iam::123456789012:role/ecsTaskExecutionRole" + ), + }, + ) + + template = await job.generate_work_pool_base_job_template() + + assert template == expected_template + + if job_config == "custom": + assert ( + "Unable to apply task customizations to the base job template." + "You may need to update the template manually." + in caplog.text + ) From 2d6dc288827ebaf08850e37e42750bc84723faf6 Mon Sep 17 00:00:00 2001 From: kevingrismore <146098880+kevingrismore@users.noreply.github.com> Date: Mon, 18 Dec 2023 08:56:58 -0600 Subject: [PATCH 05/22] Fix `S3Bucket.load()` for nested MinIO Credentials block (#359) --- CHANGELOG.md | 2 ++ prefect_aws/s3.py | 5 +---- tests/test_s3.py | 4 ++++ 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 653b1d94..cefe0b6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Bug where `S3Bucket.load()` constructed `AwsCredentials` instead of `MinIOCredentials` - [#359](https://github.com/PrefectHQ/prefect-aws/pull/359) + ### Deprecated ### Removed diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 11ca0438..643d78ac 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -412,7 +412,7 @@ class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock bucket_name: str = Field(default=..., description="Name of your bucket.") - credentials: Union[AwsCredentials, MinIOCredentials] = Field( + credentials: Union[MinIOCredentials, AwsCredentials] = Field( default_factory=AwsCredentials, description="A block containing your credentials to AWS or MinIO.", ) @@ -425,9 +425,6 @@ class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock ), ) - class Config: - smart_union = True - # Property to maintain compatibility with storage block based deployments @property def basepath(self) -> str: diff --git a/tests/test_s3.py b/tests/test_s3.py index 89a39f7d..93d11cc1 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -822,7 +822,11 @@ def s3_bucket_with_similar_objects(self, s3_bucket_with_objects, objects_in_fold def test_credentials_are_correct_type(self, credentials): s3_bucket = S3Bucket(bucket_name="bucket", credentials=credentials) + s3_bucket_parsed = S3Bucket.parse_obj( + {"bucket_name": "bucket", "credentials": dict(credentials)} + ) assert isinstance(s3_bucket.credentials, type(credentials)) + assert isinstance(s3_bucket_parsed.credentials, type(credentials)) @pytest.mark.parametrize("client_parameters", aws_clients[-1:], indirect=True) def test_list_objects_empty(self, s3_bucket_empty, client_parameters): From c0a4f9c5f3695ee482aa7d0f9f307148f0d453a0 Mon Sep 17 00:00:00 2001 From: Dominic Tarro <57306102+dominictarro@users.noreply.github.com> Date: Wed, 3 Jan 2024 08:32:02 -0500 Subject: [PATCH 06/22] Added Lambda function block (#355) Co-authored-by: nate nowack Co-authored-by: Alexander Streed --- docs/lambda_function.md | 6 + mkdocs.yml | 1 + prefect_aws/__init__.py | 2 + prefect_aws/lambda_function.py | 194 +++++++++++++++++++++++++ tests/test_lambda_function.py | 253 +++++++++++++++++++++++++++++++++ 5 files changed, 456 insertions(+) create mode 100644 docs/lambda_function.md create mode 100644 prefect_aws/lambda_function.py create mode 100644 tests/test_lambda_function.py diff --git a/docs/lambda_function.md b/docs/lambda_function.md new file mode 100644 index 00000000..3f5c52e8 --- /dev/null +++ b/docs/lambda_function.md @@ -0,0 +1,6 @@ +--- +description: Module handling AWS Lambda functions +notes: This documentation page is generated from source file docstrings. +--- + +::: prefect_aws.lambda_function \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 465f6407..d5084099 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -87,6 +87,7 @@ nav: - Credentials: credentials.md - ECS Worker: ecs_worker.md - ECS: ecs.md + - Lambda: lambda_function.md - Deployments: - Steps: deployments/steps.md - S3: s3.md diff --git a/prefect_aws/__init__.py b/prefect_aws/__init__.py index 9e39166b..b1de3fc5 100644 --- a/prefect_aws/__init__.py +++ b/prefect_aws/__init__.py @@ -1,6 +1,7 @@ from . import _version from .credentials import AwsCredentials, MinIOCredentials from .client_parameters import AwsClientParameters +from .lambda_function import LambdaFunction from .s3 import S3Bucket from .ecs import ECSTask from .secrets_manager import AwsSecret @@ -17,6 +18,7 @@ __all__ = [ "AwsCredentials", "AwsClientParameters", + "LambdaFunction", "MinIOCredentials", "S3Bucket", "ECSTask", diff --git a/prefect_aws/lambda_function.py b/prefect_aws/lambda_function.py new file mode 100644 index 00000000..b2cc631c --- /dev/null +++ b/prefect_aws/lambda_function.py @@ -0,0 +1,194 @@ +"""Integrations with AWS Lambda. + +Examples: + + Run a lambda function with a payload + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(payload={"foo": "bar"}) + ``` + + Specify a version of a lambda function + + ```python + LambdaFunction( + function_name="test-function", + qualifier="1", + aws_credentials=aws_credentials, + ).invoke() + ``` + + Invoke a lambda function asynchronously + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(invocation_type="Event") + ``` + + Invoke a lambda function and return the last 4 KB of logs + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(tail=True) + ``` + + Invoke a lambda function with a client context + + ```python + LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ).invoke(client_context={"bar": "foo"}) + ``` + +""" +import json +from typing import Literal, Optional + +from prefect.blocks.core import Block +from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field +else: + from pydantic import Field + +from prefect_aws.credentials import AwsCredentials + + +class LambdaFunction(Block): + """Invoke a Lambda function. This block is part of the prefect-aws + collection. Install prefect-aws with `pip install prefect-aws` to use this + block. + + Attributes: + function_name: The name, ARN, or partial ARN of the Lambda function to + run. This must be the name of a function that is already deployed + to AWS Lambda. + qualifier: The version or alias of the Lambda function to use when + invoked. If not specified, the latest (unqualified) version of the + Lambda function will be used. + aws_credentials: The AWS credentials to use to connect to AWS Lambda + with a default factory of AwsCredentials. + + """ + + _block_type_name = "Lambda Function" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/d74b16fe84ce626345adf235a47008fea2869a60-225x225.png" # noqa + _documentation_url = "https://prefecthq.github.io/prefect-aws/s3/#prefect_aws.lambda_function.LambdaFunction" # noqa + + function_name: str = Field( + title="Function Name", + description=( + "The name, ARN, or partial ARN of the Lambda function to run. This" + " must be the name of a function that is already deployed to AWS" + " Lambda." + ), + ) + qualifier: Optional[str] = Field( + default=None, + title="Qualifier", + description=( + "The version or alias of the Lambda function to use when invoked. " + "If not specified, the latest (unqualified) version of the Lambda " + "function will be used." + ), + ) + aws_credentials: AwsCredentials = Field( + title="AWS Credentials", + default_factory=AwsCredentials, + description="The AWS credentials to invoke the Lambda with.", + ) + + class Config: + """Lambda's pydantic configuration.""" + + smart_union = True + + def _get_lambda_client(self): + """ + Retrieve a boto3 session and Lambda client + """ + boto_session = self.aws_credentials.get_boto3_session() + lambda_client = boto_session.client("lambda") + return lambda_client + + @sync_compatible + async def invoke( + self, + payload: dict = None, + invocation_type: Literal[ + "RequestResponse", "Event", "DryRun" + ] = "RequestResponse", + tail: bool = False, + client_context: Optional[dict] = None, + ) -> dict: + """ + [Invoke](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/lambda/client/invoke.html) + the Lambda function with the given payload. + + Args: + payload: The payload to send to the Lambda function. + invocation_type: The invocation type of the Lambda function. This + can be one of "RequestResponse", "Event", or "DryRun". Uses + "RequestResponse" by default. + tail: If True, the response will include the base64-encoded last 4 + KB of log data produced by the Lambda function. + client_context: The client context to send to the Lambda function. + Limited to 3583 bytes. + + Returns: + The response from the Lambda function. + + Examples: + + ```python + from prefect_aws.lambda_function import LambdaFunction + from prefect_aws.credentials import AwsCredentials + + credentials = AwsCredentials() + lambda_function = LambdaFunction( + function_name="test_lambda_function", + aws_credentials=credentials, + ) + response = lambda_function.invoke( + payload={"foo": "bar"}, + invocation_type="RequestResponse", + ) + response["Payload"].read() + ``` + ```txt + b'{"foo": "bar"}' + ``` + + """ + # Add invocation arguments + kwargs = dict(FunctionName=self.function_name) + + if payload: + kwargs["Payload"] = json.dumps(payload).encode() + + # Let boto handle invalid invocation types + kwargs["InvocationType"] = invocation_type + + if self.qualifier is not None: + kwargs["Qualifier"] = self.qualifier + + if tail: + kwargs["LogType"] = "Tail" + + if client_context is not None: + # For some reason this is string, but payload is bytes + kwargs["ClientContext"] = json.dumps(client_context) + + # Get client and invoke + lambda_client = await run_sync_in_worker_thread(self._get_lambda_client) + return await run_sync_in_worker_thread(lambda_client.invoke, **kwargs) diff --git a/tests/test_lambda_function.py b/tests/test_lambda_function.py new file mode 100644 index 00000000..32210629 --- /dev/null +++ b/tests/test_lambda_function.py @@ -0,0 +1,253 @@ +import inspect +import io +import json +import zipfile +from typing import Optional + +import boto3 +import pytest +from botocore.response import StreamingBody +from moto import mock_iam, mock_lambda +from pytest_lazyfixture import lazy_fixture + +from prefect_aws.credentials import AwsCredentials +from prefect_aws.lambda_function import LambdaFunction + + +@pytest.fixture +def lambda_mock(aws_credentials: AwsCredentials): + with mock_lambda(): + yield boto3.client( + "lambda", + region_name=aws_credentials.region_name, + ) + + +@pytest.fixture +def iam_mock(aws_credentials: AwsCredentials): + with mock_iam(): + yield boto3.client( + "iam", + region_name=aws_credentials.region_name, + ) + + +@pytest.fixture +def mock_iam_rule(iam_mock): + yield iam_mock.create_role( + RoleName="test-role", + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "lambda.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + ) + + +def handler_a(event, context): + if isinstance(event, dict): + if "error" in event: + raise Exception(event["error"]) + event["foo"] = "bar" + else: + event = {"foo": "bar"} + return event + + +LAMBDA_TEST_CODE = inspect.getsource(handler_a) + + +@pytest.fixture +def mock_lambda_code(): + with io.BytesIO() as f: + with zipfile.ZipFile(f, mode="w") as z: + z.writestr("foo.py", LAMBDA_TEST_CODE) + f.seek(0) + yield f.read() + + +@pytest.fixture +def mock_lambda_function(lambda_mock, mock_iam_rule, mock_lambda_code): + r = lambda_mock.create_function( + FunctionName="test-function", + Runtime="python3.10", + Role=mock_iam_rule["Role"]["Arn"], + Handler="foo.handler", + Code={"ZipFile": mock_lambda_code}, + ) + r2 = lambda_mock.publish_version( + FunctionName="test-function", + ) + r["Version"] = r2["Version"] + yield r + + +def handler_b(event, context): + event = {"data": [1, 2, 3]} + return event + + +LAMBDA_TEST_CODE_V2 = inspect.getsource(handler_b) + + +@pytest.fixture +def mock_lambda_code_v2(): + with io.BytesIO() as f: + with zipfile.ZipFile(f, mode="w") as z: + z.writestr("foo.py", LAMBDA_TEST_CODE_V2) + f.seek(0) + yield f.read() + + +@pytest.fixture +def add_lambda_version(mock_lambda_function, lambda_mock, mock_lambda_code_v2): + r = mock_lambda_function.copy() + lambda_mock.update_function_code( + FunctionName="test-function", + ZipFile=mock_lambda_code_v2, + ) + r2 = lambda_mock.publish_version( + FunctionName="test-function", + ) + r["Version"] = r2["Version"] + yield r + + +@pytest.fixture +def lambda_function(aws_credentials): + return LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ) + + +def make_patched_invocation(client, handler): + """Creates a patched invoke method for moto lambda. The method replaces + the response 'Payload' with the result of the handler function. + """ + true_invoke = client.invoke + + def invoke(*args, **kwargs): + """Calls the true invoke and replaces the Payload with its result.""" + result = true_invoke(*args, **kwargs) + blob = json.dumps( + handler( + event=kwargs.get("Payload"), + context=kwargs.get("ClientContext"), + ) + ).encode() + result["Payload"] = StreamingBody(io.BytesIO(blob), len(blob)) + return result + + return invoke + + +@pytest.fixture +def mock_invoke( + lambda_function: LambdaFunction, handler, monkeypatch: pytest.MonkeyPatch +): + """Fixture to patch the invocation response's 'Payload' field. + + When `result["Payload"].read` is called, moto attempts to run the function + in a Docker container and return the result. This is total overkill, so + we actually call the handler with the given arguments. + """ + client = lambda_function._get_lambda_client() + + monkeypatch.setattr( + client, + "invoke", + make_patched_invocation(client, handler), + ) + + def _get_lambda_client(): + return client + + monkeypatch.setattr( + lambda_function, + "_get_lambda_client", + _get_lambda_client, + ) + + yield + + +class TestLambdaFunction: + def test_init(self, aws_credentials): + function = LambdaFunction( + function_name="test-function", + aws_credentials=aws_credentials, + ) + assert function.function_name == "test-function" + assert function.qualifier is None + + @pytest.mark.parametrize( + "payload,expected,handler", + [ + ({"foo": "baz"}, {"foo": "bar"}, handler_a), + (None, {"foo": "bar"}, handler_a), + ], + ) + def test_invoke_lambda_payloads( + self, + payload: Optional[dict], + expected: dict, + handler, + mock_lambda_function, + lambda_function: LambdaFunction, + mock_invoke, + ): + result = lambda_function.invoke(payload) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == expected + + @pytest.mark.parametrize("handler", [handler_a]) + def test_invoke_lambda_tail( + self, lambda_function: LambdaFunction, mock_lambda_function, mock_invoke + ): + result = lambda_function.invoke(tail=True) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == {"foo": "bar"} + assert "LogResult" in result + + @pytest.mark.parametrize("handler", [handler_a]) + def test_invoke_lambda_client_context( + self, lambda_function: LambdaFunction, mock_lambda_function, mock_invoke + ): + # Just making sure boto doesn't throw an error + result = lambda_function.invoke(client_context={"bar": "foo"}) + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == {"foo": "bar"} + + @pytest.mark.parametrize( + "func_fixture,expected,handler", + [ + (lazy_fixture("mock_lambda_function"), {"foo": "bar"}, handler_a), + (lazy_fixture("add_lambda_version"), {"data": [1, 2, 3]}, handler_b), + ], + ) + def test_invoke_lambda_qualifier( + self, + func_fixture, + expected, + lambda_function: LambdaFunction, + mock_invoke, + ): + try: + lambda_function.qualifier = func_fixture["Version"] + result = lambda_function.invoke() + assert result["StatusCode"] == 200 + response_payload = json.loads(result["Payload"].read()) + assert response_payload == expected + finally: + lambda_function.qualifier = None From 50d5f6395e860d2c00c5f4377399d84232954069 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Wed, 3 Jan 2024 09:57:17 -0600 Subject: [PATCH 07/22] update changelog (#363) --- CHANGELOG.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cefe0b6d..64454b60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,12 +13,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Bug where `S3Bucket.load()` constructed `AwsCredentials` instead of `MinIOCredentials` - [#359](https://github.com/PrefectHQ/prefect-aws/pull/359) - ### Deprecated ### Removed +## 0.4.7 + +Released January 3rd, 2024. + +### Added + +- `LambdaFunction` block to invoke lambda functions - [#355](https://github.com/PrefectHQ/prefect-aws/pull/355) + +### Fixed + +- Bug where `S3Bucket.load()` constructed `AwsCredentials` instead of `MinIOCredentials` - [#359](https://github.com/PrefectHQ/prefect-aws/pull/359) + ## 0.4.6 Released December 11th, 2023. From eeb3252b16ca823c23ec5b35d128d1abc422cb6e Mon Sep 17 00:00:00 2001 From: nate nowack Date: Fri, 5 Jan 2024 11:23:15 -0600 Subject: [PATCH 08/22] fix typo in docs (#364) --- prefect_aws/ecs.py | 2 +- prefect_aws/workers/ecs_worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index 8e7052f2..ac5dd333 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -407,7 +407,7 @@ class ECSTask(Infrastructure): description=( "The type of ECS task run infrastructure that should be used. Note that" " 'FARGATE_SPOT' is not a formal ECS launch type, but we will configure" - " the proper capacity provider stategy if set here." + " the proper capacity provider strategy if set here." ), ) ) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index c55e0f30..6d3c8ec7 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -420,7 +420,7 @@ class ECSVariables(BaseVariables): description=( "The type of ECS task run infrastructure that should be used. Note that" " 'FARGATE_SPOT' is not a formal ECS launch type, but we will configure" - " the proper capacity provider stategy if set here." + " the proper capacity provider strategy if set here." ), ) ) From f6977603f4f997459695f42f9e728a7fb45ec6cb Mon Sep 17 00:00:00 2001 From: nate nowack Date: Fri, 19 Jan 2024 12:15:39 -0600 Subject: [PATCH 09/22] provide ability to cache boto client instances directly and on `S3Bucket` (#369) --- CHANGELOG.md | 4 + prefect_aws/client_parameters.py | 12 +++ prefect_aws/credentials.py | 76 ++++++++++++++++--- prefect_aws/s3.py | 2 +- tests/test_credentials.py | 122 ++++++++++++++++++++++++++++++- 5 files changed, 204 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 64454b60..b40563f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) + ### Fixed ### Deprecated @@ -105,6 +107,7 @@ Released August 31st, 2023. Released July 20th, 2023. ### Changed + - Promoted workers to GA, removed beta disclaimers ## 0.3.5 @@ -293,6 +296,7 @@ Released on October 28th, 2022. - `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137) ### Fixed + - Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139) - Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143) diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index bf030590..eb3be09b 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -70,6 +70,18 @@ class AwsClientParameters(BaseModel): title="Botocore Config", ) + def __hash__(self): + return hash( + ( + self.api_version, + self.use_ssl, + self.verify, + self.verify_cert_path, + self.endpoint_url, + self.config, + ) + ) + @validator("config", pre=True) def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]: """ diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index 64f49efe..5aeddaa6 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -1,6 +1,8 @@ """Module handling AWS credentials""" from enum import Enum +from functools import lru_cache +from threading import Lock from typing import Any, Optional, Union import boto3 @@ -16,14 +18,43 @@ from prefect_aws.client_parameters import AwsClientParameters +_LOCK = Lock() + class ClientType(Enum): + """The supported boto3 clients.""" + S3 = "s3" ECS = "ecs" BATCH = "batch" SECRETS_MANAGER = "secretsmanager" +@lru_cache(maxsize=8, typed=True) +def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any: + """ + Helper method to cache and dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + with _LOCK: + if isinstance(client_type, ClientType): + client_type = client_type.value + + client = ctx.get_boto3_session().client( + service_name=client_type, + **ctx.aws_client_parameters.get_params_override(), + ) + return client + + class AwsCredentials(CredentialsBlock): """ Block used to manage authentication with AWS. AWS authentication is @@ -75,6 +106,22 @@ class AwsCredentials(CredentialsBlock): title="AWS Client Parameters", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + field_hashes = ( + hash(self.aws_access_key_id), + hash(self.aws_secret_access_key), + hash(self.aws_session_token), + hash(self.profile_name), + hash(self.region_name), + hash(frozenset(self.aws_client_parameters.dict().items())), + ) + return hash(field_hashes) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -104,7 +151,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType]): """ Helper method to dynamically get a client type. @@ -120,10 +167,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ @@ -186,6 +230,21 @@ class MinIOCredentials(CredentialsBlock): description="Extra parameters to initialize the Client.", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + return hash( + ( + hash(self.minio_root_user), + hash(self.minio_root_password), + hash(self.region_name), + hash(frozenset(self.aws_client_parameters.dict().items())), + ) + ) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -218,7 +277,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType]): """ Helper method to dynamically get a client type. @@ -234,10 +293,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 643d78ac..a10e2171 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -466,7 +466,7 @@ def _get_s3_client(self) -> boto3.client: Authenticate MinIO credentials or AWS credentials and return an S3 client. This is a helper function called by read_path() or write_path(). """ - return self.credentials.get_s3_client() + return self.credentials.get_client("s3") def _get_bucket_resource(self) -> boto3.resource: """ diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 6e0a1ff8..96ecbd22 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -3,7 +3,12 @@ from botocore.client import BaseClient from moto import mock_s3 -from prefect_aws.credentials import AwsCredentials, ClientType, MinIOCredentials +from prefect_aws.credentials import ( + AwsCredentials, + ClientType, + MinIOCredentials, + _get_client_cached, +) def test_aws_credentials_get_boto3_session(): @@ -44,3 +49,118 @@ def test_minio_credentials_get_boto3_session(): def test_credentials_get_client(credentials, client_type): with mock_s3(): assert isinstance(credentials.get_client(client_type), BaseClient) + + +@pytest.mark.parametrize( + "credentials", + [ + AwsCredentials(region_name="us-east-1"), + MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ), + ], +) +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_get_client_cached(credentials, client_type): + """ + Test to ensure that _get_client_cached function returns the same instance + for multiple calls with the same parameters and properly utilizes lru_cache. + """ + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + credentials.get_client(client_type) + credentials.get_client(client_type) + credentials.get_client(client_type) + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_aws_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = AwsCredentials(region_name="us-east-1") + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_minio_credentials_change_causes_cache_miss(client_type): + """ + Test to ensure that changing configuration on an AwsCredentials instance + after fetching a client causes a cache miss. + """ + + _get_client_cached.cache_clear() + + credentials = MinIOCredentials( + minio_root_user="root_user", + minio_root_password="root_password", + region_name="us-east-1", + ) + + initial_client = credentials.get_client(client_type) + + credentials.region_name = "us-west-2" + + new_client = credentials.get_client(client_type) + + assert ( + initial_client is not new_client + ), "Client should be different after configuration change" + + assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" + + +@pytest.mark.parametrize( + "credentials_type, initial_field, new_field", + [ + ( + AwsCredentials, + {"region_name": "us-east-1"}, + {"region_name": "us-east-2"}, + ), + ( + MinIOCredentials, + { + "region_name": "us-east-1", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + { + "region_name": "us-east-2", + "minio_root_user": "root_user", + "minio_root_password": "root_password", + }, + ), + ], +) +def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field): + credentials = credentials_type(**initial_field) + initial_hash = hash(credentials) + + setattr(credentials, list(new_field.keys())[0], list(new_field.values())[0]) + new_hash = hash(credentials) + + assert initial_hash != new_hash, "Hash should change when region_name changes" From 7a254923b9be43383d132f550e654e7534c979b6 Mon Sep 17 00:00:00 2001 From: kevingrismore <146098880+kevingrismore@users.noreply.github.com> Date: Fri, 19 Jan 2024 16:09:19 -0600 Subject: [PATCH 10/22] Support MinIO Credentials block as credentials for `push_to_s3` and `pull_from_s3` (#366) Co-authored-by: Chris Guidry --- CHANGELOG.md | 8 +- prefect_aws/deployments/steps.py | 14 +++- .../{deploments => deployments}/test_steps.py | 74 ++++++++++++++++++- 3 files changed, 88 insertions(+), 8 deletions(-) rename tests/{deploments => deployments}/test_steps.py (82%) diff --git a/CHANGELOG.md b/CHANGELOG.md index b40563f9..70227627 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Support MinIO Credentials as `credentials` dict for `push_to_s3` and `pull_from_s3` - [#366](https://github.com/PrefectHQ/prefect-aws/pull/366) + ### Changed - Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) @@ -141,7 +143,7 @@ Released on June 13th, 2023. ### Fixed -- Change prefect.docker import to prefect.utilities.dockerutils to fix a crash when using custom blocks based on S3Bucket - [#273](https://github.com/PrefectHQ/prefect-aws/pull/273) +- Change prefect.docker import to prefect.utilities.dockerutils to fix a crash when using custom blocks based on S3Bucket - [#273](https://github.com/PrefectHQ/prefect-aws/pull/273) ## 0.3.2 @@ -230,7 +232,7 @@ Released on January 4th, 2023. - `list_objects`, `download_object_to_path`, `download_object_to_file_object`, `download_folder_to_path`, `upload_from_path`, `upload_from_file_object`, `upload_from_folder` methods in `S3Bucket` - [#85](https://github.com/PrefectHQ/prefect-aws/pull/175) - `aws_client_parameters` as a field in `AwsCredentials` and `MinioCredentials` blocks - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `get_client` and `get_s3_client` methods to `AwsCredentials` and `MinioCredentials` blocks - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) +- `get_client` and `get_s3_client` methods to `AwsCredentials` and `MinioCredentials` blocks - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) ### Changed @@ -242,7 +244,7 @@ Released on January 4th, 2023. - `endpoint_url` field in S3Bucket; specify `aws_client_parameters` in `AwsCredentials` or `MinIOCredentials` instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) - `basepath` field in S3Bucket; specify `bucket_folder` instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `minio_credentials` and `aws_credentials` field in S3Bucket; use the `credentials` field instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) +- `minio_credentials` and `aws_credentials` field in S3Bucket; use the `credentials` field instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) ## 0.2.1 diff --git a/prefect_aws/deployments/steps.py b/prefect_aws/deployments/steps.py index 7525a5e2..6161dfb8 100644 --- a/prefect_aws/deployments/steps.py +++ b/prefect_aws/deployments/steps.py @@ -62,7 +62,8 @@ def push_to_s3( bucket: The name of the S3 bucket where files will be uploaded. folder: The folder in the S3 bucket where files will be uploaded. credentials: A dictionary of AWS credentials (aws_access_key_id, - aws_secret_access_key, aws_session_token). + aws_secret_access_key, aws_session_token) or MinIO credentials + (minio_root_user, minio_root_password). client_parameters: A dictionary of additional parameters to pass to the boto3 client. ignore_file: The name of the file containing ignore patterns. @@ -139,7 +140,8 @@ def pull_from_s3( bucket: The name of the S3 bucket where files are stored. folder: The folder in the S3 bucket where files are stored. credentials: A dictionary of AWS credentials (aws_access_key_id, - aws_secret_access_key, aws_session_token). + aws_secret_access_key, aws_session_token) or MinIO credentials + (minio_root_user, minio_root_password). client_parameters: A dictionary of additional parameters to pass to the boto3 client. @@ -204,8 +206,12 @@ def get_s3_client( client_parameters = {} # Get credentials from credentials (regardless if block or not) - aws_access_key_id = credentials.get("aws_access_key_id", None) - aws_secret_access_key = credentials.get("aws_secret_access_key", None) + aws_access_key_id = credentials.get( + "aws_access_key_id", credentials.get("minio_root_user", None) + ) + aws_secret_access_key = credentials.get( + "aws_secret_access_key", credentials.get("minio_root_password", None) + ) aws_session_token = credentials.get("aws_session_token", None) # Get remaining session info from credentials, or client_parameters diff --git a/tests/deploments/test_steps.py b/tests/deployments/test_steps.py similarity index 82% rename from tests/deploments/test_steps.py rename to tests/deployments/test_steps.py index 22608bd7..15c4fc25 100644 --- a/tests/deploments/test_steps.py +++ b/tests/deployments/test_steps.py @@ -11,6 +11,16 @@ from prefect_aws.deployments.steps import get_s3_client, pull_from_s3, push_to_s3 +@pytest.fixture(scope="module", autouse=True) +def set_custom_endpoint(): + original = os.environ.get("MOTO_S3_CUSTOM_ENDPOINTS") + os.environ["MOTO_S3_CUSTOM_ENDPOINTS"] = "http://custom.minio.endpoint:9000" + yield + os.environ.pop("MOTO_S3_CUSTOM_ENDPOINTS") + if original is not None: + os.environ["MOTO_S3_CUSTOM_ENDPOINTS"] = original + + @pytest.fixture def s3_setup(): with mock_s3(): @@ -215,8 +225,15 @@ def test_s3_session_with_params(): }, ) get_s3_client(credentials=creds_block.dict()) + get_s3_client( + credentials={ + "minio_root_user": "MY_USER", + "minio_root_password": "MY_PASSWORD", + }, + client_parameters={"endpoint_url": "http://custom.minio.endpoint:9000"}, + ) all_calls = mock_session.mock_calls - assert len(all_calls) == 6 + assert len(all_calls) == 8 assert all_calls[0].kwargs == { "aws_access_key_id": "THE_KEY", "aws_secret_access_key": "SHHH!", @@ -265,6 +282,20 @@ def test_s3_session_with_params(): }.items() <= all_calls[5].kwargs.items() assert all_calls[5].kwargs.get("config").connect_timeout == 123 assert all_calls[5].kwargs.get("config").signature_version is None + assert all_calls[6].kwargs == { + "aws_access_key_id": "MY_USER", + "aws_secret_access_key": "MY_PASSWORD", + "aws_session_token": None, + "profile_name": None, + "region_name": None, + } + assert all_calls[7].args[0] == "s3" + assert { + "api_version": None, + "use_ssl": True, + "verify": None, + "endpoint_url": "http://custom.minio.endpoint:9000", + }.items() <= all_calls[7].kwargs.items() def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): @@ -309,6 +340,47 @@ def test_custom_credentials_and_client_parameters(s3_setup, tmp_files): assert (tmp_path / file.name).exists() +def test_custom_credentials_and_client_parameters_minio(s3_setup, tmp_files): + s3, bucket_name = s3_setup + folder = "my-project" + + # Custom credentials and client parameters + custom_credentials = { + "minio_root_user": "fake_user", + "minio_root_password": "fake_password", + } + + custom_client_parameters = { + "endpoint_url": "http://custom.minio.endpoint:9000", + } + + os.chdir(tmp_files) + + # Test push_to_s3 with custom credentials and client parameters + push_to_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + # Test pull_from_s3 with custom credentials and client parameters + tmp_path = tmp_files / "test_pull" + tmp_path.mkdir(parents=True, exist_ok=True) + os.chdir(tmp_path) + + pull_from_s3( + bucket_name, + folder, + credentials=custom_credentials, + client_parameters=custom_client_parameters, + ) + + for file in tmp_files.iterdir(): + if file.is_file() and file.name != ".prefectignore": + assert (tmp_path / file.name).exists() + + def test_without_prefectignore_file(s3_setup, tmp_files: Path, mock_aws_credentials): s3, bucket_name = s3_setup folder = "my-project" From 897ab4318525a92eee0b1970a2a2de2fc008cba2 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Tue, 23 Jan 2024 10:25:58 -0600 Subject: [PATCH 11/22] Update CHANGELOG.md (#371) --- CHANGELOG.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 70227627..8121da96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,18 +9,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Support MinIO Credentials as `credentials` dict for `push_to_s3` and `pull_from_s3` - [#366](https://github.com/PrefectHQ/prefect-aws/pull/366) - ### Changed -- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) - ### Fixed ### Deprecated ### Removed +## 0.4.8 +Released January 23rd, 2024; + +### Added + +- Support MinIO Credentials as `credentials` dict for `push_to_s3` and `pull_from_s3` - [#366](https://github.com/PrefectHQ/prefect-aws/pull/366) + +### Changed + +- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) + ## 0.4.7 Released January 3rd, 2024. From 04b90006a1a632fb9d6edb58c98a7a4c865733b6 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Wed, 24 Jan 2024 12:27:37 -0600 Subject: [PATCH 12/22] fix client hashing in nested client params case (#373) --- prefect_aws/client_parameters.py | 4 +++- prefect_aws/credentials.py | 2 +- prefect_aws/utilities.py | 35 ++++++++++++++++++++++++++++++++ tests/test_credentials.py | 25 +++++++++++++++++++++++ tests/test_utilities.py | 34 +++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 prefect_aws/utilities.py create mode 100644 tests/test_utilities.py diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index eb3be09b..6b47c422 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -7,6 +7,8 @@ from botocore.client import Config from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.utilities import hash_collection + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator else: @@ -78,7 +80,7 @@ def __hash__(self): self.verify, self.verify_cert_path, self.endpoint_url, - self.config, + hash_collection(self.config), ) ) diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index 5aeddaa6..474a610a 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -118,7 +118,7 @@ def __hash__(self): hash(self.aws_session_token), hash(self.profile_name), hash(self.region_name), - hash(frozenset(self.aws_client_parameters.dict().items())), + hash(self.aws_client_parameters), ) return hash(field_hashes) diff --git a/prefect_aws/utilities.py b/prefect_aws/utilities.py new file mode 100644 index 00000000..ad1e6ed2 --- /dev/null +++ b/prefect_aws/utilities.py @@ -0,0 +1,35 @@ +"""Utilities for working with AWS services.""" + +from prefect.utilities.collections import visit_collection + + +def hash_collection(collection) -> int: + """Use visit_collection to transform and hash a collection. + + Args: + collection (Any): The collection to hash. + + Returns: + int: The hash of the transformed collection. + + Example: + ```python + from prefect_aws.utilities import hash_collection + + hash_collection({"a": 1, "b": 2}) + ``` + + """ + + def make_hashable(item): + """Make an item hashable by converting it to a tuple.""" + if isinstance(item, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in item.items())) + elif isinstance(item, list): + return tuple(make_hashable(v) for v in item) + return item + + hashable_collection = visit_collection( + collection, visit_fn=make_hashable, return_data=True + ) + return hash(hashable_collection) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 96ecbd22..6e593212 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -164,3 +164,28 @@ def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field new_hash = hash(credentials) assert initial_hash != new_hash, "Hash should change when region_name changes" + + +def test_aws_credentials_nested_client_parameters_are_hashable(): + """ + Test to ensure that nested client parameters are hashable. + """ + + creds = AwsCredentials( + region_name="us-east-1", + aws_client_parameters=dict( + config=dict( + connect_timeout=5, + read_timeout=5, + retries=dict(max_attempts=10, mode="standard"), + ) + ), + ) + + assert hash(creds) is not None + + client = creds.get_client("s3") + + _client = creds.get_client("s3") + + assert client is _client diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 00000000..0e0fdc6f --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,34 @@ +import pytest + +from prefect_aws.utilities import hash_collection + + +class TestHashCollection: + def test_simple_dict(self): + simple_dict = {"key1": "value1", "key2": "value2"} + assert hash_collection(simple_dict) == hash_collection( + simple_dict + ), "Simple dictionary hashing failed" + + def test_nested_dict(self): + nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"} + assert hash_collection(nested_dict) == hash_collection( + nested_dict + ), "Nested dictionary hashing failed" + + def test_complex_structure(self): + complex_structure = { + "key1": [1, 2, 3], + "key2": {"subkey1": {"subsubkey1": "value"}}, + } + assert hash_collection(complex_structure) == hash_collection( + complex_structure + ), "Complex structure hashing failed" + + def test_unhashable_structure(self): + typically_unhashable_structure = dict(key=dict(subkey=[1, 2, 3])) + with pytest.raises(TypeError): + hash(typically_unhashable_structure) + assert hash_collection(typically_unhashable_structure) == hash_collection( + typically_unhashable_structure + ), "Unhashable structure hashing failed after transformation" From 933c6aa2388ac4aa500c4336952edce6dbbcde60 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Wed, 24 Jan 2024 12:43:03 -0600 Subject: [PATCH 13/22] prep v0.4.9 (#374) --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8121da96..35f0fdc3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed +## 0.4.9 +Released January 24rd, 2024; + +### Fixed + +- Hashing of nested objects within `AwsClientParameters` - [#373](https://github.com/PrefectHQ/prefect-aws/pull/373) + + ## 0.4.8 Released January 23rd, 2024; From ac19635afd83f46bc4ccf864c0624134ed2e9a2a Mon Sep 17 00:00:00 2001 From: kevingrismore <146098880+kevingrismore@users.noreply.github.com> Date: Mon, 29 Jan 2024 12:07:29 -0500 Subject: [PATCH 14/22] Use cached boto3 clients in `ECSWorker` (#375) --- CHANGELOG.md | 9 ++++-- prefect_aws/workers/ecs_worker.py | 49 +++++++++++++------------------ tests/workers/test_ecs_worker.py | 20 +++++++++++++ 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f0fdc3..1445d22c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Use cached boto3 clients in `ECSWorker` - [#375](https://github.com/PrefectHQ/prefect-aws/pull/375) + ### Fixed ### Deprecated @@ -18,15 +20,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed ## 0.4.9 -Released January 24rd, 2024; + +Released January 24rd, 2024. ### Fixed - Hashing of nested objects within `AwsClientParameters` - [#373](https://github.com/PrefectHQ/prefect-aws/pull/373) - ## 0.4.8 -Released January 23rd, 2024; + +Released January 23rd, 2024. ### Added diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 6d3c8ec7..372e6b29 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -51,12 +51,11 @@ import sys import time from copy import deepcopy -from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple +from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union from uuid import UUID import anyio import anyio.abc -import boto3 import yaml from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.server.schemas.core import FlowRun @@ -79,7 +78,7 @@ from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from typing_extensions import Literal -from prefect_aws import AwsCredentials +from prefect_aws.credentials import AwsCredentials, ClientType # Internal type alias for ECS clients which are generated dynamically in botocore _ECSClient = Any @@ -584,8 +583,8 @@ async def run( """ Runs a given flow run on the current worker. """ - boto_session, ecs_client = await run_sync_in_worker_thread( - self._get_session_and_client, configuration + ecs_client = await run_sync_in_worker_thread( + self._get_client, configuration, "ecs" ) logger = self.get_flow_run_logger(flow_run) @@ -598,7 +597,6 @@ async def run( ) = await run_sync_in_worker_thread( self._create_task_and_wait_for_start, logger, - boto_session, ecs_client, configuration, flow_run, @@ -625,7 +623,6 @@ async def run( cluster_arn, task_definition, is_new_task_definition and configuration.auto_deregister_task_definition, - boto_session, ecs_client, ) @@ -636,21 +633,17 @@ async def run( status_code=status_code if status_code is not None else -1, ) - def _get_session_and_client( - self, - configuration: ECSJobConfiguration, - ) -> Tuple[boto3.Session, _ECSClient]: + def _get_client( + self, configuration: ECSJobConfiguration, client_type: Union[str, ClientType] + ) -> _ECSClient: """ - Retrieve a boto3 session and ECS client + Get a boto3 client of client_type. Will use a cached client if one exists. """ - boto_session = configuration.aws_credentials.get_boto3_session() - ecs_client = boto_session.client("ecs") - return boto_session, ecs_client + return configuration.aws_credentials.get_client(client_type) def _create_task_and_wait_for_start( self, logger: logging.Logger, - boto_session: boto3.Session, ecs_client: _ECSClient, configuration: ECSJobConfiguration, flow_run: FlowRun, @@ -741,7 +734,6 @@ def _create_task_and_wait_for_start( # Prepare the task run request task_run_request = self._prepare_task_run_request( - boto_session, configuration, task_definition, task_definition_arn, @@ -782,7 +774,6 @@ def _watch_task_and_get_exit_code( cluster_arn: str, task_definition: dict, deregister_task_definition: bool, - boto_session: boto3.Session, ecs_client: _ECSClient, ) -> Optional[int]: """ @@ -798,7 +789,6 @@ def _watch_task_and_get_exit_code( cluster_arn, task_definition, ecs_client, - boto_session, ) if deregister_task_definition: @@ -992,7 +982,6 @@ def _wait_for_task_finish( cluster_arn: str, task_definition: dict, ecs_client: _ECSClient, - boto_session: boto3.Session, ): """ Watch an ECS task until it reaches a STOPPED status. @@ -1031,7 +1020,7 @@ def _wait_for_task_finish( else: # Prepare to stream the output log_config = container_def["logConfiguration"]["options"] - logs_client = boto_session.client("logs") + logs_client = self._get_client(configuration, "logs") can_stream_output = True # Track the last log timestamp to prevent double display last_log_timestamp: Optional[int] = None @@ -1300,13 +1289,13 @@ def _prepare_task_definition( return task_definition def _load_network_configuration( - self, vpc_id: Optional[str], boto_session: boto3.Session + self, vpc_id: Optional[str], configuration: ECSJobConfiguration ) -> dict: """ Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ - ec2_client = boto_session.client("ec2") + ec2_client = self._get_client(configuration, "ec2") vpc_message = "the default VPC" if not vpc_id else f"VPC with ID {vpc_id}" if not vpc_id: @@ -1347,13 +1336,16 @@ def _load_network_configuration( } def _custom_network_configuration( - self, vpc_id: str, network_configuration: dict, boto_session: boto3.Session + self, + vpc_id: str, + network_configuration: dict, + configuration: ECSJobConfiguration, ) -> dict: """ Load settings from a specific VPC or the default VPC and generate a task run request's network configuration. """ - ec2_client = boto_session.client("ec2") + ec2_client = self._get_client(configuration, "ec2") vpc_message = f"VPC with ID {vpc_id}" vpcs = ec2_client.describe_vpcs(VpcIds=[vpc_id]).get("Vpcs") @@ -1389,7 +1381,6 @@ def _custom_network_configuration( def _prepare_task_run_request( self, - boto_session: boto3.Session, configuration: ECSJobConfiguration, task_definition: dict, task_definition_arn: str, @@ -1422,7 +1413,7 @@ def _prepare_task_run_request( and not configuration.network_configuration ): task_run_request["networkConfiguration"] = self._load_network_configuration( - configuration.vpc_id, boto_session + configuration.vpc_id, configuration ) # Use networkConfiguration if supplied by user @@ -1435,7 +1426,7 @@ def _prepare_task_run_request( self._custom_network_configuration( configuration.vpc_id, configuration.network_configuration, - boto_session, + configuration, ) ) @@ -1628,7 +1619,7 @@ def _stop_task( f"{cluster!r}." ) - _, ecs_client = self._get_session_and_client(configuration) + ecs_client = self._get_client(configuration, "ecs") try: ecs_client.stop_task(cluster=cluster, task=task) except Exception as exc: diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index 077a178a..c9fbaabf 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -21,6 +21,7 @@ from tenacity import RetryError +from prefect_aws.credentials import _get_client_cached from prefect_aws.workers.ecs_worker import ( _TASK_DEFINITION_CACHE, ECS_DEFAULT_CONTAINER_NAME, @@ -2183,6 +2184,25 @@ async def test_retry_on_failed_task_start( assert run_task_mock.call_count == 3 +@pytest.mark.usefixtures("ecs_mocks") +async def test_worker_uses_cached_boto3_client(aws_credentials: AwsCredentials): + configuration = await construct_configuration( + aws_credentials=aws_credentials, + ) + + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + async with ECSWorker(work_pool_name="test") as worker: + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + worker._get_client(configuration, "ecs") + + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 + + async def test_mask_sensitive_env_values(): task_run_request = { "overrides": { From 87755b2d7fd979d7975486abdd56f2e4d191e2f7 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Thu, 8 Feb 2024 15:32:19 -0500 Subject: [PATCH 15/22] The title (#380) --- .github/PULL_REQUEST_TEMPLATE.md | 1 - CHANGELOG.md | 404 ------------------------------- 2 files changed, 405 deletions(-) delete mode 100644 CHANGELOG.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 1db88eb0..f7bec8da 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -25,4 +25,3 @@ Any relevant screenshots - Run `pre-commit install && pre-commit run --all` locally for formatting and linting. - [ ] Includes screenshots of documentation updates. - Run `mkdocs serve` view documentation locally. -- [ ] Summarizes PR's changes in [CHANGELOG.md](https://github.com/PrefectHQ/prefect-aws/blob/main/CHANGELOG.md) diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index 1445d22c..00000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,404 +0,0 @@ -# Changelog - -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## Unreleased - -### Added - -### Changed - -- Use cached boto3 clients in `ECSWorker` - [#375](https://github.com/PrefectHQ/prefect-aws/pull/375) - -### Fixed - -### Deprecated - -### Removed - -## 0.4.9 - -Released January 24rd, 2024. - -### Fixed - -- Hashing of nested objects within `AwsClientParameters` - [#373](https://github.com/PrefectHQ/prefect-aws/pull/373) - -## 0.4.8 - -Released January 23rd, 2024. - -### Added - -- Support MinIO Credentials as `credentials` dict for `push_to_s3` and `pull_from_s3` - [#366](https://github.com/PrefectHQ/prefect-aws/pull/366) - -### Changed - -- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) - -## 0.4.7 - -Released January 3rd, 2024. - -### Added - -- `LambdaFunction` block to invoke lambda functions - [#355](https://github.com/PrefectHQ/prefect-aws/pull/355) - -### Fixed - -- Bug where `S3Bucket.load()` constructed `AwsCredentials` instead of `MinIOCredentials` - [#359](https://github.com/PrefectHQ/prefect-aws/pull/359) - -## 0.4.6 - -Released December 11th, 2023. - -### Added - -Ability to publish `ECSTask`` block as an ecs work pool - [#353](https://github.com/PrefectHQ/prefect-aws/pull/353) - -## 0.4.5 - -Released November 30th, 2023. - -### Fixed - -- Bug where Prefect API key provided to ECS tasks was masked - [#347](https://github.com/PrefectHQ/prefect-aws/pull/347) - -## 0.4.4 - -Released November 29th, 2023. - -### Changed - -- Mask Prefect API key in logs - [#341](https://github.com/PrefectHQ/prefect-aws/pull/341) - -## 0.4.3 - -Released November 13th, 2023. - -### Added - -- `SecretBrinary` suport to `AwsSecret` block - [#274](https://github.com/PrefectHQ/prefect-aws/pull/274) - -## 0.4.2 - -Released November 6th, 2023. - -### Fixed - -- Fixed use_ssl default for s3 client. - -## 0.4.1 - -Released October 13th, 2023. - -### Added - -- AWS S3 copy and move tasks and `S3Bucket` methods - [#316](https://github.com/PrefectHQ/prefect-aws/pull/316) - -### Fixed - -- `ECSWorker` issue where defining a custom network configuration with a subnet would erroneously report it as missing from the VPC when more than one subnet exists in the VPC. [#321](https://github.com/PrefectHQ/prefect-aws/pull/321) -- Updated `push_to_s3` and `pull_from_s3` deployment steps to properly create a boto3 session client if the passed credentials are a referenced `AwsCredentials` block [#322](https://github.com/PrefectHQ/prefect-aws/pull/322) - -## 0.4.0 - -Released October 5th, 2023. - -### Changed - -- Changed `push_to_s3` deployment step function to write paths `as_posix()` to allow support for deploying from windows [#314](https://github.com/PrefectHQ/prefect-aws/pull/314) -- Conditional imports to support operating with pydantic>2 installed - [#317](https://github.com/PrefectHQ/prefect-aws/pull/317) - -## 0.3.7 - -Released August 31st, 2023. - -### Added - -- Added retries to ECS task run creation for ECS worker - [#303](https://github.com/PrefectHQ/prefect-aws/pull/303) -- Added support to `ECSWorker` for `awsvpcConfiguration` [#304](https://github.com/PrefectHQ/prefect-aws/pull/304) - -## 0.3.6 - -Released July 20th, 2023. - -### Changed - -- Promoted workers to GA, removed beta disclaimers - -## 0.3.5 - -Released on July 14th, 2023. - -### Fixed - -- Fixed `S3Bucket.stream_from` path resolution - [#291](https://github.com/PrefectHQ/prefect-aws/pull/291) -- Fixed `ECSWorker` debug logs from failing to parse json - [#296](https://github.com/PrefectHQ/prefect-aws/pull/296) - -## 0.3.4 - -Released on June 15th, 2023. - -### Added - -- Added `S3Bucket.stream_from` to copy objects between buckets - [#276](https://github.com/PrefectHQ/prefect-aws/pull/276) - -### Deprecated - -- `prefect_aws.projects` module. Use `prefect_aws.deployments` instead. - [#278](https://github.com/PrefectHQ/prefect-aws/pull/278) -- `pull_project_from_s3` step. Use `pull_from_s3` instead. - [#278](https://github.com/PrefectHQ/prefect-aws/pull/278) -- `push_project_to_s3` step. Use `push_to_s3` instead. - [#278](https://github.com/PrefectHQ/prefect-aws/pull/278) -- `PullProjectFromS3Output` step output. Use `PullFromS3Output` instead. - [#278](https://github.com/PrefectHQ/prefect-aws/pull/278) -- `PushProjectToS3Output` step output. Use `PushToS3Output` instead. - [#278](https://github.com/PrefectHQ/prefect-aws/pull/278) - -## 0.3.3 - -Released on June 13th, 2023. - -### Fixed - -- Change prefect.docker import to prefect.utilities.dockerutils to fix a crash when using custom blocks based on S3Bucket - [#273](https://github.com/PrefectHQ/prefect-aws/pull/273) - -## 0.3.2 - -Released on May 25th, 2023. - -### Added - -- Stream ECS Worker flow run logs to the API - [#267](https://github.com/PrefectHQ/prefect-aws/pull/267) - -### Fixed - -- Fixed bug where incorrect credentials model was selected when `MinIOCredentials` was used with `S3Bucket` - [#254](https://github.com/PrefectHQ/prefect-aws/pull/254) -- Fixed bug where `S3Bucket.list_objects` was truncating prefix paths ending with slashes - [#263](https://github.com/PrefectHQ/prefect-aws/pull/263) -- Fixed bug where ECS worker could not cancel flow runs - [#268](https://github.com/PrefectHQ/prefect-aws/pull/268) - -## 0.3.1 - -Released on April 20th, 2023. - -### Added - -- `ECSWorker` for executing Prefect flow runs as ECS tasks - [#238](https://github.com/PrefectHQ/prefect-aws/pull/238) - -### Fixed - -- Fixes retrieving files from large buckets via pagination in the `pull_project_from_s3` step - [#240](https://github.com/PrefectHQ/prefect-aws/pull/240) -- Slugify tags to ensure compatibility with ECS limitations - [#245](https://github.com/PrefectHQ/prefect-aws/pull/245) - -## 0.3.0 - -Released on April 6th, 2023. - -### Added - -- Support for unsigned AWS requests - [#220](https://github.com/PrefectHQ/prefect-aws/pull/220) -- Added push and pull project steps for S3 - [#229](https://github.com/PrefectHQ/prefect-aws/pull/229) -- `basepath` property to `S3Bucket` to maintain compatibility with storage block based deployments - [#231](https://github.com/PrefectHQ/prefect-aws/pull/231) - -### Changed - -- Added string support to `JsonPatch` implementation for task customizations Link [#233](https://github.com/PrefectHQ/prefect-aws/pull/233) - -### Removed - -- `basepath`, `aws_credentials` and `minio_credentials` fields from `S3Bucket` - [#231](https://github.com/PrefectHQ/prefect-aws/pull/231) - -## 0.2.5 - -Released on March 13th, 2023. - -### Fixed - -- Fixed errors raised when using `write_path` and `read_path` with `credentials` field on `S3Bucket` - [#208](https://github.com/PrefectHQ/prefect-aws/pull/208) -- Resolving paths in `S3Bucket` unintentionally generating an arbitrary UUID when path is an empty string - [#212](https://github.com/PrefectHQ/prefect-aws/pull/212) -- Fixed crashes when pausing flow runs executed with `ECSTask` - [#218](https://github.com/PrefectHQ/prefect-aws/pull/218) - -## 0.2.4 - -Released on January 23rd, 2023. - -### Added - -- `AwsSecret` block with `read_secret`, `write_secret`, and `delete_secret` methods - [#176](https://github.com/PrefectHQ/prefect-aws/pull/176) - -### Changed - -- Object keys sent in S3 requests use '/' delimiters instead of system default - [#192](https://github.com/PrefectHQ/prefect-aws/pull/192) - -### Fixed - -- Fix bug where ECSTask could fail to stream logs - [#186](https://github.com/PrefectHQ/prefect-aws/pull/186) - -## 0.2.3 - -Released on January 4th, 2023. - -### Fixed - -- Missing `mypy_boto3_s3` in requirements.txt - [#189](https://github.com/PrefectHQ/prefect-aws/pull/189) - -## 0.2.2 - -Released on January 4th, 2023. - -### Added - -- `list_objects`, `download_object_to_path`, `download_object_to_file_object`, `download_folder_to_path`, `upload_from_path`, `upload_from_file_object`, `upload_from_folder` methods in `S3Bucket` - [#85](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `aws_client_parameters` as a field in `AwsCredentials` and `MinioCredentials` blocks - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `get_client` and `get_s3_client` methods to `AwsCredentials` and `MinioCredentials` blocks - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) - -### Changed - -- `S3Bucket` additionally inherits from `ObjectStorageBlock` - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- Exposed all existing blocks to the top level init - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- Inherit `CredentialsBlock` for `AwsCredentials` and `MinIOCredentials` - [#183](https://github.com/PrefectHQ/prefect-aws/pull/183) - -### Deprecated - -- `endpoint_url` field in S3Bucket; specify `aws_client_parameters` in `AwsCredentials` or `MinIOCredentials` instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `basepath` field in S3Bucket; specify `bucket_folder` instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) -- `minio_credentials` and `aws_credentials` field in S3Bucket; use the `credentials` field instead - [#175](https://github.com/PrefectHQ/prefect-aws/pull/175) - -## 0.2.1 - -Released on December 7th, 2022. - -### Changed - -- `ECSTask` now logs the difference between the requested and the pre-registered task definition when using a `task_definition_arn` - [#166](https://github.com/PrefectHQ/prefect-aws/pull/166) -- Default of `S3Bucket` to be an empty string rather than None - [#169](https://github.com/PrefectHQ/prefect-aws/pull/169) - -### Fixed - -- Deployments of `S3Bucket` - [#169](https://github.com/PrefectHQ/prefect-aws/pull/169) -- The image from `task_definition_arn` will be respected if `image` is not explicitly set - [#170](https://github.com/PrefectHQ/prefect-aws/pull/170) - -## 0.2.0 - -Released on December 2nd, 2022. - -### Added - -- `ECSTask.kill` method for cancellation support - [#163](https://github.com/PrefectHQ/prefect-aws/pull/163) - -### Changed - -- Breaking: Identifiers `ECSTask` now include the cluster in addition to the task ARN - [#163](https://github.com/PrefectHQ/prefect-aws/pull/163) -- Bumped minimum required `prefect` version - [#154](https://github.com/PrefectHQ/prefect-aws/pull/154) - -## 0.1.8 - -Released on November 16th, 2022. - -### Added - -- Added `family` field to `ECSTask` to configure task definition family names — [#152](https://github.com/PrefectHQ/prefect-aws/pull/152) - -### Changed - -- Changes the default `ECSTask` family to include the flow and deployment names if available — [#152](https://github.com/PrefectHQ/prefect-aws/pull/152) - -### Fixed - -- Fixed failure while watching ECS task execution when the task is missing — [#153](https://github.com/PrefectHQ/prefect-aws/pull/153) - -## 0.1.7 - -Released on October 28th, 2022. - -### Changed - -- `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137) - -### Fixed - -- Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139) -- Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143) - -## 0.1.6 - -Released on October 19th, 2022. - -### Added - -- `get_directory` and `put_directory` methods on `S3Bucket`. The `S3Bucket` block is now usable for remote flow storage with deployments. - [#82](https://github.com/PrefectHQ/prefect-aws/pull/82) - -## 0.1.5 - -Released on October 14th, 2022. - -### Added - -- Add `ECSTask.cloudwatch_logs_options` for customization of CloudWatch logging — [#116](https://github.com/PrefectHQ/prefect-aws/pull/116) -- Added `config` parameter to AwsClientParameters to support advanced configuration (e.g. accessing public S3 buckets) [#117](https://github.com/PrefectHQ/prefect-aws/pull/117) -- Add `@sync_compatible` to `S3Bucket` methods to allow calling them in sync contexts - [#119](https://github.com/PrefectHQ/prefect-aws/pull/119). -- Add `ECSTask.task_customizations` for customization of arbitary fields in the run task payload — [#120](https://github.com/PrefectHQ/prefect-aws/pull/120) - -### Fixed - -- Fix configuration to submit doc edits via GitHub - [#110](https://github.com/PrefectHQ/prefect-aws/pull/110) -- Removed invalid ecs task register fields - [#126](https://github.com/PrefectHQ/prefect-aws/issues/126) - -## 0.1.4 - -Released on September 13th, 2022. - -### Changed - -- Increased default timeout on the `ECSTask` block - [#106](https://github.com/PrefectHQ/prefect-aws/pull/106) - -## 0.1.3 - -Released on September 12th, 2022. - -### Added - -- `client_waiter` task - [#43](https://github.com/PrefectHQ/prefect-aws/pull/43) -- `ECSTask` infrastructure block - [#85](https://github.com/PrefectHQ/prefect-aws/pull/85) - -## 0.1.2 - -Released on August 2nd, 2022. - -### Added - -- `batch_submit` task - [#41](https://github.com/PrefectHQ/prefect-aws/pull/41) -- `MinIOCredentials` block - [#46](https://github.com/PrefectHQ/prefect-aws/pull/46) -- `S3Bucket` block - [#47](https://github.com/PrefectHQ/prefect-aws/pull/47) - -### Changed - -- Converted `AwsCredentials` into a `Block` [#45](https://github.com/PrefectHQ/prefect-aws/pull/45) - -### Deprecated - -### Removed - -- Removed `.result()` and `is_complete` on test flow calls. [#45](https://github.com/PrefectHQ/prefect-aws/pull/45) - -## 0.1.1 - -## Added - -Released on April 18th, 2022. - -- `AwsClientParameters` for added configuration of the `boto3` S3 client - [#29](https://github.com/PrefectHQ/prefect-aws/pull/29) - - Contributed by [davzucky](https://github.com/davzucky) -- Added boto3 client type hinting via `types-boto3` - [#26](https://github.com/PrefectHQ/prefect-aws/pull/26) - - Contributed by [davzucky](https://github.com/davzucky) - -## 0.1.0 - -Released on March 9th, 2022. - -### Added - -- `s3_download`, `s3_upload` and `s3_list_objects` tasks -- `read_secret` task - [#6](https://github.com/PrefectHQ/prefect-aws/pull/6) -- `update_secret` task - [#12](https://github.com/PrefectHQ/prefect-aws/pull/12) -- `create_secret` and `delete_secret` tasks - [#13](https://github.com/PrefectHQ/prefect-aws/pull/13) From 3481235652e27babb9e2c6ce9134a2838218a026 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Tue, 27 Feb 2024 11:01:06 -0600 Subject: [PATCH 16/22] update pytest requirements (#387) --- requirements-dev.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index bcbdc906..d29a1074 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,8 +15,8 @@ moto >= 3.1.16, < 4.2.5 mypy pillow pre-commit -pytest -pytest-asyncio +pytest > 7, < 8 +pytest-asyncio >= 0.18.2, != 0.22.0, < 0.23.0 # Cannot override event loop in 0.23.0. See https://github.com/pytest-dev/pytest-asyncio/issues/706 for more details. pytest-cov pytest-lazy-fixture pytest-xdist From f14daa4918c1dc0d9ea8c7cef4530b8a0cfdc1d4 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Tue, 27 Feb 2024 12:49:09 -0600 Subject: [PATCH 17/22] Removes `pytest-lazy-fixture` dependency to allow `pytest` to be unpinned (#388) --- requirements-dev.txt | 3 +-- tests/test_lambda_function.py | 7 ++++--- tests/test_s3.py | 23 +++++++++++------------ 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index d29a1074..1a71f0be 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,9 +15,8 @@ moto >= 3.1.16, < 4.2.5 mypy pillow pre-commit -pytest > 7, < 8 +pytest pytest-asyncio >= 0.18.2, != 0.22.0, < 0.23.0 # Cannot override event loop in 0.23.0. See https://github.com/pytest-dev/pytest-asyncio/issues/706 for more details. pytest-cov -pytest-lazy-fixture pytest-xdist types-boto3 >= 1.0.2 diff --git a/tests/test_lambda_function.py b/tests/test_lambda_function.py index 32210629..92880fdf 100644 --- a/tests/test_lambda_function.py +++ b/tests/test_lambda_function.py @@ -8,7 +8,6 @@ import pytest from botocore.response import StreamingBody from moto import mock_iam, mock_lambda -from pytest_lazyfixture import lazy_fixture from prefect_aws.credentials import AwsCredentials from prefect_aws.lambda_function import LambdaFunction @@ -232,8 +231,8 @@ def test_invoke_lambda_client_context( @pytest.mark.parametrize( "func_fixture,expected,handler", [ - (lazy_fixture("mock_lambda_function"), {"foo": "bar"}, handler_a), - (lazy_fixture("add_lambda_version"), {"data": [1, 2, 3]}, handler_b), + ("mock_lambda_function", {"foo": "bar"}, handler_a), + ("add_lambda_version", {"data": [1, 2, 3]}, handler_b), ], ) def test_invoke_lambda_qualifier( @@ -242,7 +241,9 @@ def test_invoke_lambda_qualifier( expected, lambda_function: LambdaFunction, mock_invoke, + request, ): + func_fixture = request.getfixturevalue(func_fixture) try: lambda_function.qualifier = func_fixture["Version"] result = lambda_function.invoke() diff --git a/tests/test_s3.py b/tests/test_s3.py index 93d11cc1..3dc83d91 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -8,7 +8,6 @@ from moto import mock_s3 from prefect import flow from prefect.deployments import Deployment -from pytest_lazyfixture import lazy_fixture from prefect_aws import AwsCredentials, MinIOCredentials from prefect_aws.client_parameters import AwsClientParameters @@ -22,9 +21,9 @@ ) aws_clients = [ - (lazy_fixture("aws_client_parameters_custom_endpoint")), - (lazy_fixture("aws_client_parameters_empty")), - (lazy_fixture("aws_client_parameters_public_bucket")), + "aws_client_parameters_custom_endpoint", + "aws_client_parameters_empty", + "aws_client_parameters_public_bucket", ] @@ -38,7 +37,7 @@ def s3_mock(monkeypatch, client_parameters): @pytest.fixture def client_parameters(request): - client_parameters = request.param + client_parameters = request.getfixturevalue(request.param) return client_parameters @@ -114,7 +113,7 @@ def a_lot_of_objects(bucket, tmp_path): @pytest.mark.parametrize( "client_parameters", - [lazy_fixture("aws_client_parameters_custom_endpoint")], + ["aws_client_parameters_custom_endpoint"], indirect=True, ) async def test_s3_download_failed_with_wrong_endpoint_setup( @@ -141,30 +140,30 @@ async def test_flow(): "client_parameters", [ pytest.param( - lazy_fixture("aws_client_parameters_custom_endpoint"), + "aws_client_parameters_custom_endpoint", marks=pytest.mark.is_public(False), ), pytest.param( - lazy_fixture("aws_client_parameters_custom_endpoint"), + "aws_client_parameters_custom_endpoint", marks=pytest.mark.is_public(True), ), pytest.param( - lazy_fixture("aws_client_parameters_empty"), + "aws_client_parameters_empty", marks=pytest.mark.is_public(False), ), pytest.param( - lazy_fixture("aws_client_parameters_empty"), + "aws_client_parameters_empty", marks=pytest.mark.is_public(True), ), pytest.param( - lazy_fixture("aws_client_parameters_public_bucket"), + "aws_client_parameters_public_bucket", marks=[ pytest.mark.is_public(False), pytest.mark.xfail(reason="Bucket is not a public one"), ], ), pytest.param( - lazy_fixture("aws_client_parameters_public_bucket"), + "aws_client_parameters_public_bucket", marks=pytest.mark.is_public(True), ), ], From a2070b27c7ca3ff11e2b8e2e7b8ed2db5407ef4e Mon Sep 17 00:00:00 2001 From: Kevin Grismore <146098880+kevingrismore@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:08:30 -0600 Subject: [PATCH 18/22] Add ECS worker option to use most recent revision in task definition family (#370) Co-authored-by: nate nowack --- docs/gen_examples_catalog.py | 2 +- prefect_aws/workers/ecs_worker.py | 141 ++++++++++++++++++------------ tests/workers/test_ecs_worker.py | 91 +++++++++++++++++++ 3 files changed, 177 insertions(+), 57 deletions(-) diff --git a/docs/gen_examples_catalog.py b/docs/gen_examples_catalog.py index ba8b1c7c..9293c427 100644 --- a/docs/gen_examples_catalog.py +++ b/docs/gen_examples_catalog.py @@ -56,7 +56,7 @@ def get_code_examples(obj: Union[ModuleType, Callable]) -> Set[str]: for section in parsed_sections: if section.kind == DocstringSectionKind.examples: code_example = "\n".join( - (part[1] for part in section.as_dict().get("value", [])) + part[1] for part in section.as_dict().get("value", []) ) if not skip_block_load_code_example(code_example): code_examples.add(code_example) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 372e6b29..2e63eac4 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -267,6 +267,7 @@ class ECSJobConfiguration(BaseJobConfiguration): vpc_id: Optional[str] = Field(default=None) container_name: Optional[str] = Field(default=None) cluster: Optional[str] = Field(default=None) + match_latest_revision_in_family: bool = Field(default=False) @root_validator def task_run_request_requires_arn_if_no_task_definition_given(cls, values) -> dict: @@ -550,6 +551,16 @@ class ECSVariables(BaseVariables): "your AWS account, instead it will be marked as INACTIVE." ), ) + match_latest_revision_in_family: bool = Field( + default=False, + description=( + "If enabled, the most recent active revision in the task definition " + "family will be compared against the desired ECS task configuration. " + "If they are equal, the existing task definition will be used instead " + "of registering a new one. If no family is specified the default family " + f'"{ECS_DEFAULT_FAMILY}" will be used.' + ), + ) class ECSWorkerResult(BaseWorkerResult): @@ -661,55 +672,15 @@ def _create_task_and_wait_for_start( new_task_definition_registered = False if not task_definition_arn: - cached_task_definition_arn = _TASK_DEFINITION_CACHE.get( - flow_run.deployment_id - ) task_definition = self._prepare_task_definition( configuration, region=ecs_client.meta.region_name ) - - if cached_task_definition_arn: - # Read the task definition to see if the cached task definition is valid - try: - cached_task_definition = self._retrieve_task_definition( - logger, ecs_client, cached_task_definition_arn - ) - except Exception as exc: - logger.warning( - "Failed to retrieve cached task definition" - f" {cached_task_definition_arn!r}: {exc!r}" - ) - # Clear from cache - _TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None) - cached_task_definition_arn = None - else: - if not cached_task_definition["status"] == "ACTIVE": - # Cached task definition is not active - logger.warning( - "Cached task definition" - f" {cached_task_definition_arn!r} is not active" - ) - _TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None) - cached_task_definition_arn = None - elif not self._task_definitions_equal( - task_definition, cached_task_definition - ): - # Cached task definition is not valid - logger.warning( - "Cached task definition" - f" {cached_task_definition_arn!r} does not meet" - " requirements" - ) - _TASK_DEFINITION_CACHE.pop(flow_run.deployment_id, None) - cached_task_definition_arn = None - - if not cached_task_definition_arn: - task_definition_arn = self._register_task_definition( - logger, ecs_client, task_definition - ) - new_task_definition_registered = True - else: - task_definition_arn = cached_task_definition_arn + ( + task_definition_arn, + new_task_definition_registered, + ) = self._get_or_register_task_definition( + logger, ecs_client, configuration, flow_run, task_definition + ) else: task_definition = self._retrieve_task_definition( logger, ecs_client, task_definition_arn @@ -722,9 +693,6 @@ def _create_task_and_wait_for_start( self._validate_task_definition(task_definition, configuration) - # Update the cached task definition ARN to avoid re-registering the task - # definition on this worker unless necessary; registration is agressively - # rate limited by AWS _TASK_DEFINITION_CACHE[flow_run.deployment_id] = task_definition_arn logger.info(f"Using ECS task definition {task_definition_arn!r}...") @@ -732,7 +700,6 @@ def _create_task_and_wait_for_start( f"Task definition {json.dumps(task_definition, indent=2, default=str)}" ) - # Prepare the task run request task_run_request = self._prepare_task_run_request( configuration, task_definition, @@ -753,7 +720,6 @@ def _create_task_and_wait_for_start( self._report_task_run_creation_failure(configuration, task_run_request, exc) raise - # Raises an exception if the task does not start logger.info("Waiting for ECS task run to start...") self._wait_for_task_start( logger, @@ -766,6 +732,65 @@ def _create_task_and_wait_for_start( return task_arn, cluster_arn, task_definition, new_task_definition_registered + def _get_or_register_task_definition( + self, + logger: logging.Logger, + ecs_client: _ECSClient, + configuration: ECSJobConfiguration, + flow_run: FlowRun, + task_definition: dict, + ) -> Tuple[str, bool]: + """Get or register a task definition for the given flow run. + + Returns a tuple of the task definition ARN and a bool indicating if the task + definition is newly registered. + """ + + cached_task_definition_arn = _TASK_DEFINITION_CACHE.get(flow_run.deployment_id) + new_task_definition_registered = False + + if cached_task_definition_arn: + try: + cached_task_definition = self._retrieve_task_definition( + logger, ecs_client, cached_task_definition_arn + ) + if not cached_task_definition[ + "status" + ] == "ACTIVE" or not self._task_definitions_equal( + task_definition, cached_task_definition + ): + cached_task_definition_arn = None + except Exception: + cached_task_definition_arn = None + + if ( + not cached_task_definition_arn + and configuration.match_latest_revision_in_family + ): + family_name = task_definition.get("family", ECS_DEFAULT_FAMILY) + try: + task_definition_from_family = self._retrieve_task_definition_by_family( + logger, ecs_client, family_name + ) + if task_definition_from_family and self._task_definitions_equal( + task_definition, task_definition_from_family + ): + cached_task_definition_arn = task_definition_from_family[ + "taskDefinitionArn" + ] + except Exception: + pass + + if not cached_task_definition_arn: + task_definition_arn = self._register_task_definition( + logger, ecs_client, task_definition + ) + new_task_definition_registered = True + else: + task_definition_arn = cached_task_definition_arn + + return task_definition_arn, new_task_definition_registered + def _watch_task_and_get_exit_code( self, logger: logging.Logger, @@ -928,15 +953,19 @@ def _retrieve_task_definition( self, logger: logging.Logger, ecs_client: _ECSClient, - task_definition_arn: str, + task_definition: str, ): """ Retrieve an existing task definition from AWS. """ - logger.info(f"Retrieving ECS task definition {task_definition_arn!r}...") - response = ecs_client.describe_task_definition( - taskDefinition=task_definition_arn - ) + if task_definition.startswith("arn:aws:ecs:"): + logger.info(f"Retrieving ECS task definition {task_definition!r}...") + else: + logger.info( + "Retrieving most recent active revision from " + f"ECS task family {task_definition!r}..." + ) + response = ecs_client.describe_task_definition(taskDefinition=task_definition) return response["taskDefinition"] def _wait_for_task_start( diff --git a/tests/workers/test_ecs_worker.py b/tests/workers/test_ecs_worker.py index c9fbaabf..702f2ddd 100644 --- a/tests/workers/test_ecs_worker.py +++ b/tests/workers/test_ecs_worker.py @@ -1415,6 +1415,97 @@ async def test_deregister_task_definition_does_not_apply_to_linked_arn( describe_task_definition(ecs_client, task)["status"] == "ACTIVE" +@pytest.mark.usefixtures("ecs_mocks") +async def test_match_latest_revision_in_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + ) + + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + ) + + configuration_3 = await construct_configuration( + aws_credentials=aws_credentials, + match_latest_revision_in_family=True, + execution_role_arn="test", + ) + + # Let the first worker run and register two task definitions + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration_1, flow_run) + result_1 = await run_then_stop_task(worker, configuration_2, flow_run) + + # Start a new worker with an empty cache + async with ECSWorker(work_pool_name="test") as worker: + result_2 = await run_then_stop_task(worker, configuration_3, flow_run) + + assert result_1.status_code == 0 + _, task_arn_1 = parse_identifier(result_1.identifier) + + assert result_2.status_code == 0 + _, task_arn_2 = parse_identifier(result_2.identifier) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + assert task_2["taskDefinitionArn"].endswith(":2") + + +@pytest.mark.usefixtures("ecs_mocks") +async def test_match_latest_revision_in_family_custom_family( + aws_credentials: AwsCredentials, flow_run: FlowRun +): + session = aws_credentials.get_boto3_session() + ecs_client = session.client("ecs") + + configuration_1 = await construct_configuration( + aws_credentials=aws_credentials, + family="test-family", + ) + + configuration_2 = await construct_configuration( + aws_credentials=aws_credentials, + execution_role_arn="test", + family="test-family", + ) + + configuration_3 = await construct_configuration( + aws_credentials=aws_credentials, + match_latest_revision_in_family=True, + execution_role_arn="test", + family="test-family", + ) + + # Let the first worker run and register two task definitions + async with ECSWorker(work_pool_name="test") as worker: + await run_then_stop_task(worker, configuration_1, flow_run) + result_1 = await run_then_stop_task(worker, configuration_2, flow_run) + + # Start a new worker with an empty cache + async with ECSWorker(work_pool_name="test") as worker: + result_2 = await run_then_stop_task(worker, configuration_3, flow_run) + + assert result_1.status_code == 0 + _, task_arn_1 = parse_identifier(result_1.identifier) + + assert result_2.status_code == 0 + _, task_arn_2 = parse_identifier(result_2.identifier) + + task_1 = describe_task(ecs_client, task_arn_1) + task_2 = describe_task(ecs_client, task_arn_2) + + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] + assert task_2["taskDefinitionArn"].endswith(":2") + + @pytest.mark.usefixtures("ecs_mocks") async def test_worker_caches_registered_task_definitions( aws_credentials: AwsCredentials, flow_run: FlowRun From c24bf1408f86c94b444a5c4737ee4229f7e697aa Mon Sep 17 00:00:00 2001 From: James Martin Date: Tue, 5 Mar 2024 01:37:39 +1100 Subject: [PATCH 19/22] Fixed Batch docs to reference the correct order of arguments for the batch_submit task (#391) --- prefect_aws/batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prefect_aws/batch.py b/prefect_aws/batch.py index 0dc57810..8abfb50a 100644 --- a/prefect_aws/batch.py +++ b/prefect_aws/batch.py @@ -21,8 +21,8 @@ async def batch_submit( Args: job_name: The AWS batch job name. - job_definition: The AWS batch job definition. job_queue: Name of the AWS batch job queue. + job_definition: The AWS batch job definition. aws_credentials: Credentials to use for authentication with AWS. **batch_kwargs: Additional keyword arguments to pass to the boto3 `submit_job` function. See the documentation for @@ -49,8 +49,8 @@ def example_batch_submit_flow(): ) job_id = batch_submit( "job_name", - "job_definition", "job_queue", + "job_definition", aws_credentials ) return job_id From 72c106fe6fb0dde5bff9ac321d5be5b84b0eb3db Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Wed, 6 Mar 2024 13:17:52 -0600 Subject: [PATCH 20/22] Adds porting of network configuration to generated base job templates (#392) --- prefect_aws/ecs.py | 19 ++++++++++ prefect_aws/utilities.py | 81 ++++++++++++++++++++++++++++++++++++++++ tests/test_ecs.py | 28 ++++++++++---- tests/test_utilities.py | 59 ++++++++++++++++++++++++++++- 4 files changed, 178 insertions(+), 9 deletions(-) diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index ac5dd333..6368748f 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -126,6 +126,8 @@ from prefect.utilities.pydantic import JsonPatch from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.utilities import assemble_document_for_patches + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import Field, root_validator, validator else: @@ -739,6 +741,23 @@ async def generate_work_pool_base_job_template(self) -> dict: ) if self.task_customizations: + network_config_patches = JsonPatch( + [ + patch + for patch in self.task_customizations + if "networkConfiguration" in patch["path"] + ] + ) + minimal_network_config = assemble_document_for_patches( + network_config_patches + ) + if minimal_network_config: + minimal_network_config_with_patches = network_config_patches.apply( + minimal_network_config + ) + base_job_template["variables"]["properties"]["network_configuration"][ + "default" + ] = minimal_network_config_with_patches["networkConfiguration"] try: base_job_template["job_configuration"]["task_run_request"] = ( self.task_customizations.apply( diff --git a/prefect_aws/utilities.py b/prefect_aws/utilities.py index ad1e6ed2..33b6cdc6 100644 --- a/prefect_aws/utilities.py +++ b/prefect_aws/utilities.py @@ -1,5 +1,7 @@ """Utilities for working with AWS services.""" +from typing import Dict, List, Union + from prefect.utilities.collections import visit_collection @@ -33,3 +35,82 @@ def make_hashable(item): collection, visit_fn=make_hashable, return_data=True ) return hash(hashable_collection) + + +def ensure_path_exists(doc: Union[Dict, List], path: List[str]): + """ + Ensures the path exists in the document, creating empty dictionaries or lists as + needed. + + Args: + doc: The current level of the document or sub-document. + path: The remaining path parts to ensure exist. + """ + if not path: + return + current_path = path.pop(0) + # Check if the next path part exists and is a digit + next_path_is_digit = path and path[0].isdigit() + + # Determine if the current path is for an array or an object + if isinstance(doc, list): # Path is for an array index + current_path = int(current_path) + # Ensure the current level of the document is a list and long enough + + while len(doc) <= current_path: + doc.append({}) + next_level = doc[current_path] + else: # Path is for an object + if current_path not in doc or ( + next_path_is_digit and not isinstance(doc.get(current_path), list) + ): + doc[current_path] = [] if next_path_is_digit else {} + next_level = doc[current_path] + + ensure_path_exists(next_level, path) + + +def assemble_document_for_patches(patches): + """ + Assembles an initial document that can successfully accept the given JSON Patch + operations. + + Args: + patches: A list of JSON Patch operations. + + Returns: + An initial document structured to accept the patches. + + Example: + + ```python + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"} + ] + + initial_document = assemble_document_for_patches(patches) + + #output + { + "name": {}, + "contact": {}, + "age": {} + } + ``` + """ + document = {} + + for patch in patches: + operation = patch["op"] + path = patch["path"].lstrip("/").split("/") + + if operation == "add": + # Ensure all but the last element of the path exists + ensure_path_exists(document, path[:-1]) + elif operation in ["remove", "replace"]: + # For remove adn replace, the entire path should exist + ensure_path_exists(document, path) + + return document diff --git a/tests/test_ecs.py b/tests/test_ecs.py index 2f970116..a81c446b 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -2128,6 +2128,15 @@ def base_job_template_with_defaults(default_base_job_template, aws_credentials): base_job_template_with_defaults["variables"]["properties"][ "auto_deregister_task_definition" ]["default"] = False + base_job_template_with_defaults["variables"]["properties"]["network_configuration"][ + "default" + ] = { + "awsvpcConfiguration": { + "subnets": ["subnet-***"], + "assignPublicIp": "DISABLED", + "securityGroups": ["sg-***"], + } + } return base_job_template_with_defaults @@ -2188,10 +2197,20 @@ async def test_generate_work_pool_base_job_template( cpu=2048, memory=4096, task_customizations=[ + { + "op": "replace", + "path": "/networkConfiguration/awsvpcConfiguration/assignPublicIp", + "value": "DISABLED", + }, + { + "op": "add", + "path": "/networkConfiguration/awsvpcConfiguration/subnets", + "value": ["subnet-***"], + }, { "op": "add", "path": "/networkConfiguration/awsvpcConfiguration/securityGroups", - "value": ["sg-d72e9599956a084f5"], + "value": ["sg-***"], }, ], family="test-family", @@ -2229,10 +2248,3 @@ async def test_generate_work_pool_base_job_template( template = await job.generate_work_pool_base_job_template() assert template == expected_template - - if job_config == "custom": - assert ( - "Unable to apply task customizations to the base job template." - "You may need to update the template manually." - in caplog.text - ) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 0e0fdc6f..cecf863f 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,10 @@ import pytest -from prefect_aws.utilities import hash_collection +from prefect_aws.utilities import ( + assemble_document_for_patches, + ensure_path_exists, + hash_collection, +) class TestHashCollection: @@ -32,3 +36,56 @@ def test_unhashable_structure(self): assert hash_collection(typically_unhashable_structure) == hash_collection( typically_unhashable_structure ), "Unhashable structure hashing failed after transformation" + + +class TestAssembleDocumentForPatches: + def test_initial_document(self): + patches = [ + {"op": "replace", "path": "/name", "value": "Jane"}, + {"op": "add", "path": "/contact/address", "value": "123 Main St"}, + {"op": "remove", "path": "/age"}, + ] + + initial_document = assemble_document_for_patches(patches) + + expected_document = {"name": {}, "contact": {}, "age": {}} + + assert initial_document == expected_document, "Initial document assembly failed" + + +class TestEnsurePathExists: + def test_existing_path(self): + doc = {"key1": {"subkey1": "value1"}} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": {"subkey1": "value1"} + }, "Existing path modification failed" + + def test_new_path_object(self): + doc = {} + path = ["key1", "subkey1"] + ensure_path_exists(doc, path) + assert doc == {"key1": {"subkey1": {}}}, "New path creation for object failed" + + def test_new_path_array(self): + doc = {} + path = ["key1", "0"] + ensure_path_exists(doc, path) + assert doc == {"key1": [{}]}, "New path creation for array failed" + + def test_existing_path_array(self): + doc = {"key1": [{"subkey1": "value1"}]} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": "value1"}] + }, "Existing path modification for array failed" + + def test_existing_path_array_index_out_of_range(self): + doc = {"key1": []} + path = ["key1", "0", "subkey1"] + ensure_path_exists(doc, path) + assert doc == { + "key1": [{"subkey1": {}}] + }, "Existing path modification for array index out of range failed" From 1add0b98107307c11240d470cc085293b182f1b6 Mon Sep 17 00:00:00 2001 From: Kevin Grismore <146098880+kevingrismore@users.noreply.github.com> Date: Tue, 12 Mar 2024 13:54:01 -0500 Subject: [PATCH 21/22] Call existing function for getting recent revision from family (#393) --- prefect_aws/workers/ecs_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prefect_aws/workers/ecs_worker.py b/prefect_aws/workers/ecs_worker.py index 2e63eac4..edbe843d 100644 --- a/prefect_aws/workers/ecs_worker.py +++ b/prefect_aws/workers/ecs_worker.py @@ -769,7 +769,7 @@ def _get_or_register_task_definition( ): family_name = task_definition.get("family", ECS_DEFAULT_FAMILY) try: - task_definition_from_family = self._retrieve_task_definition_by_family( + task_definition_from_family = self._retrieve_task_definition( logger, ecs_client, family_name ) if task_definition_from_family and self._task_definitions_equal( @@ -779,7 +779,7 @@ def _get_or_register_task_definition( "taskDefinitionArn" ] except Exception: - pass + cached_task_definition_arn = None if not cached_task_definition_arn: task_definition_arn = self._register_task_definition( From 5fc8a3a2864dbbfc15faab3acf6afaa524adf584 Mon Sep 17 00:00:00 2001 From: Alexander Streed Date: Fri, 15 Mar 2024 08:14:57 -0500 Subject: [PATCH 22/22] Deprecate `ECSTask` block (#395) --- prefect_aws/ecs.py | 19 ++++++++++++++++++- requirements.txt | 2 +- tests/test_ecs.py | 15 +++++++++++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/prefect_aws/ecs.py b/prefect_aws/ecs.py index 6368748f..b0227bf6 100644 --- a/prefect_aws/ecs.py +++ b/prefect_aws/ecs.py @@ -1,4 +1,11 @@ """ +DEPRECATION WARNING: + +This module is deprecated as of March 2024 and will not be available after September 2024. +It has been replaced by the ECS worker, which offers enhanced functionality and better performance. + +For upgrade instructions, see https://docs.prefect.io/latest/guides/upgrade-guide-agents-to-workers/. + Integrations with the Amazon Elastic Container Service. Examples: @@ -102,7 +109,8 @@ ], ) ``` -""" +""" # noqa + import copy import difflib import json @@ -118,6 +126,7 @@ import yaml from anyio.abc import TaskStatus from jsonpointer import JsonPointerException +from prefect._internal.compatibility.deprecated import deprecated_class from prefect.blocks.core import BlockNotSavedError from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.infrastructure.base import Infrastructure, InfrastructureResult @@ -205,6 +214,14 @@ def _pretty_diff(d1: dict, d2: dict) -> str: ) +@deprecated_class( + start_date="Mar 2024", + help=( + "Use the ECS worker instead." + " Refer to the upgrade guide for more information:" + " https://docs.prefect.io/latest/guides/upgrade-guide-agents-to-workers/." + ), +) class ECSTask(Infrastructure): """ Run a command as an ECS task. diff --git a/requirements.txt b/requirements.txt index e5cfb0b0..4a764f81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ boto3>=1.24.53 botocore>=1.27.53 mypy_boto3_s3>=1.24.94 mypy_boto3_secretsmanager>=1.26.49 -prefect>=2.14.10 +prefect>=2.16.4 tenacity>=8.0.0 \ No newline at end of file diff --git a/tests/test_ecs.py b/tests/test_ecs.py index a81c446b..6c429e9f 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -12,6 +12,7 @@ from botocore.exceptions import ClientError from moto import mock_ec2, mock_ecs, mock_logs from moto.ec2.utils import generate_instance_identity_document +from prefect._internal.compatibility.deprecated import PrefectDeprecationWarning from prefect.exceptions import InfrastructureNotAvailable, InfrastructureNotFound from prefect.logging.configuration import setup_logging from prefect.server.schemas.core import Deployment, Flow, FlowRun @@ -35,6 +36,20 @@ parse_task_identifier, ) + +def test_ecs_task_emits_deprecation_warning(): + with pytest.warns( + PrefectDeprecationWarning, + match=( + "prefect_aws.ecs.ECSTask has been deprecated." + " It will not be available after Sep 2024." + " Use the ECS worker instead." + " Refer to the upgrade guide for more information" + ), + ): + ECSTask() + + setup_logging()