diff --git a/docs/source/reference/tensordict.rst b/docs/source/reference/tensordict.rst index c843ad986..d6e941c7e 100644 --- a/docs/source/reference/tensordict.rst +++ b/docs/source/reference/tensordict.rst @@ -32,6 +32,7 @@ or ``cat``. cat from_consolidated + from_any from_dict from_h5 from_module @@ -39,6 +40,7 @@ or ``cat``. from_namedtuple from_pytree from_struct_array + from_tuple fromkeys is_batchedtensor lazy_stack diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 7fc9d349d..c339365be 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -9,13 +9,9 @@ from tensordict._td import ( cat, from_consolidated, - from_dict, - from_h5, from_module, from_modules, - from_namedtuple, from_pytree, - from_struct_array, fromkeys, is_tensor_collection, lazy_stack, @@ -29,6 +25,12 @@ ) from tensordict.base import ( + from_any, + from_dict, + from_h5, + from_namedtuple, + from_struct_array, + from_tuple, get_defaults_to_none, set_get_defaults_to_none, TensorDictBase, diff --git a/tensordict/_td.py b/tensordict/_td.py index 5672fc070..c50fb437a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2050,6 +2050,8 @@ def from_dict( input_dict[key] = TensorDict.from_any( value, auto_batch_size=False, + device=device, + batch_size=batch_size, ) # regular __init__ breaks because a tensor may have the same batch-size as the tensordict out = cls( @@ -4863,139 +4865,6 @@ def from_modules( ) -def from_dict( - input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None -): - """Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`. - - If ``batch_size`` is not specified, returns the maximum batch size possible. - - This function works on nested dictionaries too, or can be used to determine the - batch-size of a nested tensordict. - - Args: - input_dict (dictionary, optional): a dictionary to use as a data source - (nested keys compatible). - batch_size (iterable of int, optional): a batch size for the tensordict. - device (torch.device or compatible type, optional): a device for the TensorDict. - batch_dims (int, optional): the ``batch_dims`` (ie number of leading dimensions - to be considered for ``batch_size``). Exclusinve with ``batch_size``. - Note that this is the __maximum__ number of batch dims of the tensordict, - a smaller number is tolerated. - names (list of str, optional): the dimension names of the tensordict. - - Examples: - >>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} - >>> print(from_dict(input_dict)) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), - b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - >>> # nested dict: the nested TensorDict can have a different batch-size - >>> # as long as its leading dims match. - >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} - >>> print(from_dict(input_dict)) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), - b: TensorDict( - fields={ - c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([3, 4]), - device=None, - is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - >>> # we can also use this to work out the batch sie of a tensordict - >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) - >>> print( - from_dict(input_td)) - TensorDict( - fields={ - a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), - b: TensorDict( - fields={ - c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([3, 4]), - device=None, - is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - - """ - return TensorDict.from_dict( - input_dict, - *others, - batch_size=batch_size, - device=device, - batch_dims=batch_dims, - names=names, - ) - - -def from_namedtuple(named_tuple, *, auto_batch_size: bool = False): - """Converts a namedtuple to a TensorDict recursively. - - Keyword Args: - auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. - Defaults to ``False``. - - Examples: - >>> from tensordict import TensorDict, from_namedtuple - >>> import torch - >>> data = TensorDict({ - ... "a_tensor": torch.zeros((3)), - ... "nested": {"a_tensor": torch.zeros((3)), "a_string": "zero!"}}, [3]) - >>> nt = data.to_namedtuple() - >>> print(nt) - GenericDict(a_tensor=tensor([0., 0., 0.]), nested=GenericDict(a_tensor=tensor([0., 0., 0.]), a_string='zero!')) - >>> from_namedtuple(nt, auto_batch_size=True) - TensorDict( - fields={ - a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), - nested: TensorDict( - fields={ - a_string: NonTensorData(data=zero!, batch_size=torch.Size([3]), device=None), - a_tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - - """ - return TensorDict.from_namedtuple(named_tuple, auto_batch_size=auto_batch_size) - - -def from_struct_array(struct_array: np.ndarray, device: torch.device | None = None): - """Converts a structured numpy array to a TensorDict. - - The content of the resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy - operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict. - - Examples: - >>> x = np.array( - ... [("Rex", 9, 81.0), ("Fido", 3, 27.0)], - ... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], - ... ) - >>> td = from_struct_array(x) - >>> x_recon = td.to_struct_array() - >>> assert (x_recon == x).all() - >>> assert x_recon.shape == x.shape - >>> # Try modifying x age field and check effect on td - >>> x["age"] += 1 - >>> assert (td["age"] == np.array([10, 4])).all() - - """ - return TensorDict.from_struct_array(struct_array, device=device) - - def from_pytree( pytree, *, @@ -5060,22 +4929,6 @@ def from_pytree( ) -def from_h5( - filename, - mode="r", -): - """Creates a PersistentTensorDict from a h5 file. - - This function will automatically determine the batch-size for each nested - tensordict. - - Args: - filename (str): the path to the h5 file. - mode (str, optional): reading mode. Defaults to ``"r"``. - """ - return TensorDict.from_h5(filename, mode="r") - - def stack(input, dim=0, *, out=None): """Stacks tensordicts into a single tensordict along the given dimension. diff --git a/tensordict/base.py b/tensordict/base.py index 8b86cb250..007189df7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1412,19 +1412,52 @@ def to_pytree(self): raise NotImplementedError(f"unknown type {_pytree_type}") @classmethod - def from_h5(cls, filename, mode="r"): + def from_h5( + cls, + filename, + *, + mode: str = "r", + auto_batch_size: bool = False, + batch_dims: int | None = None, + batch_size: torch.Size | None = None, + ): """Creates a PersistentTensorDict from a h5 file. - This function will automatically determine the batch-size for each nested - tensordict. - Args: - filename (str): the path to the h5 file. - mode (str, optional): reading mode. Defaults to ``"r"``. + filename (str): The path to the h5 file. + + Keword Arguments: + mode (str, optional): Reading mode. Defaults to ``"r"``. + auto_batch_size (bool, optional): If ``True``, the batch size will be computed automatically. + Defaults to ``False``. + batch_dims (int, optional): If auto_batch_size is ``True``, defines how many dimensions the output + tensordict should have. Defaults to ``None`` (full batch-size at each level). + batch_size (torch.Size, optional): The batch size of the TensorDict. Defaults to ``None``. + + Returns: + A PersistentTensorDict representation of the input h5 file. + + Examples: + >>> td = TensorDict.from_h5("path/to/file.h5") + >>> print(td) + PersistentTensorDict( + fields={ + key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) """ from tensordict.persistent import PersistentTensorDict - return PersistentTensorDict.from_h5(filename, mode=mode) + result = PersistentTensorDict.from_h5( + filename, mode=mode, batch_size=batch_size + ) + if auto_batch_size: + if batch_size is not None: + raise TypeError(cls._CONFLICTING_BATCH_SIZES.format("from_h5")) + result.auto_batch_size_(batch_dims=batch_dims) + return result # Module interaction @classmethod @@ -9968,22 +10001,47 @@ def dict_to_namedtuple(dictionary): return dict_to_namedtuple(self.to_dict(retain_none=False)) @classmethod - def from_any(cls, obj, *, auto_batch_size: bool = False): - """Converts any object to a TensorDict, recursively. + def from_any( + cls, + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ): + """Recursively converts any object to a TensorDict. + + .. note:: ``from_any`` is less restrictive than the regular TensorDict constructor. It can cast data structures like + dataclasses or tuples to a tensordict using custom heuristics. This approach may incur some extra overhead and + involves more opinionated choices in terms of mapping strategies. + + .. note:: This method recursively converts the input object to a TensorDict. If the object is already a + TensorDict (or any similar tensor collection object), it will be returned as is. + + Args: + obj: The object to be converted. Keyword Args: auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. Defaults to ``False``. + batch_dims (int, optional): If auto_batch_size is ``True``, defines how many dimensions the output tensordict + should have. Defaults to ``None`` (full batch-size at each level). + device (torch.device, optional): The device on which the TensorDict will be created. + batch_size (torch.Size, optional): The batch size of the TensorDict. + Exclusive with ``auto_batch_size``. - Support includes: + Returns: + A TensorDict representation of the input object. + + Supported objects: - - dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not - tensorclasses). - - namedtuple through :meth:`~.from_namedtuple` - - dict through :meth:`~.from_dict` - - tuple through :meth:`~.from_tuple` - - numpy's structured arrays through :meth:`~.from_struct_array` - - h5 objects through :meth:`~.from_h5` + - Dataclasses through :meth:`~.from_dataclass` (dataclasses will be converted to TensorDict instances, not tensorclasses). + - Namedtuples through :meth:`~.from_namedtuple`. + - Dictionaries through :meth:`~.from_dict`. + - Tuples through :meth:`~.from_tuple`. + - NumPy's structured arrays through :meth:`~.from_struct_array`. + - HDF5 objects through :meth:`~.from_h5`. """ if is_tensor_collection(obj): @@ -9995,15 +10053,45 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): # return cls.from_any(obj.data, auto_batch_size=auto_batch_size) return obj if isinstance(obj, dict): - return cls.from_dict(obj, auto_batch_size=auto_batch_size) + return cls.from_dict( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) if isinstance(obj, UserDict): - return cls.from_dict(dict(obj), auto_batch_size=auto_batch_size) + return cls.from_dict( + dict(obj), + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): - return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) + return cls.from_struct_array( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) if isinstance(obj, tuple): if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) - return cls.from_tuple(obj, auto_batch_size=auto_batch_size) + return cls.from_namedtuple( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + return cls.from_tuple( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) if isinstance(obj, list): if _is_list_tensor_compatible(obj)[0]: return torch.tensor(obj) @@ -10012,7 +10100,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return NonTensorStack.from_list(obj) if is_dataclass(obj): - return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) + return cls.from_dataclass( + obj, + auto_batch_size=auto_batch_size, + device=device, + batch_size=batch_size, + ) if _has_h5: import h5py @@ -10026,17 +10119,72 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return obj @classmethod - def from_tuple(cls, obj, *, auto_batch_size: bool = False): + def from_tuple( + cls, + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ): + """Converts a tuple to a TensorDict. + + Args: + obj: The tuple instance to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, the batch size will be computed automatically. Defaults to ``False``. + batch_dims (int, optional): If auto_batch_size is ``True``, defines how many dimensions the output tensordict + should have. Defaults to ``None`` (full batch-size at each level). + device (torch.device, optional): The device on which the TensorDict will be created. Defaults to ``None``. + batch_size (torch.Size, optional): The batch size of the TensorDict. Defaults to ``None``. + + Returns: + A TensorDict representation of the input tuple. + + Examples: + >>> my_tuple = (1, 2, 3) + >>> td = TensorDict.from_tuple(my_tuple) + >>> print(td) + TensorDict( + fields={ + 0: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + 1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + 2: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + """ from tensordict import TensorDict - result = TensorDict({str(i): cls.from_any(item) for i, item in enumerate(obj)}) + result = TensorDict( + { + str(i): cls.from_any(item, batch_size=batch_size, device=device) + for i, item in enumerate(obj) + }, + batch_size=batch_size, + device=device, + ) if auto_batch_size: - result.auto_batch_size_() + if batch_size is not None: + raise TypeError(cls._CONFLICTING_BATCH_SIZES.format("from_tuple")) + result.auto_batch_size_(batch_dims=batch_dims) return result + _CONFLICTING_BATCH_SIZES = "Conflicting batch sizes in {}: batch_size and auto_batch_size cannot be both specified." + @classmethod def from_dataclass( - cls, dataclass, *, auto_batch_size: bool = False, as_tensorclass: bool = False + cls, + dataclass, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + as_tensorclass: bool = False, + device: torch.device | None = None, + batch_size: torch.Size | None = None, ): """Converts a dataclass into a TensorDict instance. @@ -10046,9 +10194,15 @@ def from_dataclass( Keyword Args: auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting TensorDict. Defaults to ``False``. + batch_dims (int, optional): If ``auto_batch_size`` is ``True``, defines how many dimensions the output + tensordict should have. Defaults to ``None`` (full batch-size at each level). as_tensorclass (bool, optional): If ``True``, delegates the conversion to the free function - :func:`~tensordict.from_dataclass` and returns a tensor-compatible class - (:func:`~tensordict.tensorclass`) or instance instead of a ``TensorDict``. Defaults to ``False``. + :func:`~tensordict.from_dataclass` and returns a tensor-compatible class (:func:`~tensordict.tensorclass`) + or instance instead of a TensorDict. Defaults to ``False``. + device (torch.device, optional): The device on which the TensorDict will be created. + Defaults to ``None``. + batch_size (torch.Size, optional): The batch size of the TensorDict. + Defaults to ``None``. Returns: A TensorDict instance derived from the provided dataclass, unless `as_tensorclass` is True, in which case a tensor-compatible class or instance is returned. @@ -10081,19 +10235,43 @@ def from_dataclass( ) source = {} for field in fields(dataclass): - source[field.name] = cls.from_any(getattr(dataclass, field.name)) - result = TensorDict(source) + source[field.name] = cls.from_any( + getattr(dataclass, field.name), device=device, batch_size=batch_size + ) + result = TensorDict(source, device=device, batch_size=batch_size) if auto_batch_size: - result.auto_batch_size_() + if batch_size is not None: + raise TypeError(cls._CONFLICTING_BATCH_SIZES.format("from_dataclass")) + result.auto_batch_size_(batch_dims=batch_dims) return result @classmethod - def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): + def from_namedtuple( + cls, + named_tuple, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + ): """Converts a namedtuple to a TensorDict recursively. + Args: + named_tuple: The namedtuple instance to be converted. + Keyword Args: auto_batch_size (bool, optional): if ``True``, the batch size will be computed automatically. Defaults to ``False``. + batch_dims (int, optional): If ``auto_batch_size`` is ``True``, defines how many dimensions the output + tensordict should have. Defaults to ``None`` (full batch-size at each level). + device (torch.device, optional): The device on which the TensorDict will be created. + Defaults to ``None``. + batch_size (torch.Size, optional): The batch size of the TensorDict. + Defaults to ``None``. + + Returns: + A TensorDict representation of the input namedtuple. Examples: >>> from tensordict import TensorDict @@ -10135,22 +10313,56 @@ def namedtuple_to_dict(namedtuple_obj): "indices": namedtuple_obj.indices, } for key, value in namedtuple_obj.items(): - namedtuple_obj[key] = cls.from_any(value) + namedtuple_obj[key] = cls.from_any( + value, device=device, batch_size=batch_size + ) return dict(namedtuple_obj) - result = TensorDict(namedtuple_to_dict(named_tuple)) + result = TensorDict( + namedtuple_to_dict(named_tuple), device=device, batch_size=batch_size + ) if auto_batch_size: - result.auto_batch_size_() + if batch_size is not None: + raise TypeError(cls._CONFLICTING_BATCH_SIZES.format("from_namedtuple")) + result.auto_batch_size_(batch_dims=batch_dims) return result @classmethod def from_struct_array( - cls, struct_array: np.ndarray, device: torch.device | None = None + cls, + struct_array: np.ndarray, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, ) -> T: """Converts a structured numpy array to a TensorDict. - The content of the resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy - operation). Changing values of the structured numpy array in-place will affect the content of the TensorDict. + The resulting TensorDict will share the same memory content as the numpy array (it is a zero-copy operation). + Changing values of the structured numpy array in-place will affect the content of the TensorDict. + + .. note:: This method performs a zero-copy operation, meaning that the resulting TensorDict will share the same memory + content as the input numpy array. Therefore, changing values of the numpy array in-place will affect the content + of the TensorDict. + + Args: + struct_array (np.ndarray): The structured numpy array to be converted. + + Keyword Args: + auto_batch_size (bool, optional): If ``True``, the batch size will be computed automatically. Defaults to ``False``. + batch_dims (int, optional): If ``auto_batch_size`` is ``True``, defines how many dimensions the output + tensordict should have. Defaults to ``None`` (full batch-size at each level). + device (torch.device, optional): The device on which the TensorDict will be created. + Defaults to ``None``. + + .. note:: Changing the device (i.e., specifying any device other than ``None`` or ``"cpu"``) will transfer the data, + resulting in a change to the memory location of the returned data. + + batch_size (torch.Size, optional): The batch size of the TensorDict. Defaults to None. + + Returns: + A TensorDict representation of the input structured numpy array. Examples: >>> x = np.array( @@ -10172,18 +10384,35 @@ def from_struct_array( cls = TensorDict td = cls( {name: struct_array[name] for name in struct_array.dtype.names}, - batch_size=struct_array.shape, + batch_size=struct_array.shape if batch_size is None else batch_size, device=device, ) + if auto_batch_size: + if batch_size is not None: + raise TypeError( + cls._CONFLICTING_BATCH_SIZES.format("from_struct_array") + ) + td.auto_batch_size_(batch_dims=batch_dims) return td def to_struct_array(self): """Converts a tensordict to a numpy structured array. - In a :meth:`.from_struct_array` - :meth:`.to_struct_array` loop, the content of the input and output - arrays should match. However, `to_struct_array` will not keep the memory content of the original arrays. + In a :meth:`.from_struct_array` - :meth:`.to_struct_array` loop, the content of the input and output arrays should match. + However, `to_struct_array` will not keep the memory content of the original arrays. + + .. seealso:: :meth:`.from_struct_array` for more information. + + Returns: + A numpy structured array representation of the input TensorDict. - See :meth:`.from_struct_array` for more information. + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> td = TensorDict({'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4.0, 5.0, 6.0])}, batch_size=[3]) + >>> arr = td.to_struct_array() + >>> print(arr) + [(1, 4.) (2, 5.) (3, 6.)] """ from .utils import TORCH_TO_NUMPY_DTYPE_DICT @@ -10218,9 +10447,6 @@ def to_h5( Args: filename (str or path): path to the h5 file. - device (torch.device or compatible, optional): the device where to - expect the tensor once they are returned. Defaults to ``None`` - (on cpu by default). **kwargs: kwargs to be passed to :meth:`h5py.File.create_dataset`. Returns: @@ -11551,3 +11777,188 @@ def _expand_to_match_shape( batch_size = torch.Size([*parent_batch_size, *_shape(data)[self_batch_dims:]]) result = data.empty(batch_size=batch_size) return result + + +def from_any( + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts any object to a TensorDict. + + .. seealso:: :meth:`~tensordict.TensorDictBase.from_any` for more information. + """ + return TensorDictBase.from_any( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + + +def from_tuple( + obj, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts a tuple to a TensorDict. + + .. seealso:: :meth:`TensorDictBase.from_tuple` for more information. + """ + return TensorDictBase.from_tuple( + obj, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + + +def from_namedtuple( + named_tuple, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts a namedtuple to a TensorDict. + + .. seealso:: :meth:`TensorDictBase.from_namedtuple` for more information. + """ + return TensorDictBase.from_namedtuple( + named_tuple, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + + +def from_struct_array( + struct_array, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts a structured numpy array to a TensorDict. + + .. seealso:: :meth:`TensorDictBase.from_struct_array` for more information. + + Examples: + >>> x = np.array( + ... [("Rex", 9, 81.0), ("Fido", 3, 27.0)], + ... dtype=[("name", "U10"), ("age", "i4"), ("weight", "f4")], + ... ) + >>> td = from_struct_array(x) + >>> x_recon = td.to_struct_array() + >>> assert (x_recon == x).all() + >>> assert x_recon.shape == x.shape + >>> # Try modifying x age field and check effect on td + >>> x["age"] += 1 + >>> assert (td["age"] == np.array([10, 4])).all() + + """ + return TensorDictBase.from_struct_array( + struct_array, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + + +def from_dict( + d, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts a dictionary to a TensorDict. + + .. seealso:: :meth:`TensorDictBase.from_dict` for more information. + + + Examples: + >>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} + >>> print(from_dict(input_dict)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> # nested dict: the nested TensorDict can have a different batch-size + >>> # as long as its leading dims match. + >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} + >>> print(from_dict(input_dict)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> # we can also use this to work out the batch sie of a tensordict + >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) + >>> print( + from_dict(input_td)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + + """ + return TensorDictBase.from_dict( + d, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) + + +def from_h5( + h5_file, + *, + auto_batch_size: bool = False, + batch_dims: int | None = None, + device: torch.device | None = None, + batch_size: torch.Size | None = None, +): + """Converts an HDF5 file to a TensorDict. + + .. seealso:: :meth:`TensorDictBase.from_h5` for more information. + """ + return TensorDictBase.from_h5( + h5_file, + auto_batch_size=auto_batch_size, + batch_dims=batch_dims, + device=device, + batch_size=batch_size, + ) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 70d06aa8c..2a3957a93 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -191,19 +191,38 @@ def __init__( self._check_batch_size(self._batch_size) @classmethod - def from_h5(cls, filename, mode="r"): + def from_h5(cls, filename, *, mode="r", batch_size: torch.size | None = None): """Creates a PersistentTensorDict from a h5 file. - This function will automatically determine the batch-size for each nested - tensordict. + This function will automatically determine the batch-size for each nested tensordict (unless ``batch_size`` + is provided). Args: - filename (str): the path to the h5 file. - mode (str, optional): reading mode. Defaults to ``"r"``. + filename (str): The path to the h5 file. + + Keyword Args: + mode (str, optional): Reading mode. Defaults to ``"r"``. + batch_size (torch.Size, optional): The batch size of the TensorDict. Defaults to None (batch-size automatically + determined). + + Returns: + A PersistentTensorDict representation of the input h5 file. + + Examples: + >>> ptd = PersistentTensorDict.from_h5("path/to/file.h5") + >>> print(ptd) + PersistentTensorDict( + fields={ + key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) """ - out = cls(filename=filename, mode=mode, batch_size=[]) - # determine batch size - _set_max_batch_size(out) + out = cls(filename=filename, mode=mode, batch_size=batch_size) + if batch_size is None: + # determine batch size + _set_max_batch_size(out) return out @classmethod diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0ed505279..03e91f147 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -395,10 +395,13 @@ def from_dataclass( obj: Any, *, auto_batch_size: bool = False, + batch_dims: int | None = None, + batch_size: torch.Size | None = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, + device: torch.device | None = None, ) -> Any: """Converts a dataclass instance or a type into a tensorclass instance or type, respectively. @@ -410,11 +413,14 @@ def from_dataclass( Keyword Args: auto_batch_size (bool, optional): If ``True``, automatically determines and applies batch size to the resulting object. Defaults to ``False``. + batch_dims (int, optional): If auto_batch_size is ``True``, defines how many dimensions the output tensordict should have. Defaults to ``None`` (full batch-size at each level). + batch_size (torch.Size, optional): The batch size of the TensorDict. Defaults to ``None``. frozen (bool, optional): If ``True``, the resulting class or instance will be immutable. Defaults to ``False``. autocast (bool, optional): If ``True``, enables automatic type casting for the resulting class or instance. Defaults to ``False``. nocast (bool, optional): If ``True``, disables any type casting for the resulting class or instance. Defaults to ``False``. inplace (bool, optional): If ``True``, the dataclass type passed will be modified in-place. Defaults to ``False``. Without effect if an instance is provided. + device (torch.device, optional): The device on which the TensorDict will be created. Defaults to ``None``. Returns: A tensor-compatible class or instance derived from the provided dataclass. @@ -487,9 +493,13 @@ def from_dataclass( clz._autocast = autocast clz._nocast = nocast clz._frozen = frozen - result = clz(**asdict(obj)) + result = clz(**asdict(obj), batch_size=batch_size, device=device) if auto_batch_size: - result = result.auto_batch_size_() + if batch_size is not None: + raise TypeError( + TensorDictBase._CONFLICTING_BATCH_SIZES.format("from_dataclass") + ) + result = result.auto_batch_size_(batch_dims=batch_dims) return result