Skip to content

Commit

Permalink
[Feature] UnbatchedTensor
Browse files Browse the repository at this point in the history
ghstack-source-id: fa25726d61e913a725a71f1579eb06b09455e7c8
Pull Request resolved: #1170
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent ba50073 commit 50f80ca
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 32 deletions.
56 changes: 41 additions & 15 deletions docs/source/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,40 +71,66 @@ Features
--------

A :class:`~tensordict.TensorDict` is a dict-like container for tensors. To instantiate a :class:`~tensordict.TensorDict`,
you must specify key-value pairs as well as the batch size. The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size.
you can specify key-value pairs
as well as the batch size (an empty tensordict can be created via `TensorDict()`).
The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size.

>>> import torch
>>> from tensordict import TensorDict

>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )

The syntax for setting or retrieving values is much like that for a regular dictionary.

>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)

One can also index a tensordict along its batch_size which makes it possible to obtain congruent slices of data in just
a few characters (notice that indexing the nth leading dimensions with tree_map using an ellipsis would require a bit more coding):

>>> sub_tensordict = tensordict[..., :2]
>>> sub_tensordict = tensordict[..., :2]

One can also use the set method with ``inplace=True`` or the :meth:`~tensordict.TensorDict.set_` method to do inplace updates of the contents.
The former is a fault-tolerant version of the latter: if no matching key is found, it will write a new one.

The contents of the TensorDict can now be manipulated collectively.
For example, to place all of the contents onto a particular device one can simply do

>>> tensordict = tensordict.to("cuda:0")
>>> tensordict = tensordict.to("cuda:0")

You can then assert that the device of the tensordict is `"cuda:0"`:

>>> assert tensordict.device == torch.device("cuda:0")

To reshape the batch dimensions one can do

>>> tensordict = tensordict.reshape(6)
>>> tensordict = tensordict.reshape(6)

The class supports many other operations, including :func:`~torch.squeeze`, :func:`~torch.unsqueeze`,
:meth:`~tensordict.TensorDict.view`, :func:`~torch.permute`, :meth:`~tensordict.TensorDict.unbind`,
:func:`~torch.stack`, :func:`~torch.cat` and many more.

If an operation is not present, the :meth:`~tensordict.TensorDict.apply` method will usually provide the solution
that was needed.

Escaping shape operations
~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, it may be desirable to store tensors in a TensorDict without enforcing batch size consistency during
shape operations.

This can be achieved by wrapping the tensor in an :class:`~tensordict.UnbatchedTensor` instance.

An :class:`~tensordict.UnbatchedTensor` ignores its shape during shape operations on the TensorDict, allowing for
flexible storage and manipulation of tensors with arbitrary shapes.

The class supports many other operations, including squeeze, unsqueeze, view, permute, unbind, stack, cat and many more.
If an operation is not present, the TensorDict.apply method will usually provide the solution that was needed.
>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True

Non-tensor data
---------------
Expand Down
1 change: 1 addition & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
stack,
TensorDict,
)
from tensordict._unbatched import UnbatchedTensor

from tensordict.base import (
from_any,
Expand Down
10 changes: 5 additions & 5 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from tensordict.utils import (
_check_keys,
_ErrorInteceptor,
_pass_through,
_shape,
_zip_strict,
DeviceType,
is_non_tensor,
is_tensorclass,
lazy_legacy,
set_lazy_legacy,
Expand Down Expand Up @@ -454,10 +454,10 @@ def _stack(
if maybe_dense_stack is None:
maybe_dense_stack = lazy_legacy()
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
if all(is_non_tensor(td) for td in list_of_tensordicts):
from tensordict.tensorclass import NonTensorData

return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)
if all(_pass_through(td) for td in list_of_tensordicts):
return type(list_of_tensordicts[0])._stack_non_tensor(
list_of_tensordicts, dim=dim
)
if is_tc:
tc_type = type(list_of_tensordicts[0])
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]
Expand Down
214 changes: 214 additions & 0 deletions tensordict/_unbatched.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from functools import wraps
from typing import Any, Callable

import torch
from tensordict.base import TensorDictBase

from tensordict.tensorclass import (
_arg_to_tensordict,
_from_tensordict_with_copy,
_TD_PASS_THROUGH,
TD_HANDLED_FUNCTIONS,
TensorClass,
)
from tensordict.utils import _getitem_batch_size, _is_tensorclass, unravel_key
from torch import Tensor


def _arg_to_tensordict_unbatched(arg, batch_size):
if _is_tensorclass(type(arg)):
arg = arg._tensordict.empty()
arg.batch_size = batch_size
return arg
elif isinstance(arg, (tuple, list)) and all(
_is_tensorclass(type(item)) for item in arg
):
arg_list = []
for item in arg:
item = item._tensordict.empty()
item.batch_size = batch_size
arg_list.append(item)

return type(arg)(arg_list)
return arg


def _bypass(func):
@wraps(func)
def bypassed_func(self, *args, **kwargs):
meta_tensor = torch.zeros(
self.batch_size, dtype=self.dtype, device=torch.device("meta")
)
name = func.__name__
r = getattr(meta_tensor, name)(*args, **kwargs)
self_copy = self.copy()
self_copy.batch_size = r.shape
return self_copy

return bypassed_func


_TORCH_SHAPE_OPS = (
torch.gather,
torch.unbind,
torch.cat,
torch.stack,
torch.unflatten,
torch.flatten,
torch.split,
torch.squeeze,
torch.unsqueeze,
)


class UnbatchedTensor(TensorClass):
"""A TensorClass that represents a tensor whose shape is ignored during shape operations.
This class allows tensors to be stored in a TensorDict without enforcing batch size consistency.
Shape operations (e.g., reshape, unsqueeze, squeeze) on the TensorDict will return the same UnbatchedTensor instance,
while other operations (e.g., apply, key manipulation, pointwise arithmetic) may modify the underlying tensor content.
Example:
>>> td = TensorDict(a=UnbatchedTensor(torch.randn(3, 4)), b=torch.randn(2, 3), batch_size=(2,))
>>> td_reshaped = td.reshape((1, 2))
>>> td_reshaped["a"] is td["a"]
True
Note that accessing an UnbatchedTensor using `get()` and `__getitem__()` will return different results.
`get()` returns the UnbatchedTensor instance, while `__getitem__()` returns the underlying tensor content.
Example:
>>> td.get("a")
<UnbatchedTensor ...>
>>> td["a"]
tensor([[...]])
"""

data: torch.Tensor | TensorDictBase
_pass_through = True

@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple[type, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Callable:
if func not in _TD_PASS_THROUGH or not all(
issubclass(t, (Tensor, cls, TensorDictBase)) for t in types
):
return NotImplemented

if kwargs is None:
kwargs = {}

# get the output type from the arguments / keyword arguments
if len(args) > 0:
tensorclass_instance = args[0]
else:
tensorclass_instance = kwargs.get("input", kwargs["tensors"])
if isinstance(tensorclass_instance, (tuple, list)):
tensorclass_instance = tensorclass_instance[0]

if func not in _TORCH_SHAPE_OPS:
args = tuple(_arg_to_tensordict(arg) for arg in args)
kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()}
result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
else:
# Get a brute force batch size
args = tuple(
_arg_to_tensordict_unbatched(arg, tensorclass_instance.batch_size)
for arg in args
)
kwargs = {
key: _arg_to_tensordict_unbatched(
value, tensorclass_instance.batch_size
)
for key, value in kwargs.items()
}
example_td = TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
result = tensorclass_instance.copy()
result.batch_size = example_td.batch_size
return result

if isinstance(result, (list, tuple)):
return type(result)(
_from_tensordict_with_copy(tensorclass_instance, tensordict_result)
for tensordict_result in result
)
return _from_tensordict_with_copy(tensorclass_instance, result)

def __getitem__(self, index):
if isinstance(index, (tuple, str)) and unravel_key(index):
raise ValueError(
"TensorClass fields must be accessed as attributes, not items."
)
self_copy = self.copy()
self_copy.batch_size = _getitem_batch_size(self.batch_size, index)
return self_copy

@property
def batch_size(self):
return self._batch_size

@batch_size.setter
def batch_size(self, batch_size):
self.__dict__["_batch_size"] = torch.Size(batch_size)

shape = batch_size

def unbind(self, dim: int):
return tuple(
self[(slice(None),) * dim + (0,)] for _ in range(self.batch_size[dim])
)

@_bypass
def reshape(self, *shape): ...

@_bypass
def view(self, *shape): ...

def unsqueeze(self, dim):
shape = list(self.batch_size)
shape.insert(dim, 0)
self_copy = self.copy()
self_copy.batch_size = shape
return self_copy

def transpose(self, dim0, dim1):
batch_size = list(self.batch_size)
batch_size[dim1], batch_size[dim0] = batch_size[dim0], batch_size[dim1]
self_copy = self.copy()
self_copy.batch_size = batch_size
return self_copy

def permute(self, *dims):
if len(dims) == 1 and not isinstance(dims[0], int):
return self.permute(*dims[0])
batch_size = list(self.batch_size)
batch_size = [batch_size[d] for d in dims]
self_copy = self.copy()
self_copy.batch_size = batch_size
return self_copy

@classmethod
def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False):
result = list_of_non_tensor[0].copy()
batch_size = list(result.batch_size)
batch_size.insert(dim, len(list_of_non_tensor))
result.batch_size = torch.Size(batch_size)
return result

@_bypass
def unflatten(self, dim, unflattened_size): ...

@_bypass
def flatten(self, start_dim=0, end_dim=-1): ...
16 changes: 12 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
_lock_warn,
_make_dtype_promotion,
_parse_to,
_pass_through,
_pass_through_cls,
_pin_mem,
_PIN_MEM_TIMEOUT,
_prefix_last_key,
Expand Down Expand Up @@ -6450,7 +6452,7 @@ def _get_tuple(self, key, default): ...

def _get_tuple_maybe_non_tensor(self, key, default):
result = self._get_tuple(key, default)
if is_non_tensor(result):
if _pass_through(result):
# Only lazy stacks of non tensors are actually tensordict instances
if isinstance(result, TensorDictBase):
return result.tolist()
Expand Down Expand Up @@ -7371,7 +7373,11 @@ def flatten(tensor):
else:
names = None
out = self._fast_apply(
flatten, batch_size=batch_size, propagate_lock=True, names=names
flatten,
batch_size=batch_size,
propagate_lock=True,
names=names,
call_on_nested=True,
)
return out

Expand Down Expand Up @@ -7417,7 +7423,9 @@ def unflatten(tensor):
else:
batch_size = list(unflattened_size) + list(self.batch_size[1:])
# TODO: check that this works with nested tds of different batch size
out = self._fast_apply(unflatten, batch_size=batch_size, propagate_lock=True)
out = self._fast_apply(
unflatten, batch_size=batch_size, propagate_lock=True, call_on_nested=True
)
if self._has_names():
names = copy(self.names)
for _ in range(len(unflattened_size) - 1):
Expand Down Expand Up @@ -13317,7 +13325,7 @@ def _default_is_leaf(cls: Type) -> bool:

def _is_leaf_nontensor(cls: Type) -> bool:
if _is_tensor_collection(cls):
return _is_non_tensor(cls)
return _pass_through_cls(cls)
# if issubclass(cls, KeyedJaggedTensor):
# return False
return issubclass(cls, torch.Tensor)
Expand Down
2 changes: 2 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3427,6 +3427,8 @@ def _maybe_from_list(nontensor):
def is_empty(self) -> bool:
return False

_stack_non_tensor = NonTensorData._stack_non_tensor

@classmethod
def from_nontensordata(cls, non_tensor: NonTensorData):
data = non_tensor.data
Expand Down
Loading

0 comments on commit 50f80ca

Please sign in to comment.