diff --git a/python/src/ptens/ptensor.py b/python/src/ptens/ptensor.py index 558978d..6c3c27f 100644 --- a/python/src/ptens/ptensor.py +++ b/python/src/ptens/ptensor.py @@ -37,7 +37,26 @@ def make(cls, atoms:list, M:torch.Tensor | torch.Size): # ---- Operations ---------------------------------------------------------------------------------------- - + _covariant_functions = [ + torch.Tensor.to, + torch.Tensor.clone, + ] + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + r= super().__torch_function__(func, types, args, kwargs) + if func in cls._covariant_functions: + # A bit more robuts with the order of arguments. + for arg in args: + if hasattr(arg, "atoms"): + r.atoms = arg.atoms + break + return r + + def __add__(self,y): assert self.size()==y.size() assert self.atoms==y.atoms @@ -51,6 +70,7 @@ def to_string(self,indent): return self.backend().str(indent) + def to(self, *args, **kwargs): M = super().to(*args, **kwargs) M.atoms = self.atoms