diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index d70a2ced2..d23c2bec3 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -9,7 +9,18 @@ import inspect import warnings from textwrap import indent -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + MutableSequence, + Optional, + Sequence, + Tuple, + Union, +) import torch from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads @@ -981,20 +992,20 @@ def __init__( else: if isinstance(in_keys, (str, tuple)): in_keys = [in_keys] - elif not isinstance(in_keys, list): + elif not isinstance(in_keys, MutableSequence): raise ValueError(self._IN_KEY_ERR) self._kwargs = None if isinstance(out_keys, (str, tuple)): out_keys = [out_keys] - elif not isinstance(out_keys, list): + elif not isinstance(out_keys, MutableSequence): raise ValueError(self._OUT_KEY_ERR) try: - in_keys = unravel_key_list(in_keys) + in_keys = unravel_key_list(list(in_keys)) except Exception: raise ValueError(self._IN_KEY_ERR) try: - out_keys = unravel_key_list(out_keys) + out_keys = unravel_key_list(list(out_keys)) except Exception: raise ValueError(self._OUT_KEY_ERR) diff --git a/test/test_nn.py b/test/test_nn.py index 6a07ed482..948af946b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11,6 +11,7 @@ import unittest import weakref from collections import OrderedDict +from collections.abc import MutableSequence import pytest import torch @@ -118,6 +119,34 @@ def test_from_str_correct_raise(self, unsupported_type_str): class TestTDModule: + class MyMutableSequence(MutableSequence): + def __init__(self, initial_data=None): + self._data = [] if initial_data is None else list(initial_data) + + def __getitem__(self, index): + return self._data[index] + + def __setitem__(self, index, value): + self._data[index] = value + + def __delitem__(self, index): + del self._data[index] + + def __len__(self): + return len(self._data) + + def insert(self, index, value): + self._data.insert(index, value) + + def test_mutable_sequence(self): + in_keys = self.MyMutableSequence(["a", "b", "c"]) + out_keys = self.MyMutableSequence(["d", "e", "f"]) + mod = TensorDictModule(lambda *x: x, in_keys=in_keys, out_keys=out_keys) + td = mod(TensorDict(a=0, b=0, c=0)) + assert "d" in td + assert "e" in td + assert "f" in td + def test_auto_unravel(self): tdm = TensorDictModule( lambda x: x,