Skip to content

Commit

Permalink
[Refactor] __eq__ to identity check in non-tensor stacking
Browse files Browse the repository at this point in the history
ghstack-source-id: 9e7c30aa83aca63ae331093f9c028861370f88e7
Pull Request resolved: #1083
  • Loading branch information
vmoens committed Nov 11, 2024
1 parent 9607cf0 commit 6e83e59
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 26 deletions.
55 changes: 41 additions & 14 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,20 +2748,23 @@ def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]

def _check_equal(a, b):
try:
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
return (a == b).all() and a.shape == b.shape
if isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
return (a == b).all() and a.shape == b.shape
iseq = a == b
except Exception:
iseq = False
return iseq

if all(isinstance(data, NonTensorData) for data in list_of_non_tensor) and all(
_check_equal(data.data, first.data) for data in list_of_non_tensor[1:]
):
ids = set()
firstdata = NO_DEFAULT
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
return_stack = True
break
if firstdata is NO_DEFAULT:
firstdata = data.data
ids.add(id(data.data))
if len(ids) > 1:
if _check_equal(data.data, firstdata):
continue
return_stack = True
break
else:
return_stack = False
if not return_stack:
batch_size = list(first.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
return NonTensorData(
Expand Down Expand Up @@ -3442,3 +3445,27 @@ class TensorClass(metaclass=_TensorClassMeta):
_nocast: bool = False
_frozen: bool = False
...


# TODO: v0.8: remove this func entirely
def _check_equal(a, b):
# A util to check that two non-tensor data match
# We're replacing this by an identity match, not a value check (which will be faster and easier to handle).
try:
if isinstance(a, _ACCEPTED_CLASSES) or isinstance(b, _ACCEPTED_CLASSES):
iseq = (a == b).all() and a.shape == b.shape
elif isinstance(a, np.ndarray) or isinstance(b, np.ndarray):
iseq = (a == b).all() and a.shape == b.shape
else:
iseq = bool(a == b)
except Exception:
iseq = False
if iseq:
warnings.warn(
"The content of the stacked NonTensorData objects matched in value but not identity. "
"This will currently return a NonTensorData but in the future (v0.8) it will return "
"a NonTensorStack instead. "
"To obtain a non-tensor stack, use `TensorDict.lazy_stack` instead.",
category=UserWarning,
)
return iseq
14 changes: 2 additions & 12 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,18 +1909,8 @@ def z(self) -> torch.Tensor:
return self._z()

obj = torch.ones(())
y0 = Y(
weakref.ref(obj),
batch_size=[
1,
],
)
y1 = Y(
weakref.ref(obj),
batch_size=[
1,
],
)
y0 = Y(weakref.ref(obj), batch_size=[1])
y1 = Y(weakref.ref(obj), batch_size=[1])
y = torch.cat([y0, y1])
assert y.z.shape == torch.Size(())
y = torch.stack([y0, y1])
Expand Down
14 changes: 14 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10448,6 +10448,9 @@ def nontensor_check(cls, td):
)
return td

@pytest.mark.filterwarnings(
"ignore:The content of the stacked NonTensorData objects matched in value but not identity"
)
def test_map_non_tensor(self):
gc.collect()
# with NonTensorStack
Expand Down Expand Up @@ -10881,6 +10884,17 @@ def test_memmap_stack(self, tmpdir, json_serializable, device):
assert data_memmap._is_memmap

def test_memmap_stack_updates(self, tmpdir):
with pytest.warns(
UserWarning,
match="The content of the stacked NonTensorData objects matched in value but not identity",
):
data = torch.stack(
[
NonTensorData(data=torch.zeros(())),
NonTensorData(data=torch.zeros(())),
],
0,
)
data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0)
assert is_non_tensor(data)
data = torch.stack([data] * 3)
Expand Down

0 comments on commit 6e83e59

Please sign in to comment.