From ffea80f31cace5c13ad83d98f194d061914112e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 18:17:29 +0000 Subject: [PATCH] [Feature] UnbatchedTensor ghstack-source-id: d3c5067beeb099d1ae080752bc6e218d543c7515 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1170 --- docs/source/overview.rst | 55 +++++++--- tensordict/__init__.py | 1 + tensordict/_torch_func.py | 10 +- tensordict/_unbatched.py | 214 ++++++++++++++++++++++++++++++++++++++ tensordict/base.py | 16 ++- tensordict/tensorclass.py | 2 + tensordict/utils.py | 34 ++++-- test/test_tensordict.py | 104 ++++++++++++++++++ 8 files changed, 405 insertions(+), 31 deletions(-) create mode 100644 tensordict/_unbatched.py diff --git a/docs/source/overview.rst b/docs/source/overview.rst index aebbf8c66..4a91391c9 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -61,25 +61,26 @@ Decomposing the output dictionary in three similarly structured dictionaries aft Features -------- -A ``TensorDict`` is a dict-like container for tensors. To instantiate a ``TensorDict``, you must specify key-value pairs as well as the batch size. The leading dimensions of any values in the ``TensorDict`` must be compatible with the batch size. +A ``TensorDict`` is a dict-like container for tensors. To instantiate a ``TensorDict``, you can specify key-value pairs +as well as the batch size (an empty tensordict can be created via `TensorDict()`). +The leading dimensions of any values in the ``TensorDict`` must be compatible with the batch size. ->>> import torch ->>> from tensordict import TensorDict - ->>> tensordict = TensorDict( -... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)}, -... batch_size=[2, 3], -... ) + >>> import torch + >>> from tensordict import TensorDict + >>> tensordict = TensorDict( + ... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)}, + ... batch_size=[2, 3], + ... ) The syntax for setting or retrieving values is much like that for a regular dictionary. ->>> zeros = tensordict["zeros"] ->>> tensordict["twos"] = 2 * torch.ones(2, 3) + >>> zeros = tensordict["zeros"] + >>> tensordict["twos"] = 2 * torch.ones(2, 3) One can also index a tensordict along its batch_size which makes it possible to obtain congruent slices of data in just a few characters (notice that indexing the nth leading dimensions with tree_map using an ellipsis would require a bit more coding): ->>> sub_tensordict = tensordict[..., :2] + >>> sub_tensordict = tensordict[..., :2] One can also use the set method with ``inplace=True`` or the ``set_`` method to do inplace updates of the contents. The former is a fault-tolerant version of the latter: if no matching key is found, it will write a new one. @@ -87,13 +88,39 @@ The former is a fault-tolerant version of the latter: if no matching key is foun The contents of the TensorDict can now be manipulated collectively. For example, to place all of the contents onto a particular device one can simply do ->>> tensordict = tensordict.to("cuda:0") + >>> tensordict = tensordict.to("cuda:0") + +You can then assert that the device of the tensordict is `"cuda:0"`: + + >>> assert tensordict.device == torch.device("cuda:0") To reshape the batch dimensions one can do ->>> tensordict = tensordict.reshape(6) + >>> tensordict = tensordict.reshape(6) + +The class supports many other operations, including :func:`~torch.squeeze`, :func:`~torch.unsqueeze`, +:meth:`~tensordict.TensorDict.view`, :func:`~torch.permute`, :meth:`~tensordict.TensorDict.unbind`, +:func:`~torch.stack`, :func:`~torch.cat` and many more. + +If an operation is not present, the :meth:`~tensordict.TensorDict.apply` method will usually provide the solution +that was needed. + +Escaping shape operations +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, it may be desirable to store tensors in a TensorDict without enforcing batch size consistency during +shape operations. + +This can be achieved by wrapping the tensor in an :class:`~tensordict.UnbatchedTensor` instance. + +An :class:`~tensordict.UnbatchedTensor` ignores its shape during shape operations on the TensorDict, allowing for +flexible storage and manipulation of tensors with arbitrary shapes. -The class supports many other operations, including squeeze, unsqueeze, view, permute, unbind, stack, cat and many more. If an operation is not present, the TensorDict.apply method will usually provide the solution that was needed. + >>> from tensordict import UnbatchedTensor + >>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3]) + >>> reshaped_td = tensordict.reshape(6) + >>> reshaped_td["zeros"] is tensordict["zeros"] + True Named dimensions ---------------- diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c339365be..ae080a5f0 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -23,6 +23,7 @@ stack, TensorDict, ) +from tensordict._unbatched import UnbatchedTensor from tensordict.base import ( from_any, diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 2771d4985..be504b2b5 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -23,10 +23,10 @@ from tensordict.utils import ( _check_keys, _ErrorInteceptor, + _pass_through, _shape, _zip_strict, DeviceType, - is_non_tensor, is_tensorclass, lazy_legacy, set_lazy_legacy, @@ -454,10 +454,10 @@ def _stack( if maybe_dense_stack is None: maybe_dense_stack = lazy_legacy() is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData - - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + if all(_pass_through(td) for td in list_of_tensordicts): + return type(list_of_tensordicts[0])._stack_non_tensor( + list_of_tensordicts, dim=dim + ) if is_tc: tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] diff --git a/tensordict/_unbatched.py b/tensordict/_unbatched.py new file mode 100644 index 000000000..70ff94e84 --- /dev/null +++ b/tensordict/_unbatched.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable + +import torch +from tensordict.base import TensorDictBase + +from tensordict.tensorclass import ( + _arg_to_tensordict, + _from_tensordict_with_copy, + _TD_PASS_THROUGH, + TD_HANDLED_FUNCTIONS, + TensorClass, +) +from tensordict.utils import _getitem_batch_size, _is_tensorclass, unravel_key +from torch import Tensor + + +def _arg_to_tensordict_unbatched(arg, batch_size): + if _is_tensorclass(type(arg)): + arg = arg._tensordict.empty() + arg.batch_size = batch_size + return arg + elif isinstance(arg, (tuple, list)) and all( + _is_tensorclass(type(item)) for item in arg + ): + arg_list = [] + for item in arg: + item = item._tensordict.empty() + item.batch_size = batch_size + arg_list.append(item) + + return type(arg)(arg_list) + return arg + + +def _bypass(func): + @wraps(func) + def bypassed_func(self, *args, **kwargs): + meta_tensor = torch.zeros( + self.batch_size, dtype=self.dtype, device=torch.device("meta") + ) + name = func.__name__ + r = getattr(meta_tensor, name)(*args, **kwargs) + self_copy = self.copy() + self_copy.batch_size = r.shape + return self_copy + + return bypassed_func + + +_TORCH_SHAPE_OPS = ( + torch.gather, + torch.unbind, + torch.cat, + torch.stack, + torch.unflatten, + torch.flatten, + torch.split, + torch.squeeze, + torch.unsqueeze, +) + + +class UnbatchedTensor(TensorClass): + """A TensorClass that represents a tensor whose shape is ignored during shape operations. + + This class allows tensors to be stored in a TensorDict without enforcing batch size consistency. + Shape operations (e.g., reshape, unsqueeze, squeeze) on the TensorDict will return the same UnbatchedTensor instance, + while other operations (e.g., apply, key manipulation, pointwise arithmetic) may modify the underlying tensor content. + + Example: + >>> td = TensorDict(a=UnbatchedTensor(torch.randn(3, 4)), b=torch.randn(2, 3), batch_size=(2,)) + >>> td_reshaped = td.reshape((1, 2)) + >>> td_reshaped["a"] is td["a"] + True + + Note that accessing an UnbatchedTensor using `get()` and `__getitem__()` will return different results. + `get()` returns the UnbatchedTensor instance, while `__getitem__()` returns the underlying tensor content. + + Example: + >>> td.get("a") + + >>> td["a"] + tensor([[...]]) + + """ + + data: torch.Tensor | TensorDictBase + _pass_through = True + + @classmethod + def __torch_function__( + cls, + func: Callable, + types: tuple[type, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Callable: + if func not in _TD_PASS_THROUGH or not all( + issubclass(t, (Tensor, cls, TensorDictBase)) for t in types + ): + return NotImplemented + + if kwargs is None: + kwargs = {} + + # get the output type from the arguments / keyword arguments + if len(args) > 0: + tensorclass_instance = args[0] + else: + tensorclass_instance = kwargs.get("input", kwargs["tensors"]) + if isinstance(tensorclass_instance, (tuple, list)): + tensorclass_instance = tensorclass_instance[0] + + if func not in _TORCH_SHAPE_OPS: + args = tuple(_arg_to_tensordict(arg) for arg in args) + kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()} + result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + else: + # Get a brute force batch size + args = tuple( + _arg_to_tensordict_unbatched(arg, tensorclass_instance.batch_size) + for arg in args + ) + kwargs = { + key: _arg_to_tensordict_unbatched( + value, tensorclass_instance.batch_size + ) + for key, value in kwargs.items() + } + example_td = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + result = tensorclass_instance.copy() + result.batch_size = example_td.batch_size + return result + + if isinstance(result, (list, tuple)): + return type(result)( + _from_tensordict_with_copy(tensorclass_instance, tensordict_result) + for tensordict_result in result + ) + return _from_tensordict_with_copy(tensorclass_instance, result) + + def __getitem__(self, index): + if isinstance(index, (tuple, str)) and unravel_key(index): + raise ValueError( + "TensorClass fields must be accessed as attributes, not items." + ) + self_copy = self.copy() + self_copy.batch_size = _getitem_batch_size(self.batch_size, index) + return self_copy + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size): + self.__dict__["_batch_size"] = torch.Size(batch_size) + + shape = batch_size + + def unbind(self, dim: int): + return tuple( + self[(slice(None),) * dim + (0,)] for _ in range(self.batch_size[dim]) + ) + + @_bypass + def reshape(self, *shape): ... + + @_bypass + def view(self, *shape): ... + + def unsqueeze(self, dim): + shape = list(self.batch_size) + shape.insert(dim, 0) + self_copy = self.copy() + self_copy.batch_size = shape + return self_copy + + def transpose(self, dim0, dim1): + batch_size = list(self.batch_size) + batch_size[dim1], batch_size[dim0] = batch_size[dim0], batch_size[dim1] + self_copy = self.copy() + self_copy.batch_size = batch_size + return self_copy + + def permute(self, *dims): + if len(dims) == 1 and not isinstance(dims[0], int): + return self.permute(*dims[0]) + batch_size = list(self.batch_size) + batch_size = [batch_size[d] for d in dims] + self_copy = self.copy() + self_copy.batch_size = batch_size + return self_copy + + @classmethod + def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False): + result = list_of_non_tensor[0].copy() + batch_size = list(result.batch_size) + batch_size.insert(dim, len(list_of_non_tensor)) + result.batch_size = torch.Size(batch_size) + return result + + @_bypass + def unflatten(self, dim, unflattened_size): ... + + @_bypass + def flatten(self, start_dim=0, end_dim=-1): ... diff --git a/tensordict/base.py b/tensordict/base.py index a37ecaf31..822fbb4cb 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -66,6 +66,8 @@ _lock_warn, _make_dtype_promotion, _parse_to, + _pass_through, + _pass_through_cls, _pin_mem, _PIN_MEM_TIMEOUT, _prefix_last_key, @@ -6450,7 +6452,7 @@ def _get_tuple(self, key, default): ... def _get_tuple_maybe_non_tensor(self, key, default): result = self._get_tuple(key, default) - if is_non_tensor(result): + if _pass_through(result): # Only lazy stacks of non tensors are actually tensordict instances if isinstance(result, TensorDictBase): return result.tolist() @@ -7371,7 +7373,11 @@ def flatten(tensor): else: names = None out = self._fast_apply( - flatten, batch_size=batch_size, propagate_lock=True, names=names + flatten, + batch_size=batch_size, + propagate_lock=True, + names=names, + call_on_nested=True, ) return out @@ -7417,7 +7423,9 @@ def unflatten(tensor): else: batch_size = list(unflattened_size) + list(self.batch_size[1:]) # TODO: check that this works with nested tds of different batch size - out = self._fast_apply(unflatten, batch_size=batch_size, propagate_lock=True) + out = self._fast_apply( + unflatten, batch_size=batch_size, propagate_lock=True, call_on_nested=True + ) if self._has_names(): names = copy(self.names) for _ in range(len(unflattened_size) - 1): @@ -13317,7 +13325,7 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: if _is_tensor_collection(cls): - return _is_non_tensor(cls) + return _pass_through_cls(cls) # if issubclass(cls, KeyedJaggedTensor): # return False return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index b09d2467a..405db4601 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -3321,6 +3321,8 @@ def _maybe_from_list(nontensor): def is_empty(self) -> bool: return False + _stack_non_tensor = NonTensorData._stack_non_tensor + @classmethod def from_nontensordata(cls, non_tensor: NonTensorData): data = non_tensor.data diff --git a/tensordict/utils.py b/tensordict/utils.py index 5813e5b7f..b92cd0f9f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1231,13 +1231,13 @@ def __call__(func): if attr is not None: @wraps(func) - def new_func(_self, *args, **kwargs): + def func_as_decorator(_self, *args, **kwargs): _attr_pre = getattr(_self, attr) out = func(_self, *args, **kwargs) _attr_post = getattr(_self, attr) if out is not None: if _attr_post is not _attr_pre: - out._last_op = (new_func.__name__, (args, kwargs, _self)) + out._last_op = (func.__name__, (args, kwargs, _self)) else: out._last_op = None return out @@ -1245,13 +1245,13 @@ def new_func(_self, *args, **kwargs): else: @wraps(func) - def new_func(_self, *args, **kwargs): + def func_as_decorator(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if out is not None: - out._last_op = (new_func.__name__, (args, kwargs, _self)) + out._last_op = (func.__name__, (args, kwargs, _self)) return out - return new_func + return func_as_decorator return __call__ @@ -2452,9 +2452,13 @@ def __call__(self, mod: torch.nn.Module, args, kwargs): raise RuntimeError("did not find pre-hook") -def is_non_tensor(data): +def is_non_tensor(data) -> bool: """Checks if an item is a non-tensor.""" - return getattr(type(data), "_is_non_tensor", False) + return _is_non_tensor(type(data)) + + +def _pass_through(data) -> bool: + return _pass_through_cls(type(data)) _NON_TENSOR_MEMO = {} @@ -2466,7 +2470,21 @@ def _is_non_tensor(cls: type): if not is_dynamo: out = _NON_TENSOR_MEMO.get(cls) if out is None: - out = getattr(cls, "_is_non_tensor", False) + out = bool(getattr(cls, "_is_non_tensor", False)) + if not is_dynamo: + _NON_TENSOR_MEMO[cls] = out + return out + + +def _pass_through_cls(cls: type): + out = None + is_dynamo = is_compiling() + if not is_dynamo: + out = _NON_TENSOR_MEMO.get(cls) + if out is None: + out = bool(getattr(cls, "_is_non_tensor", False)) or getattr( + cls, "_pass_through", False + ) if not is_dynamo: _NON_TENSOR_MEMO[cls] = out return out diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 60ba211a9..ce17a33c1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -45,6 +45,7 @@ set_get_defaults_to_none, TensorClass, TensorDict, + UnbatchedTensor, ) from tensordict._lazy import _CustomOpTensorDict from tensordict._reductions import _reduce_td @@ -59,6 +60,7 @@ from tensordict.utils import ( _getitem_batch_size, _LOCK_ERROR, + _pass_through, assert_allclose_td, convert_ellipsis_to_idx, is_non_tensor, @@ -11780,6 +11782,108 @@ class SubTC(NonTensorData): ... assert is_non_tensor(SubTC(data=1, batch_size=[])) +class TestUnbatchedTensor: + def test_unbatched(self): + assert UnbatchedTensor._pass_through + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + assert _pass_through(td.get("a")) + assert isinstance(td["a"], torch.Tensor) + assert isinstance(td.get("a"), UnbatchedTensor) + + def test_unbatched_shape_ops(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + # get item + assert td[0]["a"] is td["a"] + assert td[:]["a"] is td["a"] + + unbind = td.unbind(0)[0] + assert unbind["a"] is td["a"] + assert unbind.batch_size == () + + split = td.split(1)[0] + assert split["a"] is td["a"] + assert split.batch_size == (1,) + assert td.split((2, 1))[0]["a"] is td["a"] + + reshape = td.reshape((1, 3)) + assert reshape["a"] is td["a"] + assert reshape.batch_size == (1, 3) + transpose = reshape.transpose(0, 1) + assert transpose["a"] is td["a"] + assert transpose.batch_size == (3, 1) + permute = reshape.permute(1, 0) + assert permute["a"] is td["a"] + assert permute.batch_size == (3, 1) + squeeze = reshape.squeeze() + assert squeeze["a"] is td["a"] + assert squeeze.batch_size == (3,) + + view = td.view((1, 3)) + assert view["a"] is td["a"] + assert view.batch_size == (1, 3) + unsqueeze = td.unsqueeze(0) + assert unsqueeze["a"] is td["a"] + assert unsqueeze.batch_size == (1, 3) + gather = td.gather(0, torch.tensor((0,))) + assert gather["a"] is td["a"] + assert gather.batch_size == (1,) + + unflatten = td.unflatten(0, (1, 3)) + assert unflatten["a"] is td["a"] + assert unflatten.batch_size == (1, 3) + assert unflatten.get("a").batch_size == (1, 3) + assert unflatten.get("a")._tensordict.batch_size == () + + flatten = unflatten.flatten(0, 1) + assert flatten["a"] is td["a"] + assert flatten.batch_size == (3,) + + def test_unbatched_torch_func(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + assert torch.unbind(td, 0)[0]["a"] is td["a"] + assert torch.stack([td, td], 0)[0]["a"] is td["a"] + assert torch.cat([td, td], 0)[0]["a"] is td["a"] + assert (torch.ones_like(td)["a"] == 1).all() + assert torch.unsqueeze(td, 0)["a"] is td["a"] + assert torch.squeeze(td)["a"] is td["a"] + unflatten = torch.unflatten(td, 0, (1, 3)) + assert unflatten["a"] is td["a"] + flatten = torch.flatten(unflatten, 0, 1) + assert flatten["a"] is td["a"] + permute = torch.permute(unflatten, (1, 0)) + assert permute["a"] is td["a"] + transpose = torch.transpose(unflatten, 1, 0) + assert transpose["a"] is td["a"] + + def test_unbatched_other_ops(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + c_d=UnbatchedTensor(torch.randn(10)), + batch_size=(3,), + ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + assert td.copy()["a"] is td["a"] + assert td.int()["a"].dtype == torch.int + assert td.to(device)["a"].device == device + assert td.select("a")["a"] is td["a"] + assert td.exclude("b")["a"] is td["a"] + assert td.unflatten_keys(separator="_")["c", "d"] is td["c_d"] + assert td.unflatten_keys(separator="_").flatten_keys()["c.d"] is td["c_d"] + + def _to_float(td, td_name, tmpdir): if hasattr(td, "_source"): td._source = td._source.float()