Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

fix client hashing in nested client params case #373

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion prefect_aws/client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from botocore.client import Config
from pydantic import VERSION as PYDANTIC_VERSION

from prefect_aws.utilities import hash_collection

if PYDANTIC_VERSION.startswith("2."):
from pydantic.v1 import BaseModel, Field, FilePath, root_validator, validator
else:
Expand Down Expand Up @@ -78,7 +80,7 @@ def __hash__(self):
self.verify,
self.verify_cert_path,
self.endpoint_url,
self.config,
hash_collection(self.config),
)
)

Expand Down
2 changes: 1 addition & 1 deletion prefect_aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __hash__(self):
hash(self.aws_session_token),
hash(self.profile_name),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
hash(self.aws_client_parameters),
)
return hash(field_hashes)

Expand Down
35 changes: 35 additions & 0 deletions prefect_aws/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Utilities for working with AWS services."""

from prefect.utilities.collections import visit_collection


def hash_collection(collection) -> int:
"""Use visit_collection to transform and hash a collection.

Args:
collection (Any): The collection to hash.

Returns:
int: The hash of the transformed collection.

Example:
```python
from prefect_aws.utilities import hash_collection

hash_collection({"a": 1, "b": 2})
```

"""

def make_hashable(item):
"""Make an item hashable by converting it to a tuple."""
if isinstance(item, dict):
return tuple(sorted((k, make_hashable(v)) for k, v in item.items()))
elif isinstance(item, list):
return tuple(make_hashable(v) for v in item)
return item

hashable_collection = visit_collection(
collection, visit_fn=make_hashable, return_data=True
)
return hash(hashable_collection)
25 changes: 25 additions & 0 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,28 @@ def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field
new_hash = hash(credentials)

assert initial_hash != new_hash, "Hash should change when region_name changes"


def test_aws_credentials_nested_client_parameters_are_hashable():
"""
Test to ensure that nested client parameters are hashable.
"""

creds = AwsCredentials(
region_name="us-east-1",
aws_client_parameters=dict(
config=dict(
connect_timeout=5,
read_timeout=5,
retries=dict(max_attempts=10, mode="standard"),
)
),
)

assert hash(creds) is not None

client = creds.get_client("s3")

_client = creds.get_client("s3")

assert client is _client
34 changes: 34 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from prefect_aws.utilities import hash_collection


class TestHashCollection:
def test_simple_dict(self):
simple_dict = {"key1": "value1", "key2": "value2"}
assert hash_collection(simple_dict) == hash_collection(
simple_dict
), "Simple dictionary hashing failed"

def test_nested_dict(self):
nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"}
assert hash_collection(nested_dict) == hash_collection(
nested_dict
), "Nested dictionary hashing failed"

def test_complex_structure(self):
complex_structure = {
"key1": [1, 2, 3],
"key2": {"subkey1": {"subsubkey1": "value"}},
}
assert hash_collection(complex_structure) == hash_collection(
complex_structure
), "Complex structure hashing failed"

def test_unhashable_structure(self):
typically_unhashable_structure = dict(key=dict(subkey=[1, 2, 3]))
with pytest.raises(TypeError):
hash(typically_unhashable_structure)
assert hash_collection(typically_unhashable_structure) == hash_collection(
typically_unhashable_structure
), "Unhashable structure hashing failed after transformation"