-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
150 lines (129 loc) · 5.92 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Linear, GATv2Conv, SAGEConv
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dropout=None, flatten=True):
super().__init__()
self.lin1 = Linear(in_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, hidden_channels)
self.lin3 = Linear(hidden_channels, out_channels)
self.use_dropout = True if dropout else False
if self.use_dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = nn.Identity()
self.flatten = flatten
self.reset_parameters()
def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()
self.lin3.reset_parameters()
def forward(self, x):
z = self.lin1(x).relu()
z = self.dropout(z)
z = self.lin2(z).relu()
z = self.dropout(z)
z = self.lin3(z)
if self.flatten:
return z.view(-1)
else:
return z
class HomoGATEncoder(nn.Module):
"""
Edge regression from the node embedding from a directed weighted graph.
x = ... # Node feature matrix: [num_nodes, num_features]
edge_index = ... # Edge indices: [2, num_edges]
edge_value = ... # Edge attribute (OD volume): [num_edges, 1]
"""
def __init__(self,
hidden_channels,
out_channels,
heads, dropout,
num_layers=3,
layer_type='gat',
analyze_mode=False):
super().__init__()
self.num_layers = num_layers
self.layer_type = layer_type
self.analyze_mode = analyze_mode
self.convs = nn.ModuleList()
self.lins = nn.ModuleList()
self.norms = nn.ModuleList()
if self.layer_type == 'gat':
self.convs.append(
GATv2Conv(hidden_channels, hidden_channels, heads=heads, edge_dim=1))
self.lins.append(Linear(hidden_channels, hidden_channels * heads))
self.norms.append(nn.BatchNorm1d(hidden_channels * heads))
self.convs.append(
GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, edge_dim=1))
self.lins.append(Linear(hidden_channels * heads, hidden_channels * heads))
self.norms.append(nn.BatchNorm1d(hidden_channels * heads))
self.convs.append(
GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads,
edge_dim=1, concat=False))
self.lins.append(Linear(hidden_channels * heads, hidden_channels))
self.norms.append(nn.BatchNorm1d(hidden_channels))
elif self.layer_type == 'sage':
for i in range(num_layers):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, hidden_channels))
self.norms.append(nn.BatchNorm1d(hidden_channels))
else:
raise NotImplementedError('Wrong layer type!')
self.in_emb_encoder = nn.Linear(hidden_channels, out_channels)
self.out_emb_encoder = nn.Linear(hidden_channels, out_channels)
self.use_dropout = True if dropout else False
if self.use_dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = nn.Identity()
self.reset_parameters()
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for lin in self.lins:
lin.reset_parameters()
self.in_emb_encoder.reset_parameters()
self.out_emb_encoder.reset_parameters()
def forward(self, edge_index, node_feature, edge_weight):
x = node_feature
for i in range(self.num_layers - 1):
if self.layer_type == 'gat':
x = self.norms[i](self.convs[i](x, edge_index, edge_weight) + self.lins[i](x))
elif self.layer_type == 'sage':
x = self.norms[i](self.convs[i](x, edge_index) + self.lins[i](x))
x = self.dropout(F.leaky_relu(x))
if self.layer_type == 'gat':
x = self.norms[-1](self.convs[-1](x, edge_index, edge_weight) + self.lins[-1](x))
elif self.layer_type == 'sage':
x = self.norms[-1](self.convs[-1](x, edge_index) + self.lins[-1](x))
in_embed = self.in_emb_encoder(x)
out_embed = self.out_emb_encoder(x)
return in_embed, out_embed, x
class SpatialGAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
heads, dropout, layer_type='gat', analyze_mode=False):
super().__init__()
self.projection = MLP(in_channels, hidden_channels, hidden_channels, dropout=dropout, flatten=False)
self.encoder = HomoGATEncoder(hidden_channels=hidden_channels,
out_channels=out_channels,
heads=heads,
dropout=dropout,
layer_type=layer_type)
self.decoder = MLP(out_channels * 2, hidden_channels, 1, dropout=dropout)
self.analyze_mode = analyze_mode
self.reset_parameters()
def forward(self, node_feature, edge_index, edge_weight, edge_label_index):
node_feature = self.projection(node_feature)
in_embed, out_embed, general_embed = self.encoder.forward(edge_index, node_feature, edge_weight)
edge_emb = torch.concat([in_embed[edge_label_index[0]], out_embed[edge_label_index[1]]], dim=-1)
volume = self.decoder(edge_emb).reshape([-1, 1])
if self.analyze_mode:
return volume
else:
return in_embed, out_embed, volume, general_embed
def reset_parameters(self):
self.projection.reset_parameters()
self.encoder.reset_parameters()
self.decoder.reset_parameters()