From 7df2062625ff45cc600e20981b0c68b21913f3f7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 11:14:43 +0000 Subject: [PATCH] [Feature] OrderedDict for TensorDictSequential ghstack-source-id: a8aed1eaefe066dafaa974f5b96190860de2f8f1 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1142 --- tensordict/nn/probabilistic.py | 81 +++++++++++++++++++++++------ tensordict/nn/sequence.py | 93 +++++++++++++++++++++++++++------- test/test_nn.py | 75 +++++++++++++++++++++++++-- 3 files changed, 212 insertions(+), 37 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 28f0ce2de..04fdfb6bb 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -7,9 +7,10 @@ import re import warnings +from collections.abc import MutableSequence from textwrap import indent -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, OrderedDict, overload import torch @@ -621,9 +622,12 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): log(p(z | x, y)) Args: - *modules (sequence of TensorDictModules): An ordered sequence of - :class:`~tensordict.nn.TensorDictModule` instances, terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`, + *modules (sequence or OrderedDict of TensorDictModuleBase or ProbabilisticTensorDictModule): An ordered sequence of + :class:`~tensordict.nn.TensorDictModule` instances, usually terminating in a :class:`~tensordict.nn.ProbabilisticTensorDictModule`, to be run sequentially. + The modules can be instances of TensorDictModuleBase or any other function that matches this signature. + Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, + and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential. Keyword Args: partial_tolerant (bool, optional): If ``True``, the input tensordict can miss some @@ -791,6 +795,28 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): """ + @overload + def __init__( + self, + modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule], + partial_tolerant: bool = False, + return_composite: bool | None = None, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, + ) -> None: ... + + @overload + def __init__( + self, + modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule], + partial_tolerant: bool = False, + return_composite: bool | None = None, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, + ) -> None: ... + def __init__( self, *modules: TensorDictModuleBase | ProbabilisticTensorDictModule, @@ -805,7 +831,14 @@ def __init__( "ProbabilisticTensorDictSequential must consist of zero or more " "TensorDictModules followed by a ProbabilisticTensorDictModule" ) - if not return_composite and not isinstance( + self._ordered_dict = False + if len(modules) == 1 and isinstance(modules[0], (OrderedDict, MutableSequence)): + if isinstance(modules[0], OrderedDict): + modules_list = list(modules[0].values()) + self._ordered_dict = True + else: + modules = modules_list = list(modules[0]) + elif not return_composite and not isinstance( modules[-1], (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), ): @@ -814,13 +847,22 @@ def __init__( "an instance of ProbabilisticTensorDictModule or another " "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)." ) + else: + modules_list = list(modules) + # if the modules not including the final probabilistic module return the sampled - # key we wont be sampling it again, in that case + # key we won't be sampling it again, in that case # ProbabilisticTensorDictSequential is presumably used to return the # distribution using `get_dist` or to sample log_probabilities - _, out_keys = self._compute_in_and_out_keys(modules[:-1]) - self._requires_sample = modules[-1].out_keys[0] not in set(out_keys) - self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) + _, out_keys = self._compute_in_and_out_keys(modules_list[:-1]) + self._requires_sample = modules_list[-1].out_keys[0] not in set(out_keys) + if self._ordered_dict: + self.__dict__["_det_part"] = TensorDictSequential( + OrderedDict(list(modules[0].items())[:-1]) + ) + else: + self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) + super().__init__(*modules, partial_tolerant=partial_tolerant) self.return_composite = return_composite self.aggregate_probabilities = aggregate_probabilities @@ -861,7 +903,7 @@ def get_dist_params( tds = self.det_part type = interaction_type() if type is None: - for m in reversed(self.module): + for m in reversed(list(self._module_iter())): if hasattr(m, "default_interaction_type"): type = m.default_interaction_type break @@ -873,7 +915,7 @@ def get_dist_params( @property def num_samples(self): num_samples = () - for tdm in self.module: + for tdm in self._module_iter(): if isinstance( tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) ): @@ -917,7 +959,7 @@ def get_dist( td_copy = tensordict.copy() dists = {} - for i, tdm in enumerate(self.module): + for i, tdm in enumerate(self._module_iter()): if isinstance( tdm, (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential) ): @@ -957,12 +999,21 @@ def default_interaction_type(self): encountered is returned. If no such value is found, a default `interaction_type()` is returned. """ - for m in reversed(self.module): + for m in reversed(list(self._module_iter())): interaction = getattr(m, "default_interaction_type", None) if interaction is not None: return interaction return interaction_type() + @property + def _last_module(self): + if not self._ordered_dict: + return self.module[-1] + mod = None + for mod in self._module_iter(): # noqa: B007 + continue + return mod + def log_prob( self, tensordict, @@ -1079,7 +1130,7 @@ def log_prob( include_sum=include_sum, **kwargs, ) - last_module: ProbabilisticTensorDictModule = self.module[-1] + last_module: ProbabilisticTensorDictModule = self._last_module out = last_module.log_prob(tensordict_inp, dist=dist, **kwargs) if is_tensor_collection(out): if tensordict_out is not None: @@ -1138,7 +1189,7 @@ def forward( else: tensordict_exec = tensordict if self.return_composite: - for m in self.module: + for m in self._module_iter(): if isinstance( m, (ProbabilisticTensorDictModule, ProbabilisticTensorDictModule) ): @@ -1149,7 +1200,7 @@ def forward( tensordict_exec = m(tensordict_exec, **kwargs) else: tensordict_exec = self.get_dist_params(tensordict_exec, **kwargs) - tensordict_exec = self.module[-1]( + tensordict_exec = self._last_module( tensordict_exec, _requires_sample=self._requires_sample ) if tensordict_out is not None: diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index faa2f60a0..adb2ff314 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -5,9 +5,10 @@ from __future__ import annotations +import collections import logging from copy import deepcopy -from typing import Any, Iterable, List +from typing import Any, Callable, Iterable, List, OrderedDict, overload from tensordict._nestedkey import NestedKey @@ -52,14 +53,18 @@ class TensorDictSequential(TensorDictModule): buffers) will be concatenated in a single list. Args: - modules (iterable of TensorDictModules): ordered sequence of TensorDictModule instances to be run sequentially. + modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]): + ordered sequence of callables that take a TensorDictBase as input and return a TensorDictBase. + These can be instances of TensorDictModuleBase or any other function that matches this signature. + Note that if a non-TensorDictModuleBase callable is used, its input and output keys will not be tracked, + and thus will not affect the `in_keys` and `out_keys` attributes of the TensorDictSequential. Keyword Args: partial_tolerant (bool, optional): if True, the input tensordict can miss some of the input keys. If so, the only module that will be executed are those who can be executed given the keys that are present. Also, if the input tensordict is a lazy stack of tensordicts AND if partial_tolerant is :obj:`True` AND if the stack does not have the required keys, then TensorDictSequential will scan through the sub-tensordicts - looking for those that have the required keys, if any. + looking for those that have the required keys, if any. Defaults to False. selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all ``out_keys`` will be written. @@ -170,19 +175,57 @@ class TensorDictSequential(TensorDictModule): module: nn.ModuleList _select_before_return = False + @overload def __init__( self, - *modules: TensorDictModuleBase, + modules: OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]], + *, + partial_tolerant: bool = False, + selected_out_keys: List[NestedKey] | None = None, + ) -> None: ... + + @overload + def __init__( + self, + modules: List[Callable[[TensorDictBase], TensorDictBase]], + *, + partial_tolerant: bool = False, + selected_out_keys: List[NestedKey] | None = None, + ) -> None: ... + + def __init__( + self, + *modules: Callable[[TensorDictBase], TensorDictBase], partial_tolerant: bool = False, selected_out_keys: List[NestedKey] | None = None, ) -> None: - modules = self._convert_modules(modules) - in_keys, out_keys = self._compute_in_and_out_keys(modules) - self._complete_out_keys = list(out_keys) - super().__init__( - module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys - ) + if len(modules) == 1 and isinstance(modules[0], collections.OrderedDict): + modules_vals = self._convert_modules(modules[0].values()) + in_keys, out_keys = self._compute_in_and_out_keys(modules_vals) + self._complete_out_keys = list(out_keys) + modules = collections.OrderedDict( + **{key: val for key, val in zip(modules[0], modules_vals)} + ) + super().__init__( + module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys + ) + elif len(modules) == 1 and isinstance( + modules[0], collections.abc.MutableSequence + ): + modules = self._convert_modules(modules[0]) + in_keys, out_keys = self._compute_in_and_out_keys(modules) + self._complete_out_keys = list(out_keys) + super().__init__( + module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys + ) + else: + modules = self._convert_modules(modules) + in_keys, out_keys = self._compute_in_and_out_keys(modules) + self._complete_out_keys = list(out_keys) + super().__init__( + module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys + ) self.partial_tolerant = partial_tolerant if selected_out_keys: @@ -408,7 +451,7 @@ def select_subsequence( out_keys = deepcopy(self.out_keys) out_keys = unravel_key_list(out_keys) - module_list = list(self.module) + module_list = list(self._module_iter()) id_to_keep = set(range(len(module_list))) for i, module in enumerate(module_list): if ( @@ -445,8 +488,14 @@ def select_subsequence( raise ValueError( "No modules left after selection. Make sure that in_keys and out_keys are coherent." ) - - return type(self)(*modules) + if isinstance(self.module, nn.ModuleList): + return type(self)(*modules) + else: + keys = [key for key in self.module if self.module[key] in modules] + modules_dict = collections.OrderedDict( + **{key: val for key, val in zip(keys, modules)} + ) + return type(self)(modules_dict) def _run_module( self, @@ -466,6 +515,12 @@ def _run_module( module(sub_td, **kwargs) return tensordict + def _module_iter(self): + if isinstance(self.module, nn.ModuleDict): + yield from self.module.children() + else: + yield from self.module + @dispatch(auto_batch_size=False) @_set_skip_existing_None() def forward( @@ -481,7 +536,7 @@ def forward( else: tensordict_exec = tensordict if not len(kwargs): - for module in self.module: + for module in self._module_iter(): tensordict_exec = self._run_module(module, tensordict_exec, **kwargs) else: raise RuntimeError( @@ -510,14 +565,16 @@ def forward( def __len__(self) -> int: return len(self.module) - def __getitem__(self, index: int | slice) -> TensorDictModuleBase: - if isinstance(index, int): + def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase: + if isinstance(index, (int, str)): return self.module.__getitem__(index) else: return type(self)(*self.module.__getitem__(index)) - def __setitem__(self, index: int, tensordict_module: TensorDictModuleBase) -> None: + def __setitem__( + self, index: int | slice | str, tensordict_module: TensorDictModuleBase + ) -> None: return self.module.__setitem__(idx=index, module=tensordict_module) - def __delitem__(self, index: int | slice) -> None: + def __delitem__(self, index: int | slice | str) -> None: self.module.__delitem__(idx=index) diff --git a/test/test_nn.py b/test/test_nn.py index 630b8d3d2..6a07ed482 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10,6 +10,7 @@ import pickle import unittest import weakref +from collections import OrderedDict import pytest import torch @@ -450,6 +451,9 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): in_keys = ["in"] net = TensorDictModule(module=net, in_keys=in_keys, out_keys=out_keys) + corr = TensorDictModule( + lambda low: max_dist - low.abs(), in_keys=out_keys, out_keys=out_keys + ) kwargs = { "distribution_class": distributions.Uniform, @@ -464,7 +468,7 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist): in_keys=dist_in_keys, out_keys=["out"], **kwargs ) - tensordict_module = ProbabilisticTensorDictSequential(net, prob_module) + tensordict_module = ProbabilisticTensorDictSequential(net, corr, prob_module) assert tensordict_module.default_interaction_type is not None td = TensorDict({"in": torch.randn(3, 3)}, [3]) @@ -797,6 +801,58 @@ def test_tdmodule_inplace(self): class TestTDSequence: + def test_ordered_dict(self): + linear = nn.Linear(3, 4) + linear.weight.data.fill_(0) + linear.bias.data.fill_(1) + layer0 = TensorDictModule(linear, in_keys=["x"], out_keys=["y"]) + ordered_dict = OrderedDict( + layer0=layer0, + layer1=lambda x: x + 1, + ) + seq = TensorDictSequential(ordered_dict) + td = seq(TensorDict(x=torch.ones(3))) + assert (td["x"] == 2).all() + assert (td["y"] == 2).all() + assert seq["layer0"] is layer0 + + def test_ordered_dict_select_subsequence(self): + ordered_dict = OrderedDict( + layer0=TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]), + layer1=TensorDictModule(lambda x: x - 1, in_keys=["y"], out_keys=["z"]), + layer2=TensorDictModule( + lambda x, y: x + y, in_keys=["x", "y"], out_keys=["a"] + ), + ) + seq = TensorDictSequential(ordered_dict) + assert len(seq) == 3 + assert isinstance(seq.module, nn.ModuleDict) + seq_select = seq.select_subsequence(out_keys=["a"]) + assert len(seq_select) == 2 + assert isinstance(seq_select.module, nn.ModuleDict) + assert list(seq_select.module) == ["layer0", "layer2"] + + def test_ordered_dict_select_outkeys(self): + ordered_dict = OrderedDict( + layer0=TensorDictModule( + lambda x: x + 1, in_keys=["x"], out_keys=["intermediate"] + ), + layer1=TensorDictModule( + lambda x: x - 1, in_keys=["intermediate"], out_keys=["z"] + ), + layer2=TensorDictModule( + lambda x, y: x + y, in_keys=["x", "z"], out_keys=["a"] + ), + ) + seq = TensorDictSequential(ordered_dict) + assert len(seq) == 3 + assert isinstance(seq.module, nn.ModuleDict) + seq.select_out_keys("z", "a") + td = seq(TensorDict(x=0)) + assert "intermediate" not in td + assert "z" in td + assert "a" in td + @pytest.mark.parametrize("args", [True, False]) def test_input_keys(self, args): module0 = TensorDictModule(lambda x: x + 0, in_keys=["input"], out_keys=["1"]) @@ -2074,6 +2130,8 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key): in_keys=[("data", "states")], out_keys=[("data", "scale")], ) + scale_module.module.weight.data.abs_() + scale_module.module.bias.data.abs_() td = TensorDict( {"data": TensorDict({"states": torch.zeros(3, 4, 1)}, [3, 4])}, [3] ) @@ -2937,7 +2995,8 @@ def test_prob_module_nested(self, interaction, map_names): "interaction", [InteractionType.MODE, InteractionType.MEAN] ) @pytest.mark.parametrize("return_log_prob", [True, False]) - def test_prob_module_seq(self, interaction, return_log_prob): + @pytest.mark.parametrize("ordereddict", [True, False]) + def test_prob_module_seq(self, interaction, return_log_prob, ordereddict): params = TensorDict( { "params": { @@ -2960,7 +3019,7 @@ def test_prob_module_seq(self, interaction, return_log_prob): ("nested", "cont"): distributions.Normal, } backbone = TensorDictModule(lambda: None, in_keys=[], out_keys=[]) - module = ProbabilisticTensorDictSequential( + args = [ backbone, ProbabilisticTensorDictModule( in_keys=in_keys, @@ -2970,7 +3029,15 @@ def test_prob_module_seq(self, interaction, return_log_prob): default_interaction_type=interaction, return_log_prob=return_log_prob, ), - ) + ] + if ordereddict: + args = [ + OrderedDict( + backbone=args[0], + proba=args[1], + ) + ] + module = ProbabilisticTensorDictSequential(*args) sample = module(params) if return_log_prob: assert "cont_log_prob" in sample.keys()