Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
avoid modifying default behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Jan 18, 2024
1 parent 56210f8 commit 865975c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
22 changes: 20 additions & 2 deletions prefect_aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType], use_cache: bool = False):
"""
Helper method to dynamically get a client type.
Expand All @@ -161,6 +161,15 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
Raises:
ValueError: if the client is not supported.
"""
if isinstance(client_type, ClientType):
client_type = client_type.value

if not use_cache:
return self.get_boto3_session().client(
service_name=client_type,
**self.aws_client_parameters.get_params_override()
)

return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
Expand Down Expand Up @@ -271,7 +280,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType], use_cache: bool = False):
"""
Helper method to dynamically get a client type.
Expand All @@ -284,6 +293,15 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
Raises:
ValueError: if the client is not supported.
"""
if isinstance(client_type, ClientType):
client_type = client_type.value

if not use_cache:
return self.get_boto3_session().client(
service_name=client_type,
**self.aws_client_parameters.get_params_override()
)

return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
Expand Down
11 changes: 10 additions & 1 deletion prefect_aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,15 @@ class S3Bucket(WritableFileSystem, WritableDeploymentStorage, ObjectStorageBlock
"for reading and writing objects."
),
)

cache_client: bool = Field(
default=False,
description=(
"If True, the S3 client will be cached. This is useful for "
"performance, but can cause issues if the S3 client is used "
"in multiple threads."
),
)

# Property to maintain compatibility with storage block based deployments
@property
Expand Down Expand Up @@ -466,7 +475,7 @@ def _get_s3_client(self) -> boto3.client:
Authenticate MinIO credentials or AWS credentials and return an S3 client.
This is a helper function called by read_path() or write_path().
"""
return self.credentials.get_s3_client()
return self.credentials.get_client("s3", use_cache=self.cache_client)

def _get_bucket_resource(self) -> boto3.resource:
"""
Expand Down
18 changes: 12 additions & 6 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def test_credentials_get_client(credentials, client_type):
with mock_s3():
assert isinstance(credentials.get_client(client_type), BaseClient)


def test_get_client_cached():
@pytest.mark.parametrize(
"client_type", [member.value for member in ClientType]
)
def test_get_client_cached(client_type):
"""
Test to ensure that _get_client_cached function returns the same instance
for multiple calls with the same parameters and properly utilizes lru_cache.
Expand All @@ -65,11 +67,15 @@ def test_get_client_cached():

assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0"

assert aws_credentials_block.get_client(client_type) is not None

assert _get_client_cached.cache_info().hits == 0, "Cache should not yet be used"

# Call get_client multiple times with the same parameters
aws_credentials_block.get_client(ClientType.S3)
aws_credentials_block.get_client(ClientType.S3)
aws_credentials_block.get_client(ClientType.S3)
aws_credentials_block.get_client(client_type, use_cache=True)
aws_credentials_block.get_client(client_type, use_cache=True)
aws_credentials_block.get_client(client_type, use_cache=True)

# Verify that _get_client_cached is called only once due to caching
assert _get_client_cached.cache_info().misses == 1
assert _get_client_cached.cache_info().hits == 2
assert _get_client_cached.cache_info().hits == 2

0 comments on commit 865975c

Please sign in to comment.