From f7665dbac800892faa05bca04202434062aa7a31 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 9 Dec 2024 11:56:01 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/_lazy.py | 6 +++--- tensordict/base.py | 6 +++--- tensordict/tensorclass.py | 11 ++++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index aeb5a8f23..87255d70f 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2127,10 +2127,10 @@ def __getitem__(self, index: IndexType) -> Any: if index_key: leaf = self._get_tuple(index_key, NO_DEFAULT) if is_non_tensor(leaf): - result = getattr(leaf, "data", NO_DEFAULT) - if result is NO_DEFAULT: + # Only lazy stacks of non tensors are actually tensordict instances + if isinstance(leaf, TensorDictBase): return leaf.tolist() - return result + return leaf.data return leaf split_index = self._split_index(index) converted_idx = split_index["index_dict"] diff --git a/tensordict/base.py b/tensordict/base.py index b8712abd1..aef81ee19 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6268,10 +6268,10 @@ def _get_tuple(self, key, default): ... def _get_tuple_maybe_non_tensor(self, key, default): result = self._get_tuple(key, default) if is_non_tensor(result): - result_data = getattr(result, "data", NO_DEFAULT) - if result_data is NO_DEFAULT: + # Only lazy stacks of non tensors are actually tensordict instances + if isinstance(result, TensorDictBase): return result.tolist() - return result_data + return result.data return result def get_at( diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e5a433af2..03df24562 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -3445,7 +3445,16 @@ def update_at_( @property def data(self): - raise AttributeError + """Attempts to return the unique value in the stack. + + Raises a ValueError if there is more than one unique value. + """ + try: + return NonTensorData._stack_non_tensor( + self.tensordicts, raise_if_non_unique=True + ).data + except ValueError: + raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.") _register_tensor_class(NonTensorStack)