Skip to content

Commit

Permalink
fix linkpred models
Browse files Browse the repository at this point in the history
  • Loading branch information
EdisonLeeeee committed Sep 21, 2021
1 parent f391c19 commit c95e749
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion graphgallery/gallery/linkpred/pyg/gae.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def model_step(self,
lr=0.01,
bias=False):

model = get_model("autoencoder.VGAE", self.backend)
model = get_model("autoencoder.GAE", self.backend)
model = model(self.graph.num_node_attrs,
out_features=out_features,
hids=hids,
Expand Down
6 changes: 3 additions & 3 deletions graphgallery/nn/models/pyg/autoencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def train_step_on_batch(self,
self.train()
optimizer = self.optimizer
optimizer.zero_grad()
x = to_device(x, device=device)
x, _ = to_device(x, device=device)
z = self.encode(*x)
# here `out_index` maybe pos_edge_index
# or (pos_edge_index, neg_edge_index)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_step_on_batch(self,
device="cpu"):
self.eval()
metrics = self.metrics
x = to_device(x, device=device)
x, _ = to_device(x, device=device)
z = self.encode(*x)
pred = self.decode(z, out_index)

Expand All @@ -78,7 +78,7 @@ def test_step_on_batch(self,
@torch.no_grad()
def predict_step_on_batch(self, x, out_index=None, device="cpu"):
self.eval()
x = to_device(x, device=device)
x, _ = to_device(x, device=device)
z = self.encode(*x)
pred = self.decode(z, out_index)
return pred.cpu().detach()
2 changes: 1 addition & 1 deletion graphgallery/nn/models/pytorch/autoencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_step_on_batch(self,
x, y = to_device(x, y, device=device)
z = self.encode(*x)
out = self.decode(z, out_index)
loss = self.compute_loss(out, y)
loss, out = self.compute_loss(out, y)
self.update_metrics(out, y)

if loss is not None:
Expand Down
5 changes: 3 additions & 2 deletions graphgallery/nn/models/pytorch/autoencoder/vgae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def forward(self, x, adj):
out = self.decode(z)
return out

def compute_loss(self, out, y):
def compute_loss(self, out, y, out_index=None):
out = self.index_select(out, out_index=out_index)
if self.training:
mu = self.cache.pop('mu')
logstd = self.cache.pop('logstd')
kl_loss = -0.5 / mu.size(0) * torch.mean(torch.sum(1 + 2 * logstd - mu.pow(2) - logstd.exp().pow(2), dim=1))
else:
kl_loss = 0.
return self.loss(out, y) + kl_loss
return self.loss(out, y) + kl_loss, out
2 changes: 1 addition & 1 deletion graphgallery/nn/models/torch_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_step_on_batch(self,
@torch.no_grad()
def predict_step_on_batch(self, x, out_index=None, device="cpu"):
self.eval()
x = to_device(x, device=device)
x, _ = to_device(x, device=device)
out = self.index_select(self(*x), out_index=out_index)
return out.cpu().detach()

Expand Down

0 comments on commit c95e749

Please sign in to comment.