From 83dfcefed5c1b8ab48bd7d1a7843a2380c1ec666 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 18:17:27 +0000 Subject: [PATCH] [Feature] Add missing `__torch_function__` ghstack-source-id: 4c432bf4bee9a5cb15b804f08e8bd8d3d2db4ea4 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1169 --- tensordict/_torch_func.py | 15 +++++++++++++++ tensordict/tensorclass.py | 2 ++ 2 files changed, 17 insertions(+) diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index f3f2cf35e..2771d4985 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -84,6 +84,21 @@ def _unbind(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: return td.unbind(*args, **kwargs) +@implements_for_td(torch.unflatten) +def _unflatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.unflatten(*args, **kwargs) + + +@implements_for_td(torch.flatten) +def _flatten(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.flatten(*args, **kwargs) + + +@implements_for_td(torch.transpose) +def _transpose(td: T, *args: Any, **kwargs: Any) -> tuple[T, ...]: + return td.transpose(*args, **kwargs) + + @implements_for_td(torch.gather) def _gather( input: T, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 957c1ddd4..b09d2467a 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -104,6 +104,7 @@ def __subclasscheck__(self, subclass): torch.cat: True, torch.clone: True, torch.empty_like: True, + torch.flatten: True, torch.full_like: True, torch.gather: True, torch.ones_like: True, @@ -114,6 +115,7 @@ def __subclasscheck__(self, subclass): torch.squeeze: True, torch.stack: True, torch.unbind: True, + torch.unflatten: True, torch.unsqueeze: True, torch.zeros_like: True, }