diff --git a/prefect_aws/client_parameters.py b/prefect_aws/client_parameters.py index eb3be09b..6b47c422 100644 --- a/prefect_aws/client_parameters.py +++ b/prefect_aws/client_parameters.py @@ -7,6 +7,8 @@ from botocore.client import Config from pydantic import VERSION as PYDANTIC_VERSION +from prefect_aws.utilities import hash_collection + if PYDANTIC_VERSION.startswith("2."): from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator else: @@ -78,7 +80,7 @@ def __hash__(self): self.verify, self.verify_cert_path, self.endpoint_url, - self.config, + hash_collection(self.config), ) ) diff --git a/prefect_aws/credentials.py b/prefect_aws/credentials.py index 5aeddaa6..474a610a 100644 --- a/prefect_aws/credentials.py +++ b/prefect_aws/credentials.py @@ -118,7 +118,7 @@ def __hash__(self): hash(self.aws_session_token), hash(self.profile_name), hash(self.region_name), - hash(frozenset(self.aws_client_parameters.dict().items())), + hash(self.aws_client_parameters), ) return hash(field_hashes) diff --git a/prefect_aws/utilities.py b/prefect_aws/utilities.py new file mode 100644 index 00000000..ad1e6ed2 --- /dev/null +++ b/prefect_aws/utilities.py @@ -0,0 +1,35 @@ +"""Utilities for working with AWS services.""" + +from prefect.utilities.collections import visit_collection + + +def hash_collection(collection) -> int: + """Use visit_collection to transform and hash a collection. + + Args: + collection (Any): The collection to hash. + + Returns: + int: The hash of the transformed collection. + + Example: + ```python + from prefect_aws.utilities import hash_collection + + hash_collection({"a": 1, "b": 2}) + ``` + + """ + + def make_hashable(item): + """Make an item hashable by converting it to a tuple.""" + if isinstance(item, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in item.items())) + elif isinstance(item, list): + return tuple(make_hashable(v) for v in item) + return item + + hashable_collection = visit_collection( + collection, visit_fn=make_hashable, return_data=True + ) + return hash(hashable_collection) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 96ecbd22..6e593212 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -164,3 +164,28 @@ def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field new_hash = hash(credentials) assert initial_hash != new_hash, "Hash should change when region_name changes" + + +def test_aws_credentials_nested_client_parameters_are_hashable(): + """ + Test to ensure that nested client parameters are hashable. + """ + + creds = AwsCredentials( + region_name="us-east-1", + aws_client_parameters=dict( + config=dict( + connect_timeout=5, + read_timeout=5, + retries=dict(max_attempts=10, mode="standard"), + ) + ), + ) + + assert hash(creds) is not None + + client = creds.get_client("s3") + + _client = creds.get_client("s3") + + assert client is _client diff --git a/tests/test_utilities.py b/tests/test_utilities.py new file mode 100644 index 00000000..0e0fdc6f --- /dev/null +++ b/tests/test_utilities.py @@ -0,0 +1,34 @@ +import pytest + +from prefect_aws.utilities import hash_collection + + +class TestHashCollection: + def test_simple_dict(self): + simple_dict = {"key1": "value1", "key2": "value2"} + assert hash_collection(simple_dict) == hash_collection( + simple_dict + ), "Simple dictionary hashing failed" + + def test_nested_dict(self): + nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"} + assert hash_collection(nested_dict) == hash_collection( + nested_dict + ), "Nested dictionary hashing failed" + + def test_complex_structure(self): + complex_structure = { + "key1": [1, 2, 3], + "key2": {"subkey1": {"subsubkey1": "value"}}, + } + assert hash_collection(complex_structure) == hash_collection( + complex_structure + ), "Complex structure hashing failed" + + def test_unhashable_structure(self): + typically_unhashable_structure = dict(key=dict(subkey=[1, 2, 3])) + with pytest.raises(TypeError): + hash(typically_unhashable_structure) + assert hash_collection(typically_unhashable_structure) == hash_collection( + typically_unhashable_structure + ), "Unhashable structure hashing failed after transformation"