-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
332 lines (276 loc) · 17.3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
Ok, so this module looks scary. It has abstract base classes, metaclasses, and all that stuff. But what you really need
to know to use it, is that there are two abstract base classes: SingleInputModel and SequenceInputModel, and each model
used in this project should be a sublass of one of these two abstract base classes. If the model only takes as input
information about one timestamp, that is, a tensor of shape [num_nodes, features_dim] (which might include information
about targets from previous timestamps in the features), than this model shpuld be a subclass of SingleInputModel.
If the model takes as input information about a sequence of timestamps, that is, a tensor of shape
[num_nodes, num_timestamps, features_dim], than this model should be a subclass of SequenceInputModel.
"""
from abc import ABC, ABCMeta
import torch
from torch import nn
from modules import (ResidualModulesWrapper, FeedForwardModule,
NEIGHBORHOOD_AGGREGATION_MODULES, SEQUENCE_ENCODER_MODULES, NORMALIZATION_MODULES,
FeaturesPreparatorForDeepModels)
class ModelRegistry(ABCMeta):
registry = {}
def __new__(mcs, name, bases, attrs):
new_cls = ABCMeta.__new__(mcs, name, bases, attrs)
mcs.registry[new_cls.__name__] = new_cls
return new_cls
@classmethod
def get_model_class(mcs, model_class_name):
model_class = mcs.registry[model_class_name]
return model_class
class SingleInputModel(nn.Module, ABC, metaclass=ModelRegistry):
"""
Abstract base class for models that take as input a tensor of shape [num_nodes, features_dim].
This input tensor contains for each node the features of this node at the current timestamp and the target values
for this node at the current and previous timestamps. The model uses this data (in the case of a linear model) or
deep representations of this data (in the case of deep models) to predict target values for this node at the future
timestamps.
Each model in this project should be a subclass of either this abstract base class or SequenceInputModel abstract
base class.
"""
single_input = True
sequence_input = False
class SequenceInputModel(nn.Module, ABC, metaclass=ModelRegistry):
"""
Abstract base class for models that take as input a tensor of shape [num_nodes, num_timestamps, features_dim].
This input tensor contains for each node a sequence, in whcih each sequence element contains the features of
this node and the target value of this node at a single timestamp - the timestamp corresponding to this sequence
element. To predict target values for this node at the future timestamps, the deep representations of this sequence
need to be converted to a single representation. This can be done by taking the deep representations of the final
element of this sequence or by taking the mean and max of the deep representations computed over the entire
sequence or by using all these methods and concatenating their outputs.
Each model in this project should be a subclass of either this abstract base class or SingleInputModel abstract
base class.
"""
single_input = False
sequence_input = True
class LinearModel(SingleInputModel):
"""A simple graph-agnostic linear model."""
def __init__(self, features_dim, output_dim, **kwargs):
super().__init__()
self.linear = nn.Linear(in_features=features_dim, out_features=output_dim)
def forward(self, graph, x):
return self.linear(x).squeeze(1)
class ResNet(SingleInputModel):
"""A graph-agnostic deep model with skip-connections and normalization."""
def __init__(self, normalization_name, num_residual_blocks, features_dim, hidden_dim, output_dim, dropout,
use_learnable_node_embeddings, num_nodes, learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk, deepwalk_node_embeddings,
use_plr_for_num_features, num_features_mask, plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale, plr_num_features_embedding_dim,
plr_num_features_shared_linear, plr_num_features_shared_frequencies,
use_plr_for_past_targets, past_targets_mask, plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale, plr_past_targets_embedding_dim,
plr_past_targets_shared_linear, plr_past_targets_shared_frequencies,
**kwargs):
super().__init__()
NormalizationModule = NORMALIZATION_MODULES[normalization_name]
self.features_preparator = FeaturesPreparatorForDeepModels(
features_dim=features_dim,
use_learnable_node_embeddings=use_learnable_node_embeddings,
num_nodes=num_nodes,
learnable_node_embeddings_dim=learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk=initialize_learnable_node_embeddings_with_deepwalk,
deepwalk_node_embeddings=deepwalk_node_embeddings,
use_plr_for_num_features=use_plr_for_num_features,
num_features_mask=num_features_mask,
plr_num_features_frequencies_dim=plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale=plr_num_features_frequencies_scale,
plr_num_features_embedding_dim=plr_num_features_embedding_dim,
plr_num_features_shared_linear=plr_num_features_shared_linear,
plr_num_features_shared_frequencies=plr_num_features_shared_frequencies,
use_plr_for_past_targets=use_plr_for_past_targets,
past_targets_mask=past_targets_mask,
plr_past_targets_frequencies_dim=plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale=plr_past_targets_frequencies_scale,
plr_past_targets_embedding_dim=plr_past_targets_embedding_dim,
plr_past_targets_shared_linear=plr_past_targets_shared_linear,
plr_past_targets_shared_frequencies=plr_past_targets_shared_frequencies
)
self.input_linear = nn.Linear(in_features=self.features_preparator.output_dim, out_features=hidden_dim)
self.dropout = nn.Dropout(p=dropout)
self.act = nn.GELU()
self.residual_modules = nn.ModuleList()
for _ in range(num_residual_blocks):
residual_module = ResidualModulesWrapper(
modules=[
NormalizationModule(hidden_dim),
FeedForwardModule(dim=hidden_dim, dropout=dropout)
]
)
self.residual_modules.append(residual_module)
self.output_normalization = NormalizationModule(hidden_dim)
self.output_linear = nn.Linear(in_features=hidden_dim, out_features=output_dim)
def forward(self, graph, x):
x = self.features_preparator(x)
x = self.input_linear(x)
x = self.dropout(x)
x = self.act(x)
for residual_module in self.residual_modules:
x = residual_module(graph, x)
x = self.output_normalization(x)
x = self.output_linear(x).squeeze(1)
return x
class SingleInputGNN(SingleInputModel):
"""
Like ResNet, but additionally has a graph neighborhood aggregation (aka message passing) module in each residual
block. That is, each residual block consists of the following sequence of modules: normalization, graph
neighborhood aggregation, two-layer MLP.
"""
def __init__(self, neighborhood_aggregation_name, neighborhood_aggregation_sep, normalization_name, num_edge_types,
num_residual_blocks, features_dim, hidden_dim, output_dim, neighborhood_aggr_attn_num_heads, dropout,
use_learnable_node_embeddings, num_nodes, learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk, deepwalk_node_embeddings,
use_plr_for_num_features, num_features_mask, plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale, plr_num_features_embedding_dim,
plr_num_features_shared_linear, plr_num_features_shared_frequencies,
use_plr_for_past_targets, past_targets_mask, plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale, plr_past_targets_embedding_dim,
plr_past_targets_shared_linear, plr_past_targets_shared_frequencies,
**kwargs):
super().__init__()
NeighborhoodAggregationModule = NEIGHBORHOOD_AGGREGATION_MODULES[neighborhood_aggregation_name]
NormalizationModule = NORMALIZATION_MODULES[normalization_name]
self.features_preparator = FeaturesPreparatorForDeepModels(
features_dim=features_dim,
use_learnable_node_embeddings=use_learnable_node_embeddings,
num_nodes=num_nodes,
learnable_node_embeddings_dim=learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk=initialize_learnable_node_embeddings_with_deepwalk,
deepwalk_node_embeddings=deepwalk_node_embeddings,
use_plr_for_num_features=use_plr_for_num_features,
num_features_mask=num_features_mask,
plr_num_features_frequencies_dim=plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale=plr_num_features_frequencies_scale,
plr_num_features_embedding_dim=plr_num_features_embedding_dim,
plr_num_features_shared_linear=plr_num_features_shared_linear,
plr_num_features_shared_frequencies=plr_num_features_shared_frequencies,
use_plr_for_past_targets=use_plr_for_past_targets,
past_targets_mask=past_targets_mask,
plr_past_targets_frequencies_dim=plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale=plr_past_targets_frequencies_scale,
plr_past_targets_embedding_dim=plr_past_targets_embedding_dim,
plr_past_targets_shared_linear=plr_past_targets_shared_linear,
plr_past_targets_shared_frequencies=plr_past_targets_shared_frequencies
)
self.input_linear = nn.Linear(in_features=self.features_preparator.output_dim, out_features=hidden_dim)
self.dropout = nn.Dropout(p=dropout)
self.act = nn.GELU()
self.residual_modules = nn.ModuleList()
for _ in range(num_residual_blocks):
residual_module = ResidualModulesWrapper(
modules=[
NormalizationModule(hidden_dim),
NeighborhoodAggregationModule(dim=hidden_dim, num_heads=neighborhood_aggr_attn_num_heads,
num_edge_types=num_edge_types, dropout=dropout,
sep=neighborhood_aggregation_sep),
FeedForwardModule(dim=hidden_dim, num_inputs=num_edge_types + neighborhood_aggregation_sep,
dropout=dropout)
]
)
self.residual_modules.append(residual_module)
self.output_normalization = NormalizationModule(hidden_dim)
self.output_linear = nn.Linear(in_features=hidden_dim, out_features=output_dim)
def forward(self, graph, x):
x = self.features_preparator(x)
x = self.input_linear(x)
x = self.dropout(x)
x = self.act(x)
for residual_module in self.residual_modules:
x = residual_module(graph, x)
x = self.output_normalization(x)
x = self.output_linear(x).squeeze(1)
return x
class SequenceInputGNN(SequenceInputModel):
"""
Like ResNet, but additionally has a sequence encoder module and a graph neighborhood aggregation (aka message
passing) module in each residual block. That is, each residual block consists of the following sequence of modules:
normalization, sequence encoder, graph neighborhood aggregation, two-layer MLP. Also has one sequence encoder
module before all residual blocks.
"""
def __init__(self, sequence_encoder_name, neighborhood_aggregation_name, neighborhood_aggregation_sep,
normalization_name, num_edge_types, num_residual_blocks, features_dim, hidden_dim, output_dim,
neighborhood_aggr_attn_num_heads, seq_encoder_num_layers, seq_encoder_rnn_type_name,
seq_encoder_attn_num_heads, seq_encoder_bidir_attn, seq_encoder_seq_len, dropout,
use_learnable_node_embeddings, num_nodes, learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk, deepwalk_node_embeddings,
use_plr_for_num_features, num_features_mask, plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale, plr_num_features_embedding_dim,
plr_num_features_shared_linear, plr_num_features_shared_frequencies,
use_plr_for_past_targets, past_targets_mask, plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale, plr_past_targets_embedding_dim,
plr_past_targets_shared_linear, plr_past_targets_shared_frequencies,
**kwargs):
super().__init__()
SequenceEncoderModule = SEQUENCE_ENCODER_MODULES[sequence_encoder_name]
NeighborhoodAggregationModule = NEIGHBORHOOD_AGGREGATION_MODULES[neighborhood_aggregation_name]
NormalizationModule = NORMALIZATION_MODULES[normalization_name]
self.features_preparator = FeaturesPreparatorForDeepModels(
features_dim=features_dim,
use_learnable_node_embeddings=use_learnable_node_embeddings,
num_nodes=num_nodes,
learnable_node_embeddings_dim=learnable_node_embeddings_dim,
initialize_learnable_node_embeddings_with_deepwalk=initialize_learnable_node_embeddings_with_deepwalk,
deepwalk_node_embeddings=deepwalk_node_embeddings,
use_plr_for_num_features=use_plr_for_num_features,
num_features_mask=num_features_mask,
plr_num_features_frequencies_dim=plr_num_features_frequencies_dim,
plr_num_features_frequencies_scale=plr_num_features_frequencies_scale,
plr_num_features_embedding_dim=plr_num_features_embedding_dim,
plr_num_features_shared_linear=plr_num_features_shared_linear,
plr_num_features_shared_frequencies=plr_num_features_shared_frequencies,
use_plr_for_past_targets=use_plr_for_past_targets,
past_targets_mask=past_targets_mask,
plr_past_targets_frequencies_dim=plr_past_targets_frequencies_dim,
plr_past_targets_frequencies_scale=plr_past_targets_frequencies_scale,
plr_past_targets_embedding_dim=plr_past_targets_embedding_dim,
plr_past_targets_shared_linear=plr_past_targets_shared_linear,
plr_past_targets_shared_frequencies=plr_past_targets_shared_frequencies
)
self.input_linear = nn.Linear(in_features=self.features_preparator.output_dim, out_features=hidden_dim)
self.input_sequence_encoder = SequenceEncoderModule(rnn_type_name=seq_encoder_rnn_type_name,
num_layers=seq_encoder_num_layers, dim=hidden_dim,
num_heads=seq_encoder_attn_num_heads,
bidir_attn=seq_encoder_bidir_attn,
seq_len=seq_encoder_seq_len, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
self.act = nn.GELU()
self.residual_modules = nn.ModuleList()
for _ in range(num_residual_blocks):
residual_module = ResidualModulesWrapper(
modules=[
NormalizationModule(hidden_dim),
SequenceEncoderModule(rnn_type_name=seq_encoder_rnn_type_name, num_layers=seq_encoder_num_layers,
dim=hidden_dim, num_heads=seq_encoder_attn_num_heads,
bidir_attn=seq_encoder_bidir_attn, seq_len=seq_encoder_seq_len,
dropout=dropout),
NeighborhoodAggregationModule(dim=hidden_dim, num_heads=neighborhood_aggr_attn_num_heads,
num_edge_types=num_edge_types, dropout=dropout,
sep=neighborhood_aggregation_sep),
FeedForwardModule(dim=hidden_dim, num_inputs=num_edge_types + neighborhood_aggregation_sep,
dropout=dropout)
]
)
self.residual_modules.append(residual_module)
self.output_normalization = NormalizationModule(hidden_dim * 3)
self.output_linear = nn.Linear(in_features=hidden_dim * 3, out_features=output_dim)
def forward(self, graph, x):
x = self.features_preparator(x)
x = self.input_linear(x)
x = self.input_sequence_encoder(x)
x = self.dropout(x)
x = self.act(x)
for residual_module in self.residual_modules:
x = residual_module(graph, x)
x_final = x[:, -1]
x_mean = x.mean(axis=1)
x_max = x.max(axis=1).values
x = torch.cat([x_final, x_mean, x_max], axis=1)
x = self.output_normalization(x)
x = self.output_linear(x).squeeze(1)
return x