Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Stacking NonTensorData does not appear to return a NonTensorStack #1047

Open
3 tasks done
rehno-lindeque opened this issue Oct 18, 2024 · 7 comments
Open
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@rehno-lindeque
Copy link

rehno-lindeque commented Oct 18, 2024

Describe the bug

Hi, please let me know if I'm using this feature incorrectly or if this is well known.

I've been unable to get NonTensorStack to work in various contexts.

The simplest example I can come up with is this one:

from tensordict import * 

a = NonTensorData({})
b = NonTensorData({}, batch_size=[1])
a_stack = NonTensorStack.from_nontensordata(a)
b_stack = NonTensorStack.from_nontensordata(b)

I expected all of these examples to produce a NonTensorStack, yet only b_stack appears to produce what I was expecting:

>>> torch.stack((a,a), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)

>>> torch.stack((b,b), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2, 1]), device=None)

>>> torch.stack((a_stack,a_stack), dim=0)
NonTensorData(data={}, batch_size=torch.Size([2]), device=None)

>>> torch.stack((b_stack,b_stack), dim=0)
NonTensorStack(
    [[{}], [{}]],
    batch_size=torch.Size([2, 1]),
    device=None)

I think I'd have hoped to see

  • torch.stack((a,a), dim=0).data == [{}, {}]
  • torch.stack((b,b), dim=0).data == [[{}], [{}]]
  • torch.stack((a_stack,a_stack), dim=0).data == [{}, {}]

This may be a separate issue, but even for the final case that appears to somewhat work...

>>> torch.stack((b_stack,b_stack), dim=0).batch_size
torch.Size([2, 1])

>>> torch.stack((b_stack,b_stack), dim=0)[...,0]
NonTensorStack(
    [{}, {}],
    batch_size=torch.Size([2]),
    device=None)

>>> torch.stack((b_stack,b_stack), dim=0)[0,0]
NonTensorData(data={}, batch_size=torch.Size([]), device=None)

there's still a number of issues that make it unusable for even the most basic use cases...

>>> torch.stack((b_stack,b_stack), dim=0).contiguous()
TensorDict(
    fields={
    },
    batch_size=torch.Size([2, 1]),
    device=None,
    is_shared=False)

>>> torch.stack((b_stack,b_stack), dim=0).reshape(-1)
TensorDict(
    fields={
    },
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

>>> torch.stack((b_stack,b_stack), dim=0).reshape(2)
TensorDict(
    fields={
    },
    batch_size=torch.Size([2]),
    device=None,

>>> torch.stack((b_stack,b_stack), dim=0).squeeze(dim=1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1255, in new_func
    out = func(_self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/base.py", line 2070, in squeeze
    result = self._squeeze(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/_lazy.py", line 2927, in _squeeze
    [td.squeeze(dim) for td in self.tensordicts],
     ^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/utils.py", line 1257, in new_func
    out._last_op = (new_func.__name__, (args, kwargs, _self))
    ^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1062, in wrapper
    out = self.set(key, value)
          ^^^^^^^^^^^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 1482, in _set
    raise AttributeError(
AttributeError: Cannot set the attribute '_last_op', expected attributes are {'_is_non_tensor', '_metadata', 'data'}.

>>> @tensorclass
... class B:
...   b: NonTensorStack

>>> B(b=torch.stack((b_stack,b_stack), dim=0))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 679, in wrapper
    key: value.data if is_non_tensor(value) else value
         ^^^^^^^^^^
  File "/nix/store/x46lwllqra2ca4wbyhk2cihzmwzml4cj-python3-3.12.4-env/lib/python3.12/site-packages/tensordict/tensorclass.py", line 3095, in data
    raise AttributeError
AttributeError. Did you mean: '_data'?

Thanks!

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@rehno-lindeque rehno-lindeque added the bug Something isn't working label Oct 18, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 11, 2024

Hello,

Yeah we do indeed check if the content match when using torch.stack. This is to avoid creating countless copies of the same non-tensor data when all the content match, or a more consistent behaviour with index + stack. What we want is for this to work:

td = TensorDict(a=set(), batch_size=[2])
td_reconstruct = torch.stack([td[0], td[1]])
td_reconstruct["a"] is td["a"]

Currently we use __eq__ to compare the contents of the NonTensorData but that's not great. is would lead to a better behaviour (and faster execution).

To summarize the current state, we have

from tensordict import TensorDict
import torch

# 1. This gives a stack
a0 = set()
a1 = set([1])
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])

# 2. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])

# 3. This gives a stack
a0 = set()
a1 = set()
TensorDict.lazy_stack([TensorDict(a=a0), TensorDict(a=a1)])

# 4. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
TensorDict.maybe_dense_stack([TensorDict(a=a0), TensorDict(a=a1)])

and we want to change the behaviour of 2. and 4.

@vmoens
Copy link
Contributor

vmoens commented Nov 11, 2024

@rehno-lindeque I implemented this in #1083. Given the bc-breaking nature of this change I can only fully change the behaviour two major releases from now (v0.8), but I think your use case will be covered as soon as v0.7.

@rehno-lindeque
Copy link
Author

rehno-lindeque commented Jan 2, 2025

Thank you @vmoens I have a workaround implemented, but it is encouraging to see on-going work.

Just a quick thought:

This is to avoid creating countless copies of the same non-tensor data

I'm not sure if I fully understand the motivation for overloading the stack operation... Is it mostly a performance consideration?

If I assumed correctly that tensordict is making shallow copies of non-tensor data, surely there's almost no overhead?

@vmoens
Copy link
Contributor

vmoens commented Jan 9, 2025

What I meant is just that having a gazillion of NonTensorData objects in a list, all wrapping the same content may cause some clutter but it could be that I'm worrying too much :)

I wrote some more doc - can you give it a look and lmk if that's of any help?

#1173

@rehno-lindeque
Copy link
Author

rehno-lindeque commented Jan 17, 2025

Thanks @vmoens I will try to reply coherently once I get back into my dependent project.

Btw. thank you, I've noted lazy_stack, though I still have doubts about potentially confusing behavior of overloading.

(I have to wonder about things like torch.cat((torch.cat((TensorDict(a=a0), TensorDict(a=a0))), TensorDict(a=a1))) for example, but I'm already breaking my promise to be coherent.)

EDIT:

>>> tensordict.__version__
'0.5.0'
>>> a0 = set()
>>> a1 = set([1])
>>> torch.cat((torch.cat((TensorDict(a=a0, batch_size=1), TensorDict(a=a0, batch_size=1))), TensorDict(a=a1, batch_size=1)))
TensorDict(
    fields={
        a: NonTensorData(data=set(), batch_size=torch.Size([3]), device=None)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

@vmoens
Copy link
Contributor

vmoens commented Jan 17, 2025

Yes I agree, it's not super clear. Perhaps one way to simplify this would be to always return a stack - and leave it to the users to agglomerate that in a single NonTensorData if they want to.

We have one shot to do this, which is that we'll be porting tensordict to torch core, and that will leave us the opportunity to do some clean-up of the API like this!

RE cat:
I can fix that thanks for reporting

@rehno-lindeque
Copy link
Author

rehno-lindeque commented Jan 17, 2025

Thanks, just wanted to add that tensordict has been amazingly useful even at this early stage. So far stacking / reshaping / etc non-tensor metadata has been the only real stumbling block for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants