diff --git a/CHANGELOG.md b/CHANGELOG.md index 64454b60..b40563f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361) + ### Fixed ### Deprecated @@ -105,6 +107,7 @@ Released August 31st, 2023. Released July 20th, 2023. ### Changed + - Promoted workers to GA, removed beta disclaimers ## 0.3.5 @@ -293,6 +296,7 @@ Released on October 28th, 2022. - `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137) ### Fixed + - Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139) - Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143) diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index bf030590..eb3be09b 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -70,6 +70,18 @@ class AwsClientParameters(BaseModel): title="Botocore Config", ) + def __hash__(self): + return hash( + ( + self.api_version, + self.use_ssl, + self.verify, + self.verify_cert_path, + self.endpoint_url, + self.config, + ) + ) + @validator("config", pre=True) def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]: """ diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index 64f49efe..7845b74b 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -1,6 +1,7 @@ """Module handling AWS credentials""" from enum import Enum +from functools import lru_cache from typing import Any, Optional, Union import boto3 @@ -18,12 +19,38 @@ class ClientType(Enum): + """The supported boto3 clients.""" + S3 = "s3" ECS = "ecs" BATCH = "batch" SECRETS_MANAGER = "secretsmanager" +@lru_cache(maxsize=8, typed=True) +def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any: + """ + Helper method to cache and dynamically get a client type. + + Args: + client_type: The client's service name. + + Returns: + An authenticated client. + + Raises: + ValueError: if the client is not supported. + """ + if isinstance(client_type, ClientType): + client_type = client_type.value + + client = ctx.get_boto3_session().client( + service_name=client_type, + **ctx.aws_client_parameters.get_params_override(), + ) + return client + + class AwsCredentials(CredentialsBlock): """ Block used to manage authentication with AWS. AWS authentication is @@ -75,6 +102,23 @@ class AwsCredentials(CredentialsBlock): title="AWS Client Parameters", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + return hash( + ( + self.aws_access_key_id, + self.aws_secret_access_key, + self.aws_session_token, + self.profile_name, + self.region_name, + self.aws_client_parameters, + ) + ) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -104,7 +148,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType], use_cache: bool = False): """ Helper method to dynamically get a client type. @@ -120,10 +164,13 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + if not use_cache: + return self.get_boto3_session().client( + service_name=client_type, + **self.aws_client_parameters.get_params_override(), + ) + + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ @@ -186,6 +233,21 @@ class MinIOCredentials(CredentialsBlock): description="Extra parameters to initialize the Client.", ) + class Config: + """Config class for pydantic model.""" + + arbitrary_types_allowed = True + + def __hash__(self): + return hash( + ( + self.minio_root_user, + self.minio_root_password, + self.region_name, + self.aws_client_parameters, + ) + ) + def get_boto3_session(self) -> boto3.Session: """ Returns an authenticated boto3 session that can be used to create clients @@ -218,7 +280,7 @@ def get_boto3_session(self) -> boto3.Session: region_name=self.region_name, ) - def get_client(self, client_type: Union[str, ClientType]) -> Any: + def get_client(self, client_type: Union[str, ClientType], use_cache: bool = False): """ Helper method to dynamically get a client type. @@ -234,10 +296,13 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any: if isinstance(client_type, ClientType): client_type = client_type.value - client = self.get_boto3_session().client( - service_name=client_type, **self.aws_client_parameters.get_params_override() - ) - return client + if not use_cache: + return self.get_boto3_session().client( + service_name=client_type, + **self.aws_client_parameters.get_params_override(), + ) + + return _get_client_cached(ctx=self, client_type=client_type) def get_s3_client(self) -> S3Client: """ diff --git a/prefect_aws/s3.py b/prefect_aws/s3.py index 643d78ac..34ca076f 100644 --- a/prefect_aws/s3.py +++ b/prefect_aws/s3.py @@ -425,6 +425,15 @@ class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock ), ) + cache_client: bool = Field( + default=False, + description=( + "If True, the S3 client will be cached. This is useful for " + "performance, but could cause issues if the S3 client is used " + "in multi-threaded environments." + ), + ) + # Property to maintain compatibility with storage block based deployments @property def basepath(self) -> str: @@ -466,7 +475,7 @@ def _get_s3_client(self) -> boto3.client: Authenticate MinIO credentials or AWS credentials and return an S3 client. This is a helper function called by read_path() or write_path(). """ - return self.credentials.get_s3_client() + return self.credentials.get_client("s3", use_cache=self.cache_client) def _get_bucket_resource(self) -> boto3.resource: """ diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 6e0a1ff8..1b66421b 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -3,7 +3,12 @@ from botocore.client import BaseClient from moto import mock_s3 -from prefect_aws.credentials import AwsCredentials, ClientType, MinIOCredentials +from prefect_aws.credentials import ( + AwsCredentials, + ClientType, + MinIOCredentials, + _get_client_cached, +) def test_aws_credentials_get_boto3_session(): @@ -44,3 +49,32 @@ def test_minio_credentials_get_boto3_session(): def test_credentials_get_client(credentials, client_type): with mock_s3(): assert isinstance(credentials.get_client(client_type), BaseClient) + + +@pytest.mark.parametrize("client_type", [member.value for member in ClientType]) +def test_get_client_cached(client_type): + """ + Test to ensure that _get_client_cached function returns the same instance + for multiple calls with the same parameters and properly utilizes lru_cache. + """ + + # Create a mock AwsCredentials instance + aws_credentials_block = AwsCredentials(region_name="us-east-1") + + # Clear cache + _get_client_cached.cache_clear() + + assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0" + + assert aws_credentials_block.get_client(client_type) is not None + + assert _get_client_cached.cache_info().hits == 0, "Cache should not yet be used" + + # Call get_client multiple times with the same parameters + aws_credentials_block.get_client(client_type, use_cache=True) + aws_credentials_block.get_client(client_type, use_cache=True) + aws_credentials_block.get_client(client_type, use_cache=True) + + # Verify that _get_client_cached is called only once due to caching + assert _get_client_cached.cache_info().misses == 1 + assert _get_client_cached.cache_info().hits == 2 diff --git a/tests/test_s3.py b/tests/test_s3.py index 93d11cc1..d6741625 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -12,6 +12,7 @@ from prefect_aws import AwsCredentials, MinIOCredentials from prefect_aws.client_parameters import AwsClientParameters +from prefect_aws.credentials import _get_client_cached from prefect_aws.s3 import ( S3Bucket, s3_copy, @@ -1047,3 +1048,16 @@ def test_move_object_between_buckets( with pytest.raises(ClientError): assert s3_bucket_with_object.read_path("object") == b"TEST" + + def test_client_is_cached_when_specified(self, aws_creds_block): + s3_bucket = S3Bucket( + bucket_name="bucket", credentials=aws_creds_block, cache_client=True + ) + + _get_client_cached.cache_clear() + + s3_bucket._get_s3_client() + s3_bucket._get_s3_client() + + assert _get_client_cached.cache_info().hits == 1 + assert _get_client_cached.cache_info().misses == 1