Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] torch.vmap fails when chunk_size is set to some positive integer. #1091

Open
3 tasks
busFred opened this issue Nov 15, 2024 · 1 comment
Open
3 tasks
Assignees
Labels
bug Something isn't working

Comments

@busFred
Copy link

busFred commented Nov 15, 2024

Describe the bug

torch.vmap seems to be incompatible with tensordict.TensorDictBase input when chunk_size is not None.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch
from tensordict import tensorclass

@tensorclass
class Data:
    a: torch.Tensor
    b: torch.Tensor

def AplusB(data):
    return data.a+data.b

data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
result = torch.vmap(AplusB, chunk_size=1)(data)
print(result)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], [line 13](vscode-notebook-cell:?execution_count=3&line=13)
     [10](vscode-notebook-cell:?execution_count=3&line=10)     return data.a+data.b
     [12](vscode-notebook-cell:?execution_count=3&line=12) data = Data(a=torch.randn(10), b=torch.randn(10), batch_size=[10])
---> [13](vscode-notebook-cell:?execution_count=3&line=13) result = torch.vmap(AplusB, chunk_size=1)(data)
     [14](vscode-notebook-cell:?execution_count=3&line=14) print(result)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203, in vmap.<locals>.wrapped(*args, **kwargs)
    [202](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:202) def wrapped(*args, **kwargs):
--> [203](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:203)     return vmap_impl(
    [204](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:204)         func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
    [205](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/apis.py:205)     )

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317, in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    [312](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:312) batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
    [313](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:313)     in_dims, args, func
    [314](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:314) )
    [316](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:316) if chunk_size is not None:
--> [317](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:317)     chunks_flat_args = _get_chunked_inputs(
    [318](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:318)         flat_args, flat_in_dims, batch_size, chunk_size
    [319](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:319)     )
    [320](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:320)     return _chunked_vmap(
    [321](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:321)         func,
    [322](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:322)         flat_in_dims,
   (...)
    [327](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:327)         **kwargs,
    [328](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:328)     )
    [330](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:330) # If chunk_size is not specified.

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359, in _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
--> [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
    [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360, in <genexpr>(.0)
    [356](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:356)     chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
    [357](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:357)     split_idxs = tuple(itertools.accumulate(chunk_sizes))
    [359](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:359) flat_args_chunks = tuple(
--> [360](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:360)     t.tensor_split(split_idxs, dim=in_dim)
    [361](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:361)     if in_dim is not None
    [362](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:362)     else [
    [363](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:363)         t,
    [364](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:364)     ]
    [365](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:365)     * len(split_idxs)
    [366](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:366)     for t, in_dim in zip(flat_args, flat_in_dims)
    [367](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:367) )
    [369](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:369) # transpose chunk dim and flatten structure
    [370](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:370) # chunks_flat_args is a list of flatten args
    [371](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/torch/_functorch/vmap.py:371) chunks_flat_args = zip(*flat_args_chunks)

File ~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098, in _getattr(self, item)
   [1096](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1096)         return out.data if hasattr(out, "data") else out.tolist()
   [1097](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1097)     return _wrap_method(self, item, out)
-> [1098](https://file+.vscode-resource.vscode-cdn.net/home/hungtien/Documents/school_work/stor712/stor712_hpso/~/anaconda3/envs/stor712_hpso/lib/python3.12/site-packages/tensordict/tensorclass.py:1098) raise AttributeError(item)

AttributeError: tensor_split

Expected behavior

The expected behavior is no error should be spit out.

Screenshots

nope.

System info

Describe the characteristic of your environment:

  • Linux Mint 22
  • conda
  • python=3.12
  • torch=2.5.1+cu124

Additional context

might be related to #823

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@busFred busFred added the bug Something isn't working label Nov 15, 2024
@busFred busFred changed the title [BUG] [BUG] torch.vmap fails when chunk_size is not None Nov 15, 2024
@busFred busFred changed the title [BUG] torch.vmap fails when chunk_size is not None [BUG] torch.vmap fails when chunk_size is set to some positive integer. Nov 15, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 18, 2024

Yep this is likely because of the ugly monkey patching we're doing.
The plan would be to be able to extend vmap like we extend stack and such, and I opened a PR with that but never really moved forward with it
pytorch/pytorch#135471

In the meantime I could patch the "tensordict" vmap to make this work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants