From 0b04c4f2a1969a8f372b15e7aea364b512089f41 Mon Sep 17 00:00:00 2001 From: Ludwig Schneider Date: Sun, 27 Oct 2024 16:03:04 -0500 Subject: [PATCH] extra cascade for fast subgraph.__eq__ --- python/src/ptens/subgraph.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/src/ptens/subgraph.py b/python/src/ptens/subgraph.py index f1f915a..187eaad 100644 --- a/python/src/ptens/subgraph.py +++ b/python/src/ptens/subgraph.py @@ -1,14 +1,14 @@ # -# This file is part of ptens, a C++/CUDA library for permutation -# equivariant message passing. -# +# This file is part of ptens, a C++/CUDA library for permutation +# equivariant message passing. +# # Copyright (c) 2023, Imre Risi Kondor # -# This source code file is subject to the terms of the noncommercial -# license distributed with cnine in the file LICENSE.TXT. Commercial -# use is prohibited. All redistributed versions of this file (in -# original or modified form) must retain this copyright notice and -# must be accompanied by a verbatim copy of the license. +# This source code file is subject to the terms of the noncommercial +# license distributed with cnine in the file LICENSE.TXT. Commercial +# use is prohibited. All redistributed versions of this file (in +# original or modified form) must retain this copyright notice and +# must be accompanied by a verbatim copy of the license. # # import torch @@ -29,7 +29,7 @@ def make(self,x): @classmethod def from_edge_index(self,M,n=-1,labels=None,degrees=None): G=subgraph() - if degrees is None: + if degrees is None: if labels is None: G.obj=_subgraph.edge_index(M,n) else: @@ -94,7 +94,7 @@ def n_espaces(self): def evecs(self): self.set_evecs() return self.obj.evecs() - + def set_evecs(self): if self.has_espaces()>0: return @@ -102,7 +102,7 @@ def set_evecs(self): L=torch.diag(torch.sum(L,1))-L U,S,V=torch.linalg.svd(L) self.obj.set_evecs(U,S) - + def torch(self): return self.obj.dense() @@ -118,6 +118,8 @@ def __repr__(self): # ---- Operators -------------------------------------------------------------------------------------------- def __eq__(self, other): + if id(self) == id(other): + return True: + if id(self.obj) == id(other.obj): + return True return self.obj.__eq__(other.obj) - -