diff --git a/tensordict/functional.py b/tensordict/functional.py index 2699f36bb..226e55da1 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -216,7 +216,7 @@ def pad_sequence( try: item0 = list_of_dicts[0][key] if is_non_tensor(item0): - out.set(key, torch.stack([d[key] for d in list_of_dicts])) + out.set(key, TensorDict.lazy_stack([d[key] for d in list_of_dicts])) continue tensor_shape = item0.shape pos_pad_dim = ( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a77cb11e7..60ba211a9 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1920,6 +1920,13 @@ def test_pad_sequence_nontensor(self): assert (d["a"] == torch.tensor([[1, 1], [2, 0]])).all() assert d["b"] == ["asd", "efg"] + def test_pad_sequence_single_nontensor(self): + d1 = TensorDict({"a": torch.tensor([1, 1]), "b": "asd"}) + d = pad_sequence([d1]) + assert (d["a"] == torch.tensor([[1, 1]])).all() + assert d["b"] == ["asd"] + assert isinstance(d.get("b"), NonTensorStack) + def test_pad_sequence_tensorclass_nontensor(self): @tensorclass class Sample: