Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
kurtamohler committed Jan 10, 2025
1 parent edf46ba commit 6b28487
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ to be able to create this other composition:
FlattenObservation
FrameSkipTransform
GrayScale
Hash
InitTracker
KLRewardTransform
NoopResetEnv
Expand Down
14 changes: 6 additions & 8 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2188,7 +2188,7 @@ def test_transform_no_env(self, datatype):
hash_fn = hash
elif datatype == "str":
obs = "abcdefg"
hash_fn = Hash.reproducible_hash_parts
hash_fn = Hash.reproducible_hash
elif datatype == "NonTensorStack":
obs = torch.stack(
[
Expand All @@ -2199,9 +2199,7 @@ def test_transform_no_env(self, datatype):
)

def fn0(x):
return torch.stack(
[Hash.reproducible_hash_parts(x_.get("data")) for x_ in x]
)
return torch.stack([Hash.reproducible_hash(x_.get("data")) for x_ in x])

hash_fn = fn0
else:
Expand Down Expand Up @@ -2311,9 +2309,9 @@ def test_trans_serial_env_check(self, datatype):
in_keys=["string"],
out_keys=["hash"],
hash_fn=lambda x: torch.stack(
[Hash.reproducible_hash_parts(x_.get("data")) for x_ in x]
[Hash.reproducible_hash(x_.get("data")) for x_ in x]
),
output_spec=Unbounded(shape=(2, 4), dtype=torch.int64),
output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8),
)
base_env = CountingEnvWithString

Expand All @@ -2335,9 +2333,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
in_keys=["string"],
out_keys=["hash"],
hash_fn=lambda x: torch.stack(
[Hash.reproducible_hash_parts(x_.get("data")) for x_ in x]
[Hash.reproducible_hash(x_.get("data")) for x_ in x]
),
output_spec=Unbounded(shape=(2, 4), dtype=torch.int64),
output_spec=Unbounded(shape=(2, 32), dtype=torch.uint8),
)
base_env = CountingEnvWithString

Expand Down
28 changes: 7 additions & 21 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4463,9 +4463,9 @@ class Hash(UnaryTransform):
out_keys (sequence of NestedKey): the keys of the resulting hashes.
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
the hash function must accept it as its second argument. Default is
Python's builtin ``hash`` function.
``Hash.reproducible_hash``.
output_spec (TensorSpec, optional): the spec of the hash output. Default
is ``Unbounded(shape=(), dtype=torch.int64)``.
is ``Unbounded(shape=(32,), dtype=torch.uint8)``.
seed (optional): seed to use for the hash function, if it requires one.
"""

Expand All @@ -4478,10 +4478,10 @@ def __init__(
seed: Any | None = None,
):
if hash_fn is None:
hash_fn = Hash.reproducible_hash_parts
hash_fn = Hash.reproducible_hash

if output_spec is None:
output_spec = Unbounded(shape=(4,), dtype=torch.int64)
output_spec = Unbounded(shape=(32,), dtype=torch.uint8)

self._seed = seed
self._hash_fn = hash_fn
Expand All @@ -4499,15 +4499,15 @@ def call_hash_fn(self, value):
return self._hash_fn(value, self._seed)

@classmethod
def reproducible_hash_parts(cls, string, seed=None):
def reproducible_hash(cls, string, seed=None):
"""Creates a reproducible 256-bit hash from a string using a seed.
Args:
string (str): The input string.
seed (str, optional): The seed value. Default is ``None``.
Returns:
tuple: Four 64-bit integers representing the parts of the 256-bit hash value.
Tensor: Shape ``(32,)`` with dtype ``torch.int8``.
"""
# Prepend the seed to the string
if seed is not None:
Expand All @@ -4524,21 +4524,7 @@ def reproducible_hash_parts(cls, string, seed=None):
# Get the hash value as bytes
hash_bytes = hash_object.digest()

# Split the hash bytes into four parts
part1 = hash_bytes[:8]
part2 = hash_bytes[8:16]
part3 = hash_bytes[16:24]
part4 = hash_bytes[24:]

# Convert each part to a 64-bit integer
part1_value = int.from_bytes(part1, "big", signed=True)
part2_value = int.from_bytes(part2, "big", signed=True)
part3_value = int.from_bytes(part3, "big", signed=True)
part4_value = int.from_bytes(part4, "big", signed=True)

return torch.tensor(
[part1_value, part2_value, part3_value, part4_value], dtype=torch.int64
)
return torch.frombuffer(hash_bytes, dtype=torch.uint8)


class Stack(Transform):
Expand Down

0 comments on commit 6b28487

Please sign in to comment.