Skip to content

Commit

Permalink
extra cascade for fast subgraph.__eq__
Browse files Browse the repository at this point in the history
  • Loading branch information
InnocentBug committed Oct 27, 2024
1 parent fd6f382 commit 0b04c4f
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions python/src/ptens/subgraph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#
# This file is part of ptens, a C++/CUDA library for permutation
# equivariant message passing.
#
# This file is part of ptens, a C++/CUDA library for permutation
# equivariant message passing.
#
# Copyright (c) 2023, Imre Risi Kondor
#
# This source code file is subject to the terms of the noncommercial
# license distributed with cnine in the file LICENSE.TXT. Commercial
# use is prohibited. All redistributed versions of this file (in
# original or modified form) must retain this copyright notice and
# must be accompanied by a verbatim copy of the license.
# This source code file is subject to the terms of the noncommercial
# license distributed with cnine in the file LICENSE.TXT. Commercial
# use is prohibited. All redistributed versions of this file (in
# original or modified form) must retain this copyright notice and
# must be accompanied by a verbatim copy of the license.
#
#
import torch
Expand All @@ -29,7 +29,7 @@ def make(self,x):
@classmethod
def from_edge_index(self,M,n=-1,labels=None,degrees=None):
G=subgraph()
if degrees is None:
if degrees is None:
if labels is None:
G.obj=_subgraph.edge_index(M,n)
else:
Expand Down Expand Up @@ -94,15 +94,15 @@ def n_espaces(self):
def evecs(self):
self.set_evecs()
return self.obj.evecs()

def set_evecs(self):
if self.has_espaces()>0:
return
L=self.torch().float()
L=torch.diag(torch.sum(L,1))-L
U,S,V=torch.linalg.svd(L)
self.obj.set_evecs(U,S)

def torch(self):
return self.obj.dense()

Expand All @@ -118,6 +118,8 @@ def __repr__(self):

# ---- Operators --------------------------------------------------------------------------------------------
def __eq__(self, other):
if id(self) == id(other):
return True:
if id(self.obj) == id(other.obj):
return True
return self.obj.__eq__(other.obj)


0 comments on commit 0b04c4f

Please sign in to comment.