Skip to content

Commit

Permalink
[Feature] OrderedDict for TensorDictSequential
Browse files Browse the repository at this point in the history
ghstack-source-id: a8aed1eaefe066dafaa974f5b96190860de2f8f1
Pull Request resolved: #1142
  • Loading branch information
vmoens committed Dec 18, 2024
1 parent 2aea3dd commit 7df2062
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 37 deletions.
81 changes: 66 additions & 15 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import re
import warnings
from collections.abc import MutableSequence

from textwrap import indent
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, OrderedDict, overload

import torch

Expand Down Expand Up @@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
log(p(z | x, y))
Args:
*modules (sequence of TensorDictModules): An ordered sequence of
:class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
*modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of
:class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`,
to be run sequentially.
The modules can be instances of TensorDictModuleBase or any other function that matches this signature.
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
Keyword Args:
partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some
Expand Down Expand Up @@ -791,6 +795,28 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
"""

@overload
def __init__(
self,
modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule],
partial_tolerant: bool = False,
return_composite: bool | None = None,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None: ...

@overload
def __init__(
self,
modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule],
partial_tolerant: bool = False,
return_composite: bool | None = None,
aggregate_probabilities: bool | None = None,
include_sum: bool | None = None,
inplace: bool | None = None,
) -> None: ...

def __init__(
self,
*modules: TensorDictModuleBase | ProbabilisticTensorDictModule,
Expand All @@ -805,7 +831,14 @@ def __init__(
"ProbabilisticTensorDictSequential must consist of zero or more "
"TensorDictModules followed by a ProbabilisticTensorDictModule"
)
if not return_composite and not isinstance(
self._ordered_dict = False
if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)):
if isinstance(modules[0], OrderedDict):
modules_list = list(modules[0].values())
self._ordered_dict = True
else:
modules = modules_list = list(modules[0])
elif not return_composite and not isinstance(
modules[-1],
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
Expand All @@ -814,13 +847,22 @@ def __init__(
"an instance of ProbabilisticTensorDictModule or another "
"ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)."
)
else:
modules_list = list(modules)

# if the modules not including the final probabilistic module return the sampled
# key we wont be sampling it again, in that case
# key we won't be sampling it again, in that case
# ProbabilisticTensorDictSequential is presumably used to return the
# distribution using `get_dist` or to sample log_probabilities
_, out_keys = self._compute_in_and_out_keys(modules[:-1])
self._requires_sample = modules[-1].out_keys[0] not in set(out_keys)
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
_, out_keys = self._compute_in_and_out_keys(modules_list[:-1])
self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys)
if self._ordered_dict:
self.__dict__["_det_part"] = TensorDictSequential(
OrderedDict(list(modules[0].items())[:-1])
)
else:
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])

super().__init__(*modules, partial_tolerant=partial_tolerant)
self.return_composite = return_composite
self.aggregate_probabilities = aggregate_probabilities
Expand Down Expand Up @@ -861,7 +903,7 @@ def get_dist_params(
tds = self.det_part
type = interaction_type()
if type is None:
for m in reversed(self.module):
for m in reversed(list(self._module_iter())):
if hasattr(m, "default_interaction_type"):
type = m.default_interaction_type
break
Expand All @@ -873,7 +915,7 @@ def get_dist_params(
@property
def num_samples(self):
num_samples = ()
for tdm in self.module:
for tdm in self._module_iter():
if isinstance(
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
):
Expand Down Expand Up @@ -917,7 +959,7 @@ def get_dist(

td_copy = tensordict.copy()
dists = {}
for i, tdm in enumerate(self.module):
for i, tdm in enumerate(self._module_iter()):
if isinstance(
tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential)
):
Expand Down Expand Up @@ -957,12 +999,21 @@ def default_interaction_type(self):
encountered is returned. If no such value is found, a default `interaction_type()` is returned.
"""
for m in reversed(self.module):
for m in reversed(list(self._module_iter())):
interaction = getattr(m, "default_interaction_type", None)
if interaction is not None:
return interaction
return interaction_type()

@property
def _last_module(self):
if not self._ordered_dict:
return self.module[-1]
mod = None
for mod in self._module_iter(): # noqa: B007
continue
return mod

def log_prob(
self,
tensordict,
Expand Down Expand Up @@ -1079,7 +1130,7 @@ def log_prob(
include_sum=include_sum,
**kwargs,
)
last_module: ProbabilisticTensorDictModule = self.module[-1]
last_module: ProbabilisticTensorDictModule = self._last_module
out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs)
if is_tensor_collection(out):
if tensordict_out is not None:
Expand Down Expand Up @@ -1138,7 +1189,7 @@ def forward(
else:
tensordict_exec = tensordict
if self.return_composite:
for m in self.module:
for m in self._module_iter():
if isinstance(
m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule)
):
Expand All @@ -1149,7 +1200,7 @@ def forward(
tensordict_exec = m(tensordict_exec, **kwargs)
else:
tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs)
tensordict_exec = self.module[-1](
tensordict_exec = self._last_module(
tensordict_exec, _requires_sample=self._requires_sample
)
if tensordict_out is not None:
Expand Down
93 changes: 75 additions & 18 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from __future__ import annotations

import collections
import logging
from copy import deepcopy
from typing import Any, Iterable, List
from typing import Any, Callable, Iterable, List, OrderedDict, overload

from tensordict._nestedkey import NestedKey

Expand Down Expand Up @@ -52,14 +53,18 @@ class TensorDictSequential(TensorDictModule):
buffers) will be concatenated in a single list.
Args:
modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially.
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]):
ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase.
These can be instances of TensorDictModuleBase or any other function that matches this signature.
Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked,
and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential.
Keyword Args:
partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys.
If so, the only module that will be executed are those who can be executed given the keys that
are present.
Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the
stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts
looking for those that have the required keys, if any.
looking for those that have the required keys, if any. Defaults to False.
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
``out_keys`` will be written.
Expand Down Expand Up @@ -170,19 +175,57 @@ class TensorDictSequential(TensorDictModule):
module: nn.ModuleList
_select_before_return = False

@overload
def __init__(
self,
*modules: TensorDictModuleBase,
modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]],
*,
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None: ...

@overload
def __init__(
self,
modules: List[Callable[[TensorDictBase], TensorDictBase]],
*,
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None: ...

def __init__(
self,
*modules: Callable[[TensorDictBase], TensorDictBase],
partial_tolerant: bool = False,
selected_out_keys: List[NestedKey] | None = None,
) -> None:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)

super().__init__(
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
)
if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict):
modules_vals = self._convert_modules(modules[0].values())
in_keys, out_keys = self._compute_in_and_out_keys(modules_vals)
self._complete_out_keys = list(out_keys)
modules = collections.OrderedDict(
**{key: val for key, val in zip(modules[0], modules_vals)}
)
super().__init__(
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
)
elif len(modules) == 1 and isinstance(
modules[0], collections.abc.MutableSequence
):
modules = self._convert_modules(modules[0])
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
)
else:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
)

self.partial_tolerant = partial_tolerant
if selected_out_keys:
Expand Down Expand Up @@ -408,7 +451,7 @@ def select_subsequence(
out_keys = deepcopy(self.out_keys)
out_keys = unravel_key_list(out_keys)

module_list = list(self.module)
module_list = list(self._module_iter())
id_to_keep = set(range(len(module_list)))
for i, module in enumerate(module_list):
if (
Expand Down Expand Up @@ -445,8 +488,14 @@ def select_subsequence(
raise ValueError(
"No modules left after selection. Make sure that in_keys and out_keys are coherent."
)

return type(self)(*modules)
if isinstance(self.module, nn.ModuleList):
return type(self)(*modules)
else:
keys = [key for key in self.module if self.module[key] in modules]
modules_dict = collections.OrderedDict(
**{key: val for key, val in zip(keys, modules)}
)
return type(self)(modules_dict)

def _run_module(
self,
Expand All @@ -466,6 +515,12 @@ def _run_module(
module(sub_td, **kwargs)
return tensordict

def _module_iter(self):
if isinstance(self.module, nn.ModuleDict):
yield from self.module.children()
else:
yield from self.module

@dispatch(auto_batch_size=False)
@_set_skip_existing_None()
def forward(
Expand All @@ -481,7 +536,7 @@ def forward(
else:
tensordict_exec = tensordict
if not len(kwargs):
for module in self.module:
for module in self._module_iter():
tensordict_exec = self._run_module(module, tensordict_exec, **kwargs)
else:
raise RuntimeError(
Expand Down Expand Up @@ -510,14 +565,16 @@ def forward(
def __len__(self) -> int:
return len(self.module)

def __getitem__(self, index: int | slice) -> TensorDictModuleBase:
if isinstance(index, int):
def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
if isinstance(index, (int, str)):
return self.module.__getitem__(index)
else:
return type(self)(*self.module.__getitem__(index))

def __setitem__(self, index: int, tensordict_module: TensorDictModuleBase) -> None:
def __setitem__(
self, index: int | slice | str, tensordict_module: TensorDictModuleBase
) -> None:
return self.module.__setitem__(idx=index, module=tensordict_module)

def __delitem__(self, index: int | slice) -> None:
def __delitem__(self, index: int | slice | str) -> None:
self.module.__delitem__(idx=index)
Loading

0 comments on commit 7df2062

Please sign in to comment.