Skip to content

Commit

Permalink
use covariant functions
Browse files Browse the repository at this point in the history
  • Loading branch information
InnocentBug committed Nov 8, 2024
1 parent ce330d7 commit 27feef4
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion python/src/ptens/ptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,26 @@ def make(cls, atoms:list, M:torch.Tensor | torch.Size):

# ---- Operations ----------------------------------------------------------------------------------------


_covariant_functions = [
torch.Tensor.to,
torch.Tensor.clone,
]

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

r= super().__torch_function__(func, types, args, kwargs)
if func in cls._covariant_functions:
# A bit more robuts with the order of arguments.
for arg in args:
if hasattr(arg, "atoms"):
r.atoms = arg.atoms
break
return r


def __add__(self,y):
assert self.size()==y.size()
assert self.atoms==y.atoms
Expand All @@ -51,6 +70,7 @@ def to_string(self,indent):
return self.backend().str(indent)



def to(self, *args, **kwargs):
M = super().to(*args, **kwargs)
M.atoms = self.atoms
Expand Down

0 comments on commit 27feef4

Please sign in to comment.