diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index f3f2cf35e..2771d4985 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -84,6 +84,21 @@ def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: return td.unbind(*args, **kwargs) +@implements_for_td(torch.unflatten) +def _unflatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.unflatten(*args, **kwargs) + + +@implements_for_td(torch.flatten) +def _flatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.flatten(*args, **kwargs) + + +@implements_for_td(torch.transpose) +def _transpose(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.transpose(*args, **kwargs) + + @implements_for_td(torch.gather) def _gather( input: T, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 957c1ddd4..b09d2467a 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -104,6 +104,7 @@ def __subclasscheck__(self, subclass): torch.cat: True, torch.clone: True, torch.empty_like: True, + torch.flatten: True, torch.full_like: True, torch.gather: True, torch.ones_like: True, @@ -114,6 +115,7 @@ def __subclasscheck__(self, subclass): torch.squeeze: True, torch.stack: True, torch.unbind: True, + torch.unflatten: True, torch.unsqueeze: True, torch.zeros_like: True, }