diff --git a/docs/source/overview.rst b/docs/source/overview.rst index 4a91391c9..72e8942e0 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -1,16 +1,21 @@ Overview ======== -TensorDict makes it easy to organise data and write reusable, generic PyTorch code. Originally developed for TorchRL, we've spun it out into a separate library. +TensorDict makes it easy to organise data and write reusable, generic PyTorch code. Originally developed for TorchRL, +we've spun it out into a separate library. -TensorDict is primarily a dictionary but also a tensor-like class: it supports multiple tensor operations that are mostly shape and storage-related. It is designed to be efficiently serialised or transmitted from node to node or process to process. Finally, it is shipped with its own ``tensordict.nn`` module which is compatible with ``functorch`` and aims at making model ensembling and parameter manipulation easier. +TensorDict is primarily a dictionary but also a tensor-like class: it supports multiple tensor operations that are +mostly shape and storage-related. It is designed to be efficiently serialised or transmitted from node to node or +process to process. Finally, it is shipped with its own :mod:`~tensordict.nn` module which is compatible with ``torch.func`` +and aims at making model ensembling and parameter manipulation easier. -On this page we will motivate ``TensorDict`` and give some examples of what it can do. +On this page we will motivate :class:`~tensordict.TensorDict` and give some examples of what it can do. Motivation ---------- -TensorDict allows you to write generic code modules that are re-usable across paradigms. For instance, the following loop can be re-used across most SL, SSL, UL and RL tasks. +TensorDict allows you to write generic code modules that are re-usable across paradigms. For instance, the following +loop can be re-used across most SL, SSL, UL and RL tasks. >>> for i, tensordict in enumerate(dataset): ... # the model reads and writes tensordicts @@ -20,9 +25,11 @@ TensorDict allows you to write generic code modules that are re-usable across pa ... optimizer.step() ... optimizer.zero_grad() -With its ``tensordict.nn`` module, the package provides many tools to use ``TensorDict`` in a code base with little or no effort. +With its :mod:`~tensordict.nn` module, the package provides many tools to use :class:`~tensordict.TensorDict` in a code +base with little or no effort. -In multiprocessing or distributed settings, ``tensordict`` allows you to seamlessly dispatch data to each worker: +In multiprocessing or distributed settings, :class:`~tensordict.TensorDict` allows you to seamlessly dispatch data to +each worker: >>> # creates batches of 10 datapoints >>> splits = torch.arange(tensordict.shape[0]).split(10) @@ -56,14 +63,17 @@ The nested case is even more compelling: ... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]} ... for i in range(3) -Decomposing the output dictionary in three similarly structured dictionaries after applying the unbind operation quickly becomes significantly cumbersome when working naively with pytree. With tensordict, we provide a simple API for users that want to unbind or split nested structures, rather than computing a nested split / unbound nested structure. +Decomposing the output dictionary in three similarly structured dictionaries after applying the unbind operation quickly +becomes significantly cumbersome when working naively with pytree. With tensordict, we provide a simple API for users +that want to unbind or split nested structures, rather than computing a nested split / unbound nested structure. Features -------- -A ``TensorDict`` is a dict-like container for tensors. To instantiate a ``TensorDict``, you can specify key-value pairs +A :class:`~tensordict.TensorDict` is a dict-like container for tensors. To instantiate a :class:`~tensordict.TensorDict`, +you can specify key-value pairs as well as the batch size (an empty tensordict can be created via `TensorDict()`). -The leading dimensions of any values in the ``TensorDict`` must be compatible with the batch size. +The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size. >>> import torch >>> from tensordict import TensorDict @@ -82,7 +92,7 @@ a few characters (notice that indexing the nth leading dimensions with tree_map >>> sub_tensordict = tensordict[..., :2] -One can also use the set method with ``inplace=True`` or the ``set_`` method to do inplace updates of the contents. +One can also use the set method with ``inplace=True`` or the :meth:`~tensordict.TensorDict.set_` method to do inplace updates of the contents. The former is a fault-tolerant version of the latter: if no matching key is found, it will write a new one. The contents of the TensorDict can now be manipulated collectively. @@ -122,6 +132,165 @@ flexible storage and manipulation of tensors with arbitrary shapes. >>> reshaped_td["zeros"] is tensordict["zeros"] True +Non-tensor data +--------------- + +Tensordict is a powerful library for working with tensor data, but it also supports non-tensor data. This guide will +show you how to use tensordict with non-tensor data. + +Creating a TensorDict with Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can create a TensorDict with non-tensor data using the :class:`~tensordict.NonTensorData` class. + + >>> from tensordict import TensorDict, NonTensorData + >>> import torch + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b=torch.zeros(()), + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +As you can see, the :class:`~tensordict.NonTensorData` object is stored in the TensorDict just like a regular tensor. + +Accessing Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can access the non-tensor data using the key or the get method. Regular `getattr` calls will return the content of +the :class:`~tensordict.NonTensorData` object whereas :meth:`~tensordict.TensorDict.get` will return the +:class:`~tensordict.NonTensorData` object itself. + + >>> print(td["a"]) # prints: a string! + >>> print(td.get("a")) # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None) + + +Batched Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~ + +If you have a batch of non-tensor data, you can store it in a TensorDict with a specified batch size. + + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None), + b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + +In this case, we assume that all elements of the tensordict have the same non-tensor data. + + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +To assign a different non-tensor data object to each element in a shaped tensordict, you can use stacks of non-tensor +data. + +Stacked Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~ + +If you have a list of non-tensor data that you want to store in a :class:`~tensordict.TensorDict`, you can use the +:class:`~tensordict.NonTensorStack` class. + + >>> td = TensorDict( + ... a=NonTensorStack("a string!", "another string!", "a third string!"), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td) + TensorDict( + fields={ + a: NonTensorStack( + ['a string!', 'another string!', 'a third string!'..., + batch_size=torch.Size([3]), + device=None), + b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + +You can access the first element and you will get the first of the strings: + + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +In contrast, using :class:`~tensordict.NonTensorData` with a list will not lead to the same result, as there is no +way to tell what to do in general with a non-tensor data that happens to be a list: + + >>> td = TensorDict( + ... a=NonTensorData(["a string!", "another string!", "a third string!"]), + ... b=torch.zeros(3), + ... batch_size=[3] + ... ) + >>> print(td[0]) + TensorDict( + fields={ + a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None), + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + +Stacking TensorDicts with Non-Tensor Data +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To stack non-tensor data, :func:`~torch.stack` will check the identity of the non-tensor objects and produce a single +:class:`~tensordict.NonTensorData` if they match, or a :class:`~tensordict.NonTensorStack` otherwise: + + >>> td = TensorDict( + ... a=NonTensorData("a string!"), + ... b = torch.zeros(()), + ... ) + >>> print(torch.stack([td, td])) + TensorDict( + fields={ + a: NonTensorData(data=a string!, batch_size=torch.Size([2]), device=None), + b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + +If you want to make sure the result is a stack, use :meth:`~tensordict.TensorDict.lazy_stack` instead. + + >>> print(TensorDict.lazy_stack([td, td])) + LazyStackedTensorDict( + fields={ + a: NonTensorStack( + ['a string!', 'a string!'], + batch_size=torch.Size([2]), + device=None), + b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([2]), + device=None, + is_shared=False, + stack_dim=0) + Named dimensions ---------------- @@ -138,7 +307,8 @@ similar to the torch.Tensor dimension name feature: Nested TensorDicts ------------------ -The values in a ``TensorDict`` can themselves be TensorDicts (the nested dictionaries in the example below will be converted to nested TensorDicts). +The values in a :class:`~tensordict.TensorDict` can themselves be TensorDicts (the nested dictionaries in the example +below will be converted to nested TensorDicts). >>> tensordict = TensorDict( ... { @@ -160,7 +330,10 @@ Accessing or setting nested keys can be done with tuples of strings Lazy evaluation --------------- -Some operations on ``TensorDict`` defer execution until items are accessed. For example stacking, squeezing, unsqueezing, permuting batch dimensions and creating a view are not executed immediately on all the contents of the ``TensorDict``. Instead they are performed lazily when values in the ``TensorDict`` are accessed. This can save a lot of unnecessary calculation should the ``TensorDict`` contain many values. +Some operations on :class:`~tensordict.TensorDict` defer execution until items are accessed. For example stacking, +squeezing, unsqueezing, permuting batch dimensions and creating a view are not executed immediately on all the contents +of the :class:`~tensordict.TensorDict`. Instead they are performed lazily when values in the :class:`~tensordict.TensorDict` +are accessed. This can save a lot of unnecessary calculation should the :class:`~tensordict.TensorDict` contain many values. >>> tensordicts = [TensorDict({ ... "a": torch.rand(10), @@ -174,7 +347,10 @@ It also has the advantage that we can manipulate the original tensordicts in a s >>> stacked["a"] = torch.zeros_like(stacked["a"]) >>> assert (tensordicts[0]["a"] == 0).all() -The caveat is that the get method has now become an expensive operation and, if repeated many times, may cause some overhead. One can avoid this by simply calling tensordict.contiguous() after the execution of stack. To further mitigate this, TensorDict comes with its own meta-data class (MetaTensor) that keeps track of the type, shape, dtype and device of each entry of the dict, without performing the expensive operation. +The caveat is that the get method has now become an expensive operation and, if repeated many times, may cause some +overhead. One can avoid this by simply calling tensordict.contiguous() after the execution of stack. To further mitigate +this, TensorDict comes with its own meta-data class (MetaTensor) that keeps track of the type, shape, dtype and device +of each entry of the dict, without performing the expensive operation. Lazy pre-allocation ------------------- @@ -185,14 +361,16 @@ Suppose we have some function foo() -> TensorDict and that we do something like >>> for i in range(N): ... tensordict[i] = foo() -When ``i == 0`` the empty ``TensorDict`` will automatically be populated with empty tensors with batch size N. In subsequent iterations of the loop the updates will all be written in-place. +When ``i == 0`` the empty :class:`~tensordict.TensorDict` will automatically be populated with empty tensors with batch +size N. In subsequent iterations of the loop the updates will all be written in-place. TensorDictModule ---------------- -To make it easy to integrate ``TensorDict`` in one's code base, we provide a tensordict.nn package that allows users to pass ``TensorDict`` instances to ``nn.Module`` objects. +To make it easy to integrate :class:`~tensordict.TensorDict` in one's code base, we provide a tensordict.nn package that allows users to +pass :class:`~tensordict.TensorDict` instances to :class:`~torch.nn.Module` objects (or any callable). -``TensorDictModule`` wraps ``nn.Module`` and accepts a single ``TensorDict`` as an input. You can specify where the underlying module should take its input from, and where it should write its output. This is a key reason we can write reusable, generic high-level code such as the training loop in the motivation section. +:class:`~tensordict.nn.TensorDictModule` wraps :class:`~torch.nn.Module` and accepts a single :class:`~tensordict.TensorDict` as an input. You can specify where the underlying module should take its input from, and where it should write its output. This is a key reason we can write reusable, generic high-level code such as the training loop in the motivation section. >>> from tensordict.nn import TensorDictModule >>> class Net(nn.Module): @@ -218,11 +396,17 @@ To facilitate the adoption of this class, one can also pass the tensors as kwarg >>> tensordict = module(input=torch.randn(32, 100)) -which will return a ``TensorDict`` identical to the one in the previous code box. +which will return a :class:`~tensordict.TensorDict` identical to the one in the previous code box. See :ref:`the export tutorial` for +more context on this feature. -A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be read and where to write it. +A key pain-point of multiple PyTorch users is the inability of nn.Sequential to handle modules with multiple inputs. +Working with key-based graphs can easily solve that problem as each node in the sequence knows what data needs to be +read and where to write it. -For this purpose, we provide the ``TensorDictSequential`` class which passes data through a sequence of ``TensorDictModules``. Each module in the sequence takes its input from, and writes its output to the original ``TensorDict``, meaning it's possible for modules in the sequence to ignore output from their predecessors, or take additional input from the tensordict as necessary. Here's an example. +For this purpose, we provide the :class:`~tensordict.nn.TensorDictSequential` class which passes data through a +sequence of ``TensorDictModules``. Each module in the sequence takes its input from, and writes its output to the +original :class:`~tensordict.TensorDict`, meaning it's possible for modules in the sequence to ignore output from their +predecessors, or take additional input from the tensordict as necessary. Here's an example: >>> class Net(nn.Module): ... def __init__(self, input_size=100, hidden_size=50, output_size=10): @@ -259,38 +443,12 @@ For this purpose, we provide the ``TensorDictSequential`` class which passes dat >>> intermediate_x = tensordict["intermediate", "x"] >>> probabilities = tensordict["output", "probabilities"] -In this example, the second module combines the output of the first with the mask stored under ("inputs", "mask") in the ``TensorDict``. - -``TensorDictSequential`` offers a bunch of other features: one can access the list of input and output keys by querying the in_keys and out_keys attributes. It is also possible to ask for a sub-graph by querying ``select_subsequence()`` with the desired sets of input and output keys that are desired. This will return another ``TensorDictSequential`` with only the modules that are indispensable to satisfy those requirements. The ``TensorDictModule`` is also compatible with ``vmap`` and other ``functorch`` capabilities. - -Functional Programming ----------------------- - -We provide and API to use ``TensorDict`` in conjunction with ``functorch``. For instance, ``TensorDict`` makes it easy to concatenate model weights to do model ensembling: - ->>> from torch import nn ->>> from tensordict import TensorDict ->>> from tensordict.nn import make_functional ->>> import torch ->>> from torch import vmap ->>> layer1 = nn.Linear(3, 4) ->>> layer2 = nn.Linear(4, 4) ->>> model = nn.Sequential(layer1, layer2) ->>> # we represent the weights hierarchically ->>> weights1 = TensorDict(layer1.state_dict(), []).unflatten_keys(separator=".") ->>> weights2 = TensorDict(layer2.state_dict(), []).unflatten_keys(separator=".") ->>> params = make_functional(model) ->>> # params provided by make_functional match state_dict: ->>> assert (params == TensorDict({"0": weights1, "1": weights2}, [])).all() ->>> # Let's use our functional module ->>> x = torch.randn(10, 3) ->>> out = model(x, params=params) # params is the last arg (or kwarg) ->>> # an ensemble of models: we stack params along the first dimension... ->>> params_stack = torch.stack([params, params], 0) ->>> # ... and use it as an input we'd like to pass through the model ->>> y = vmap(model, (None, 0))(x, params_stack) ->>> print(y.shape) -torch.Size([2, 10, 4]) - - -The functional API is comparable if not faster than the current ``FunctionalModule`` implemented in ``functorch``. +In this example, the second module combines the output of the first with the mask stored under ("inputs", "mask") in the +:class:`~tensordict.TensorDict`. + +:class:`~tensordict.nn.TensorDictSequential` offers a bunch of other features: one can access the list of input and +output keys by querying the in_keys and out_keys attributes. It is also possible to ask for a sub-graph by querying +:meth:`~tensordict.nn.TensorDictSequential.select_subsequence` with the desired sets of input and output keys that are desired. This will return another +:class:`~tensordict.nn.TensorDictSequential` with only the modules that are indispensable to satisfy those requirements. +The :class:`~tensordict.nn.TensorDictModule` is also compatible with :func:`~torch.vmap` and other ``torch.func`` +capabilities. diff --git a/tensordict/base.py b/tensordict/base.py index 822fbb4cb..e0b1d0411 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7946,7 +7946,7 @@ def reduce( async_op=False, return_premature=False, group=None, - ): + ) -> None: """Reduces the tensordict across all machines. Only the process with ``rank`` dst is going to receive the final result. @@ -9036,7 +9036,7 @@ def newfn(item_and_out): return out # Stream - def record_stream(self, stream: torch.cuda.Stream): + def record_stream(self, stream: torch.cuda.Stream) -> T: """Marks the tensordict as having been used by this stream. When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work @@ -11353,7 +11353,7 @@ def copy(self): """ return self.clone(recurse=False) - def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None): + def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T: """Converts all nested tensors to a padded version and adapts the batch-size accordingly. Args: @@ -12438,7 +12438,7 @@ def split_keys( default: Any = NO_DEFAULT, strict: bool = True, reproduce_struct: bool = False, - ): + ) -> Tuple[T, ...]: """Splits the tensordict in subsets given one or more set of keys. The method will return ``N+1`` tensordicts, where ``N`` is the number of diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index a9b2cdf5a..621388f2c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -510,9 +510,12 @@ def forward( kwargs = {"aggregate_probabilities": False} log_prob = dist.log_prob(out_tensors, **kwargs) if log_prob is not out_tensors: - # Composite dists return the tensordict_out directly when aggrgate_prob is False - out_tensors.set(self.log_prob_key, log_prob) - else: + if is_tensor_collection(log_prob): + out_tensors.update(log_prob) + else: + # Composite dists return the tensordict_out directly when aggrgate_prob is False + out_tensors.set(self.log_prob_key, log_prob) + elif dist.log_prob_key in out_tensors: out_tensors.rename_key_(dist.log_prob_key, self.log_prob_key) tensordict_out.update(out_tensors) else: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 405db4601..04ada3198 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -121,6 +121,7 @@ def __subclasscheck__(self, subclass): } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ + "dumps", "load_", "memmap", "memmap_", @@ -145,21 +146,48 @@ def __subclasscheck__(self, subclass): "_items_list", "_maybe_names", "_multithread_apply_flat", + "_multithread_apply_nest", "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", "_reduce_get_metadata", "_values_list", + "bytes", + "cat_tensors", "data_ptr", + "depth", "dim", + "dtype", + "entry_class", + "get_item_shape", + "get_non_tensor", + "irecv", + "is_consolidated", + "is_contiguous", + "is_cpu", + "is_cuda", "is_empty", + "is_floating_point", "is_memmap", + "is_meta", "is_shared", + "isend", "items", "keys", + "make_memmap", + "make_memmap_from_tensor", "ndimension", "numel", + "numpy", + "param_count", + "pop", + "recv", + "reduce", + "saved_path", + "send", "size", + "sorted_keys", + "to_struct_array", "values", # "ndim", ] @@ -214,9 +242,6 @@ def __subclasscheck__(self, subclass): "_map", "_maybe_remove_batch_dim", "_memmap_", - "_multithread_apply_flat", - "_multithread_apply_nest", - "_multithread_rebuild", "_permute", "_remove_batch_dim", "_repeat", @@ -235,6 +260,8 @@ def __subclasscheck__(self, subclass): "addcmul", "addcmul_", "all", + "amax", + "amin", "any", "apply", "apply_", @@ -245,31 +272,43 @@ def __subclasscheck__(self, subclass): "atan_", "auto_batch_size_", "auto_device_", + "bfloat16", "bitwise_and", "bool", + "cat", + "cat_from_tensordict", "ceil", "ceil_", "chunk", + "clamp", "clamp_max", "clamp_max_", "clamp_min", "clamp_min_", "clear", "clear_device_", + "complex128", + "complex32", + "complex64", "consolidate", "contiguous", "copy_", + "copy_at_", "cos", "cos_", "cosh", "cosh_", "cpu", + "create_nested", "cuda", "cummax", "cummin", "densify", + "detach", + "detach_", "div", "div_", + "double", "empty", "erf", "erf_", @@ -282,20 +321,43 @@ def __subclasscheck__(self, subclass): "expand_as", "expm1", "expm1_", + "fill_", + "filter_empty_", "filter_non_tensor_data", "flatten", + "flatten_keys", + "float", + "float16", + "float32", + "float64", "floor", "floor_", "frac", "frac_", "from_any", + "from_consolidated", "from_dataclass", + "from_h5", + "from_modules", "from_namedtuple", "from_pytree", + "from_struct_array", + "from_tuple", + "fromkeys", "gather", + "gather_and_stack", + "half", + "int", + "int16", + "int32", + "int64", + "int8", "isfinite", "isnan", + "isneginf", + "isposinf", "isreal", + "lazy_stack", "lerp", "lerp_", "lgamma", @@ -312,13 +374,16 @@ def __subclasscheck__(self, subclass): "log_", "logical_and", "logsumexp", + "make_memmap_from_storage", "map", "map_iter", "masked_fill", "masked_fill_", + "masked_select", "max", "maximum", "maximum_", + "maybe_dense_stack", "mean", "min", "minimum", @@ -338,13 +403,22 @@ def __subclasscheck__(self, subclass): "norm", "permute", "pin_memory", + "pin_memory_", + "popitem", "pow", "pow_", "prod", + "qint32", + "qint8", + "quint4x2", + "quint8", "reciprocal", "reciprocal_", + "record_stream", "refine_names", + "rename", "rename_", # TODO: must be specialized + "rename_key_", "repeat", "repeat_interleave", "replace", @@ -353,6 +427,10 @@ def __subclasscheck__(self, subclass): "round", "round_", "select", + "separates", + "set_", + "set_non_tensor", + "setdefault", "sigmoid", "sigmoid_", "sign", @@ -363,9 +441,13 @@ def __subclasscheck__(self, subclass): "sinh_", "softmax", "split", + "split_keys", "sqrt", "sqrt_", "squeeze", + "stack", + "stack_from_tensordict", + "stack_tensors", "std", "sub", "sub_", @@ -375,13 +457,21 @@ def __subclasscheck__(self, subclass): "tanh", "tanh_", "to", + "to_h5", "to_module", "to_namedtuple", + "to_padded_tensor", "to_pytree", "transpose", "trunc", "trunc_", + "type", + "uint16", + "uint32", + "uint64", + "uint8", "unflatten", + "unflatten_keys", "unlock_", "unsqueeze", "var", @@ -390,10 +480,6 @@ def __subclasscheck__(self, subclass): "zero_", "zero_grad", ] -assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set( - _METHOD_FROM_TD -).intersection(_FALLBACK_METHOD_FROM_TD) -assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD) # These methods require a copy of the non tensor data _FALLBACK_METHOD_FROM_TD_COPY = [ @@ -865,6 +951,14 @@ def __torch_function__( cls.device = property(_device, _device_setter) if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys: cls.batch_size = property(_batch_size, _batch_size_setter) + if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys: + cls.batch_dims = property(_batch_dims) + if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys: + cls.requires_grad = property(_requires_grad) + if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys: + cls.is_locked = property(_is_locked) + if not hasattr(cls, "ndim") and "ndim" not in expected_keys: + cls.ndim = property(_batch_dims) if not hasattr(cls, "shape") and "shape" not in expected_keys: cls.shape = property(_batch_size, _batch_size_setter) if not hasattr(cls, "names") and "names" not in expected_keys: @@ -2160,6 +2254,18 @@ def _batch_size(self) -> torch.Size: return self._tensordict.batch_size +def _batch_dims(self) -> torch.Size: + return self._tensordict.batch_dims + + +def _requires_grad(self) -> torch.Size: + return self._tensordict.requires_grad + + +def _is_locked(self) -> torch.Size: + return self._tensordict.is_locked + + def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 """Set the value of batch_size. diff --git a/tensordict/tensorclass.pyi b/tensordict/tensorclass.pyi index 87a7d7753..f4131802d 100644 --- a/tensordict/tensorclass.pyi +++ b/tensordict/tensorclass.pyi @@ -526,7 +526,15 @@ class TensorClass: return_early: bool = False, share_non_tensor: bool = False, ) -> T: ... - dumps = save + def dumps( + self, + prefix: str | None = None, + copy_existing: bool = False, + *, + num_threads: int = 0, + return_early: bool = False, + share_non_tensor: bool = False, + ) -> T: ... def memmap( self, prefix: str | None = None, @@ -892,6 +900,14 @@ class TensorClass: *, default: str | CompatibleType | None = None, ) -> T: ... + def clamp( + self, + min: TensorDictBase | torch.Tensor = None, + max: TensorDictBase | torch.Tensor = None, + *, + out=None, + ): ... + def logsumexp(self, dim=None, keepdim=False, *, out=None): ... def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: ... def clamp_max( self, @@ -944,6 +960,27 @@ class TensorClass: def to_namedtuple(self, dest_cls: type | None = None): ... @classmethod def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): ... + def from_tuple( + cls, + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ): ... + def logical_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: ... + def bitwise_and( + self, + other: TensorDictBase | torch.Tensor, + *, + default: str | CompatibleType | None = None, + ) -> TensorDictBase: ... @classmethod def from_struct_array( cls, struct_array: np.ndarray, device: torch.device | None = None @@ -987,6 +1024,20 @@ class TensorClass: strict: bool = True, reproduce_struct: bool = False, ): ... + def separates( + self, + *keys: NestedKey, + default: Any = NO_DEFAULT, + strict: bool = True, + filter_empty: bool = True, + ) -> T: ... + def norm( + self, + *, + out=None, + dtype: torch.dtype | None = None, + ): ... + def softmax(self, dim: int, dtype: torch.dtype | None = None): ... @property def is_locked(self) -> bool: ... @is_locked.setter diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index e2ccc9989..01a5aec37 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -5,10 +5,12 @@ from __future__ import annotations import argparse +import ast import contextlib import dataclasses import inspect import os +import pathlib import pickle import re import sys @@ -61,6 +63,126 @@ ] +def _get_methods_from_pyi(file_path): + """ + Reads a .pyi file and returns a set of method names. + + Args: + file_path (str): Path to the .pyi file. + + Returns: + set: A set of method names. + """ + with open(file_path, "r") as f: + tree = ast.parse(f.read()) + + methods = set() + for node in tree.body: + if isinstance(node, ast.ClassDef): + for child_node in node.body: + if isinstance(child_node, ast.FunctionDef): + methods.add(child_node.name) + + return methods + + +def _get_methods_from_class(cls): + """ + Returns a set of method names from a given class. + + Args: + cls (class): The class to get methods from. + + Returns: + set: A set of method names. + """ + methods = set() + for name in dir(cls): + attr = getattr(cls, name) + if ( + inspect.isfunction(attr) + or inspect.ismethod(attr) + or isinstance(attr, property) + ): + methods.add(name) + + return methods + + +def test_tensorclass_stub_methods(): + tensorclass_pyi_path = ( + pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi" + ) + tensorclass_methods = _get_methods_from_pyi(str(tensorclass_pyi_path)) + + from tensordict import TensorDict + + tensordict_methods = _get_methods_from_class(TensorDict) + + missing_methods = tensordict_methods - tensorclass_methods + missing_methods = [ + method for method in missing_methods if (not method.startswith("_")) + ] + + if missing_methods: + raise Exception( + f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" + ) + + +def test_tensorclass_instance_methods(): + @tensorclass + class X: + x: torch.Tensor + + tensorclass_pyi_path = ( + pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi" + ) + tensorclass_abstract_methods = _get_methods_from_pyi(str(tensorclass_pyi_path)) + + tensorclass_methods = _get_methods_from_class(X) + + missing_methods = ( + tensorclass_abstract_methods - tensorclass_methods - {"data", "grad"} + ) + missing_methods = [ + method for method in missing_methods if (not method.startswith("_")) + ] + + if missing_methods: + raise Exception( + f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" + ) + + +def test_sorted_methods(): + from tensordict.tensorclass import ( + _FALLBACK_METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + ) + + lists_to_check = [ + _FALLBACK_METHOD_FROM_TD_NOWRAP, + _METHOD_FROM_TD, + _FALLBACK_METHOD_FROM_TD_FORCE, + _FALLBACK_METHOD_FROM_TD, + ] + # Check that each list is sorted and has unique elements + for lst in lists_to_check: + assert lst == sorted(lst), f"List {lst} is not sorted" + assert len(lst) == len(set(lst)), f"List {lst} has duplicate elements" + # Check that no two lists share any elements + for i, lst1 in enumerate(lists_to_check): + for j, lst2 in enumerate(lists_to_check): + if i != j: + shared_elements = set(lst1) & set(lst2) + assert ( + not shared_elements + ), f"Lists {lst1} and {lst2} share elements: {shared_elements}" + + def _make_data(shape): return MyData( X=torch.rand(*shape), @@ -125,6 +247,12 @@ class MyClass1: MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]), batch_size=[3, 1], ) + assert x.shape == x.batch_size + assert x.batch_size == (3, 1) + assert x.ndim == 2 + assert x.batch_dims == 2 + assert x.numel() == 3 + assert not x.all() assert not x.any() assert isinstance(x.all(), bool)