From 41755bab3b1248a36bc16c5bcf7105bad2d82c84 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 18 Dec 2024 10:33:11 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- tensordict/nn/probabilistic.py | 28 ++++++++++++++++++++++++++-- tensordict/nn/sequence.py | 10 +++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 28f0ce2de..75c326c7a 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -9,7 +9,7 @@ import warnings from textwrap import indent -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, overload, OrderedDict import torch @@ -791,6 +791,30 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): """ + @overload + def __init__( + self, + modules: OrderedDict, + partial_tolerant: bool = False, + return_composite: bool | None = None, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, + ) -> None: + ... + + @overload + def __init__( + self, + modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule], + partial_tolerant: bool = False, + return_composite: bool | None = None, + aggregate_probabilities: bool | None = None, + include_sum: bool | None = None, + inplace: bool | None = None, + ) -> None: + ... + def __init__( self, *modules: TensorDictModuleBase | ProbabilisticTensorDictModule, @@ -815,7 +839,7 @@ def __init__( "ProbabilisticTensorDictSequential (unless return_composite is set to ``True``)." ) # if the modules not including the final probabilistic module return the sampled - # key we wont be sampling it again, in that case + # key we won't be sampling it again, in that case # ProbabilisticTensorDictSequential is presumably used to return the # distribution using `get_dist` or to sample log_probabilities _, out_keys = self._compute_in_and_out_keys(modules[:-1]) diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index e668f37fe..3f20b4195 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -488,7 +488,9 @@ def select_subsequence( return type(self)(*modules) else: keys = [key for key in self.module if self.module[key] in modules] - modules_dict = OrderedDict(**{key: val for key, val in zip(keys, modules)}) + modules_dict = collections.OrderedDict( + **{key: val for key, val in zip(keys, modules)} + ) return type(self)(modules_dict) def _run_module( @@ -565,8 +567,10 @@ def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase: else: return type(self)(*self.module.__getitem__(index)) - def __setitem__(self, index: int, tensordict_module: TensorDictModuleBase) -> None: + def __setitem__( + self, index: int | slice | str, tensordict_module: TensorDictModuleBase + ) -> None: return self.module.__setitem__(idx=index, module=tensordict_module) - def __delitem__(self, index: int | slice) -> None: + def __delitem__(self, index: int | slice | str) -> None: self.module.__delitem__(idx=index)