diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 75c326c7a..04fdfb6bb 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -7,9 +7,10 @@ import re import warnings +from collections.abc import MutableSequence from textwrap import indent -from typing import Any, Dict, List, Optional, overload, OrderedDict +from typing import Any, Dict, List, Optional, OrderedDict, overload import torch @@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): log(p(z | x, y)) Args: - *modules (sequence of TensorDictModules): An ordered sequence of - :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`, + *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of + :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`, to be run sequentially. + The modules can be instances of TensorDictModuleBase or any other function that matches this signature. + Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, + and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential. Keyword Args: partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some @@ -794,14 +798,13 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): @overload def __init__( self, - modules: OrderedDict, + modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule], partial_tolerant: bool = False, return_composite: bool | None = None, aggregate_probabilities: bool | None = None, include_sum: bool | None = None, inplace: bool | None = None, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -812,8 +815,7 @@ def __init__( aggregate_probabilities: bool | None = None, include_sum: bool | None = None, inplace: bool | None = None, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -829,7 +831,14 @@ def __init__( "ProbabilisticTensorDictSequential must consist of zero or more " "TensorDictModules followed by a ProbabilisticTensorDictModule" ) - if not return_composite and not isinstance( + self._ordered_dict = False + if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)): + if isinstance(modules[0], OrderedDict): + modules_list = list(modules[0].values()) + self._ordered_dict = True + else: + modules = modules_list = list(modules[0]) + elif not return_composite and not isinstance( modules[-1], (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), ): @@ -838,13 +847,22 @@ def __init__( "an instance of ProbabilisticTensorDictModule or another " "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)." ) + else: + modules_list = list(modules) + # if the modules not including the final probabilistic module return the sampled # key we won't be sampling it again, in that case # ProbabilisticTensorDictSequential is presumably used to return the # distribution using `get_dist` or to sample log_probabilities - _, out_keys = self._compute_in_and_out_keys(modules[:-1]) - self._requires_sample = modules[-1].out_keys[0] not in set(out_keys) - self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) + _, out_keys = self._compute_in_and_out_keys(modules_list[:-1]) + self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys) + if self._ordered_dict: + self.__dict__["_det_part"] = TensorDictSequential( + OrderedDict(list(modules[0].items())[:-1]) + ) + else: + self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) + super().__init__(*modules, partial_tolerant=partial_tolerant) self.return_composite = return_composite self.aggregate_probabilities = aggregate_probabilities @@ -885,7 +903,7 @@ def get_dist_params( tds = self.det_part type = interaction_type() if type is None: - for m in reversed(self.module): + for m in reversed(list(self._module_iter())): if hasattr(m, "default_interaction_type"): type = m.default_interaction_type break @@ -897,7 +915,7 @@ def get_dist_params( @property def num_samples(self): num_samples = () - for tdm in self.module: + for tdm in self._module_iter(): if isinstance( tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) ): @@ -941,7 +959,7 @@ def get_dist( td_copy = tensordict.copy() dists = {} - for i, tdm in enumerate(self.module): + for i, tdm in enumerate(self._module_iter()): if isinstance( tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) ): @@ -981,12 +999,21 @@ def default_interaction_type(self): encountered is returned. If no such value is found, a default `interaction_type()` is returned. """ - for m in reversed(self.module): + for m in reversed(list(self._module_iter())): interaction = getattr(m, "default_interaction_type", None) if interaction is not None: return interaction return interaction_type() + @property + def _last_module(self): + if not self._ordered_dict: + return self.module[-1] + mod = None + for mod in self._module_iter(): # noqa: B007 + continue + return mod + def log_prob( self, tensordict, @@ -1103,7 +1130,7 @@ def log_prob( include_sum=include_sum, **kwargs, ) - last_module: ProbabilisticTensorDictModule = self.module[-1] + last_module: ProbabilisticTensorDictModule = self._last_module out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs) if is_tensor_collection(out): if tensordict_out is not None: @@ -1162,7 +1189,7 @@ def forward( else: tensordict_exec = tensordict if self.return_composite: - for m in self.module: + for m in self._module_iter(): if isinstance( m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule) ): @@ -1173,7 +1200,7 @@ def forward( tensordict_exec = m(tensordict_exec, **kwargs) else: tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs) - tensordict_exec = self.module[-1]( + tensordict_exec = self._last_module( tensordict_exec, _requires_sample=self._requires_sample ) if tensordict_out is not None: diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index 3f20b4195..adb2ff314 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -53,14 +53,18 @@ class TensorDictSequential(TensorDictModule): buffers) will be concatenated in a single list. Args: - modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. + modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]): + ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase. + These can be instances of TensorDictModuleBase or any other function that matches this signature. + Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, + and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential. Keyword Args: partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts - looking for those that have the required keys, if any. + looking for those that have the required keys, if any. Defaults to False. selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all ``out_keys`` will be written. diff --git a/test/test_nn.py b/test/test_nn.py index 582d3a147..948af946b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -480,6 +480,9 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): in_keys = ["in"] net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys) + corr = TensorDictModule( + lambda low: max_dist - low.abs(), in_keys=out_keys, out_keys=out_keys + ) kwargs = { "distribution_class": distributions.Uniform, @@ -494,7 +497,7 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): in_keys=dist_in_keys, out_keys=["out"], **kwargs ) - tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) + tensordict_module = ProbabilisticTensorDictSequential(net, corr, prob_module) assert tensordict_module.default_interaction_type is not None td = TensorDict({"in": torch.randn(3, 3)}, [3]) @@ -2156,6 +2159,8 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key): in_keys=[("data", "states")], out_keys=[("data", "scale")], ) + scale_module.module.weight.data.abs_() + scale_module.module.bias.data.abs_() td = TensorDict( {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] ) @@ -3019,7 +3024,8 @@ def test_prob_module_nested(self, interaction, map_names): "interaction", [InteractionType.MODE, InteractionType.MEAN] ) @pytest.mark.parametrize("return_log_prob", [True, False]) - def test_prob_module_seq(self, interaction, return_log_prob): + @pytest.mark.parametrize("ordereddict", [True, False]) + def test_prob_module_seq(self, interaction, return_log_prob, ordereddict): params = TensorDict( { "params": { @@ -3042,7 +3048,7 @@ def test_prob_module_seq(self, interaction, return_log_prob): ("nested", "cont"): distributions.Normal, } backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[]) - module = ProbabilisticTensorDictSequential( + args = [ backbone, ProbabilisticTensorDictModule( in_keys=in_keys, @@ -3052,7 +3058,15 @@ def test_prob_module_seq(self, interaction, return_log_prob): default_interaction_type=interaction, return_log_prob=return_log_prob, ), - ) + ] + if ordereddict: + args = [ + OrderedDict( + backbone=args[0], + proba=args[1], + ) + ] + module = ProbabilisticTensorDictSequential(*args) sample = module(params) if return_log_prob: assert "cont_log_prob" in sample.keys()