-
Notifications
You must be signed in to change notification settings - Fork 73
/
model.py
82 lines (66 loc) · 2.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import args
class VGAE(nn.Module):
def __init__(self, adj):
super(VGAE,self).__init__()
self.base_gcn = GraphConvSparse(args.input_dim, args.hidden1_dim, adj)
self.gcn_mean = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
self.gcn_logstddev = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
def encode(self, X):
hidden = self.base_gcn(X)
self.mean = self.gcn_mean(hidden)
self.logstd = self.gcn_logstddev(hidden)
gaussian_noise = torch.randn(X.size(0), args.hidden2_dim)
sampled_z = gaussian_noise*torch.exp(self.logstd) + self.mean
return sampled_z
def forward(self, X):
Z = self.encode(X)
A_pred = dot_product_decode(Z)
return A_pred
class GraphConvSparse(nn.Module):
def __init__(self, input_dim, output_dim, adj, activation = F.relu, **kwargs):
super(GraphConvSparse, self).__init__(**kwargs)
self.weight = glorot_init(input_dim, output_dim)
self.adj = adj
self.activation = activation
def forward(self, inputs):
x = inputs
x = torch.mm(x,self.weight)
x = torch.mm(self.adj, x)
outputs = self.activation(x)
return outputs
def dot_product_decode(Z):
A_pred = torch.sigmoid(torch.matmul(Z,Z.t()))
return A_pred
def glorot_init(input_dim, output_dim):
init_range = np.sqrt(6.0/(input_dim + output_dim))
initial = torch.rand(input_dim, output_dim)*2*init_range - init_range
return nn.Parameter(initial)
class GAE(nn.Module):
def __init__(self,adj):
super(GAE,self).__init__()
self.base_gcn = GraphConvSparse(args.input_dim, args.hidden1_dim, adj)
self.gcn_mean = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
def encode(self, X):
hidden = self.base_gcn(X)
z = self.mean = self.gcn_mean(hidden)
return z
def forward(self, X):
Z = self.encode(X)
A_pred = dot_product_decode(Z)
return A_pred
# class GraphConv(nn.Module):
# def __init__(self, input_dim, hidden_dim, output_dim):
# super(VGAE,self).__init__()
# self.base_gcn = GraphConvSparse(args.input_dim, args.hidden1_dim, adj)
# self.gcn_mean = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
# self.gcn_logstddev = GraphConvSparse(args.hidden1_dim, args.hidden2_dim, adj, activation=lambda x:x)
# def forward(self, X, A):
# out = A*X*self.w0
# out = F.relu(out)
# out = A*X*self.w0
# return out