-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Stacking NonTensorData
does not appear to return a NonTensorStack
#1047
Comments
Hello, Yeah we do indeed check if the content match when using torch.stack. This is to avoid creating countless copies of the same non-tensor data when all the content match, or a more consistent behaviour with index + stack. What we want is for this to work: td = TensorDict(a=set(), batch_size=[2])
td_reconstruct = torch.stack([td[0], td[1]])
td_reconstruct["a"] is td["a"] Currently we use To summarize the current state, we have from tensordict import TensorDict
import torch
# 1. This gives a stack
a0 = set()
a1 = set([1])
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])
# 2. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
torch.stack([TensorDict(a=a0), TensorDict(a=a1)])
# 3. This gives a stack
a0 = set()
a1 = set()
TensorDict.lazy_stack([TensorDict(a=a0), TensorDict(a=a1)])
# 4. This does not give a stack - but maybe it should?
a0 = set()
a1 = set()
TensorDict.maybe_dense_stack([TensorDict(a=a0), TensorDict(a=a1)]) and we want to change the behaviour of 2. and 4. |
@rehno-lindeque I implemented this in #1083. Given the bc-breaking nature of this change I can only fully change the behaviour two major releases from now (v0.8), but I think your use case will be covered as soon as v0.7. |
Thank you @vmoens I have a workaround implemented, but it is encouraging to see on-going work. Just a quick thought:
I'm not sure if I fully understand the motivation for overloading the stack operation... Is it mostly a performance consideration? If I assumed correctly that tensordict is making shallow copies of non-tensor data, surely there's almost no overhead? |
What I meant is just that having a gazillion of I wrote some more doc - can you give it a look and lmk if that's of any help? |
Thanks @vmoens I will try to reply coherently once I get back into my dependent project. Btw. thank you, I've noted (I have to wonder about things like EDIT: >>> tensordict.__version__
'0.5.0'
>>> a0 = set()
>>> a1 = set([1])
>>> torch.cat((torch.cat((TensorDict(a=a0, batch_size=1), TensorDict(a=a0, batch_size=1))), TensorDict(a=a1, batch_size=1)))
TensorDict(
fields={
a: NonTensorData(data=set(), batch_size=torch.Size([3]), device=None)},
batch_size=torch.Size([3]),
device=None,
is_shared=False) |
Yes I agree, it's not super clear. Perhaps one way to simplify this would be to always return a stack - and leave it to the users to agglomerate that in a single NonTensorData if they want to. We have one shot to do this, which is that we'll be porting tensordict to torch core, and that will leave us the opportunity to do some clean-up of the API like this! RE cat: |
Thanks, just wanted to add that tensordict has been amazingly useful even at this early stage. So far stacking / reshaping / etc non-tensor metadata has been the only real stumbling block for me. |
Describe the bug
Hi, please let me know if I'm using this feature incorrectly or if this is well known.
I've been unable to get
NonTensorStack
to work in various contexts.The simplest example I can come up with is this one:
I expected all of these examples to produce a
NonTensorStack
, yet onlyb_stack
appears to produce what I was expecting:I think I'd have hoped to see
torch.stack((a,a), dim=0).data == [{}, {}]
torch.stack((b,b), dim=0).data == [[{}], [{}]]
torch.stack((a_stack,a_stack), dim=0).data == [{}, {}]
This may be a separate issue, but even for the final case that appears to somewhat work...
there's still a number of issues that make it unusable for even the most basic use cases...
Thanks!
Checklist
The text was updated successfully, but these errors were encountered: