Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 22, 2024
1 parent 3b99f5e commit 8344930
Show file tree
Hide file tree
Showing 13 changed files with 564 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/tensorclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ Here is an example:
TensorClass
NonTensorData
NonTensorStack
from_dataclass

Auto-casting
------------
Expand Down
1 change: 1 addition & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import (
from_dataclass,
NonTensorData,
NonTensorStack,
tensorclass,
Expand Down
26 changes: 24 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,37 @@ def _reduce_get_metadata(self):
@classmethod
def from_dict(
cls,
input_dict,
input_dict: List[Dict[NestedKey, Any]],
*other,
auto_batch_size: bool = False,
batch_size=None,
device=None,
batch_dims=None,
stack_dim_name=None,
stack_dim=0,
):
if batch_size is not None:
batch_size = list(batch_size)
if stack_dim is None:
stack_dim = 0
n = batch_size.pop(stack_dim)
if n != len(input_dict):
raise ValueError(
"The number of dicts and the corresponding batch-size must match, "
f"got len(input_dict)={len(input_dict)} and batch_size[{stack_dim}]={n}."
)
batch_size = torch.Size(batch_size)
return LazyStackedTensorDict(
*(input_dict[str(i)] for i in range(len(input_dict))),
*(
TensorDict.from_dict(
input_dict[str(i)],
*other,
auto_batch_size=auto_batch_size,
device=device,
batch_dims=batch_dims,
)
for i in range(len(input_dict))
),
stack_dim=stack_dim,
stack_dim_name=stack_dim_name,
)
Expand Down
150 changes: 136 additions & 14 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1957,8 +1957,46 @@ def _unsqueeze(tensor):

@classmethod
def from_dict(
cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None
cls,
input_dict,
*others,
auto_batch_size: bool | None = None,
batch_size=None,
device=None,
batch_dims=None,
names=None,
):
if others:
if batch_size is not None:
raise TypeError(
"conflicting batch size values. Please use the keyword argument only."
)
if device is not None:
raise TypeError(
"conflicting device values. Please use the keyword argument only."
)
if batch_dims is not None:
raise TypeError(
"conflicting batch_dims values. Please use the keyword argument only."
)
if names is not None:
raise TypeError(
"conflicting names values. Please use the keyword argument only."
)
warn(
"All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.",
category=DeprecationWarning,
)
batch_size, *others = others
if len(others):
device, *others = others
if len(others):
batch_dims, *others = others
if len(others):
names, *others = others
if len(others):
raise TypeError("Too many positional arguments.")

if batch_dims is not None and batch_size is not None:
raise ValueError(
"Cannot pass both batch_size and batch_dims to `from_dict`."
Expand All @@ -1967,12 +2005,12 @@ def from_dict(
batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = dict(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value, batch_size=[], device=device, batch_dims=None
)
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_any(
value,
auto_batch_size=False,
)
# regular __init__ breaks because a tensor may have the same batch-size as the tensordict
out = cls(
input_dict,
Expand All @@ -1981,7 +2019,17 @@ def from_dict(
names=names,
)
if batch_size is None:
_set_max_batch_size(out, batch_dims)
if auto_batch_size is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
"will be changed in v0.8 and auto_batch_size will be False onward. "
"To silence this warning, pass auto_batch_size directly.",
category=DeprecationWarning,
)
auto_batch_size = True
if auto_batch_size:
_set_max_batch_size(out, batch_dims)
else:
out.batch_size = batch_size
return out
Expand All @@ -1998,8 +2046,46 @@ def _from_dict_validated(
)

def from_dict_instance(
self, input_dict, batch_size=None, device=None, batch_dims=None, names=None
self,
input_dict,
*others,
auto_batch_size: bool | None = None,
batch_size=None,
device=None,
batch_dims=None,
names=None,
):
if others:
if batch_size is not None:
raise TypeError(
"conflicting batch size values. Please use the keyword argument only."
)
if device is not None:
raise TypeError(
"conflicting device values. Please use the keyword argument only."
)
if batch_dims is not None:
raise TypeError(
"conflicting batch_dims values. Please use the keyword argument only."
)
if names is not None:
raise TypeError(
"conflicting names values. Please use the keyword argument only."
)
warn(
"All positional arguments after filename will be deprecated in v0.8. Please use keyword arguments instead.",
category=DeprecationWarning,
)
batch_size, *others = others
if len(others):
device, *others = others
if len(others):
batch_dims, *others = others
if len(others):
names, *others = others
if len(others):
raise TypeError("Too many positional arguments.")

if batch_dims is not None and batch_size is not None:
raise ValueError(
"Cannot pass both batch_size and batch_dims to `from_dict`."
Expand All @@ -2014,22 +2100,42 @@ def from_dict_instance(
cur_value = self.get(key, None)
if cur_value is not None:
input_dict[key] = cur_value.from_dict_instance(
value, batch_size=[], device=device, batch_dims=None
value,
device=device,
auto_batch_size=auto_batch_size,
)
continue
# we don't know if another tensor of smaller size is coming
# so we can't be sure that the batch-size will still be valid later
input_dict[key] = TensorDict.from_dict(
value, batch_size=[], device=device, batch_dims=None
value,
device=device,
auto_batch_size=auto_batch_size,
)
else:
input_dict[key] = TensorDict.from_any(
value,
auto_batch_size=auto_batch_size,
)

out = TensorDict.from_dict(
input_dict,
batch_size=batch_size_set,
device=device,
names=names,
)
if batch_size is None:
_set_max_batch_size(out, batch_dims)
if auto_batch_size is None:
warn(
"The batch-size was not provided and auto_batch_size isn't set either. "
"Currently, from_dict will call set auto_batch_size=True but this behaviour "
"will be changed in v0.8 and auto_batch_size will be False onward. "
"To silence this warning, pass auto_batch_size directly.",
category=DeprecationWarning,
)
auto_batch_size = True
if auto_batch_size:
_set_max_batch_size(out, batch_dims)
else:
out.batch_size = batch_size
return out
Expand Down Expand Up @@ -3857,7 +3963,14 @@ def expand(self, *args: int, inplace: bool = False) -> T:

@classmethod
def from_dict(
cls, input_dict, batch_size=None, device=None, batch_dims=None, names=None
cls,
input_dict,
*others,
auto_batch_size: bool = False,
batch_size=None,
device=None,
batch_dims=None,
names=None,
):
raise NotImplementedError(f"from_dict not implemented for {cls.__name__}.")

Expand Down Expand Up @@ -4273,6 +4386,12 @@ def _items(
(key, tensordict._get_str(key, NO_DEFAULT))
for key in tensordict._source.keys()
)
from tensordict.persistent import PersistentTensorDict

if isinstance(tensordict, PersistentTensorDict):
return (
(key, tensordict._get_str(key, NO_DEFAULT)) for key in tensordict.keys()
)
raise NotImplementedError(type(tensordict))

def _keys(self) -> _TensorDictKeysView:
Expand Down Expand Up @@ -4697,7 +4816,9 @@ def from_modules(
)


def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=None):
def from_dict(
input_dict, *others, batch_size=None, device=None, batch_dims=None, names=None
):
"""Returns a TensorDict created from a dictionary or another :class:`~.tensordict.TensorDict`.
If ``batch_size`` is not specified, returns the maximum batch size possible.
Expand Down Expand Up @@ -4762,6 +4883,7 @@ def from_dict(input_dict, batch_size=None, device=None, batch_dims=None, names=N
"""
return TensorDict.from_dict(
input_dict,
*others,
batch_size=batch_size,
device=device,
batch_dims=batch_dims,
Expand Down
Loading

0 comments on commit 8344930

Please sign in to comment.