-
Notifications
You must be signed in to change notification settings - Fork 0
/
sub_modules.py
445 lines (358 loc) · 15.5 KB
/
sub_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
"""
Implementation of ``nn.Modules`` for temporal fusion transformer.
"""
import math
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class TimeDistributed(nn.Module):
def __init__(self, module: nn.Module, batch_first: bool = False):
super().__init__()
self.module = module
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.module(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
class TimeDistributedInterpolation(nn.Module):
def __init__(self, output_size: int, batch_first: bool = False, trainable: bool = False):
super().__init__()
self.output_size = output_size
self.batch_first = batch_first
self.trainable = trainable
if self.trainable:
self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float32))
self.gate = nn.Sigmoid()
def interpolate(self, x):
upsampled = F.interpolate(x.unsqueeze(1), self.output_size, mode="linear", align_corners=True).squeeze(1)
if self.trainable:
upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0
return upsampled
def forward(self, x):
if len(x.size()) <= 2:
return self.interpolate(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = self.interpolate(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
class GatedLinearUnit(nn.Module):
"""Gated Linear Unit"""
def __init__(self, input_size: int, hidden_size: int = None, dropout: float = None):
super().__init__()
if dropout is not None:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = dropout
self.hidden_size = hidden_size or input_size
self.fc = nn.Linear(input_size, self.hidden_size * 2)
self.init_weights()
def init_weights(self):
for n, p in self.named_parameters():
if "bias" in n:
torch.nn.init.zeros_(p)
elif "fc" in n:
torch.nn.init.xavier_uniform_(p)
def forward(self, x):
if self.dropout is not None:
x = self.dropout(x)
x = self.fc(x)
x = F.glu(x, dim=-1)
return x
class ResampleNorm(nn.Module):
def __init__(self, input_size: int, output_size: int = None, trainable_add: bool = True):
super().__init__()
self.input_size = input_size
self.trainable_add = trainable_add
self.output_size = output_size or input_size
if self.input_size != self.output_size:
self.resample = TimeDistributedInterpolation(self.output_size, batch_first=True, trainable=False)
if self.trainable_add:
self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float))
self.gate = nn.Sigmoid()
self.norm = nn.LayerNorm(self.output_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.input_size != self.output_size:
x = self.resample(x)
if self.trainable_add:
x = x * self.gate(self.mask) * 2.0
output = self.norm(x)
return output
class AddNorm(nn.Module):
def __init__(self, input_size: int, skip_size: int = None, trainable_add: bool = True):
super().__init__()
self.input_size = input_size
self.trainable_add = trainable_add
self.skip_size = skip_size or input_size
if self.input_size != self.skip_size:
self.resample = TimeDistributedInterpolation(self.input_size, batch_first=True, trainable=False)
if self.trainable_add:
self.mask = nn.Parameter(torch.zeros(self.input_size, dtype=torch.float))
self.gate = nn.Sigmoid()
self.norm = nn.LayerNorm(self.input_size)
def forward(self, x: torch.Tensor, skip: torch.Tensor):
if self.input_size != self.skip_size:
skip = self.resample(skip)
if self.trainable_add:
skip = skip * self.gate(self.mask) * 2.0
output = self.norm(x + skip)
return output
class GateAddNorm(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int = None,
skip_size: int = None,
trainable_add: bool = False,
dropout: float = None,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size or input_size
self.skip_size = skip_size or self.hidden_size
self.dropout = dropout
self.glu = GatedLinearUnit(self.input_size, hidden_size=self.hidden_size, dropout=self.dropout)
self.add_norm = AddNorm(self.hidden_size, skip_size=self.skip_size, trainable_add=trainable_add)
def forward(self, x, skip):
output = self.glu(x)
output = self.add_norm(output, skip)
return output
class GatedResidualNetwork(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
dropout: float = 0.1,
context_size: int = None,
residual: bool = False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.context_size = context_size
self.hidden_size = hidden_size
self.dropout = dropout
self.residual = residual
if self.input_size != self.output_size and not self.residual:
residual_size = self.input_size
else:
residual_size = self.output_size
if self.output_size != residual_size:
self.resample_norm = ResampleNorm(residual_size, self.output_size)
self.fc1 = nn.Linear(self.input_size, self.hidden_size)
self.elu = nn.ELU()
if self.context_size is not None:
self.context = nn.Linear(self.context_size, self.hidden_size, bias=False)
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.init_weights()
self.gate_norm = GateAddNorm(
input_size=self.hidden_size,
skip_size=self.output_size,
hidden_size=self.output_size,
dropout=self.dropout,
trainable_add=False,
)
def init_weights(self):
for name, p in self.named_parameters():
if "bias" in name:
torch.nn.init.zeros_(p)
elif "fc1" in name or "fc2" in name:
torch.nn.init.kaiming_normal_(p, a=0, mode="fan_in", nonlinearity="leaky_relu")
elif "context" in name:
torch.nn.init.xavier_uniform_(p)
def forward(self, x, context=None, residual=None):
if residual is None:
residual = x
if self.input_size != self.output_size and not self.residual:
residual = self.resample_norm(residual)
x = self.fc1(x)
if context is not None:
context = self.context(context)
x = x + context
x = self.elu(x)
x = self.fc2(x)
x = self.gate_norm(x, residual)
return x
class VariableSelectionNetwork(nn.Module):
def __init__(
self,
input_sizes: Dict[str, int],
hidden_size: int,
input_embedding_flags: Dict[str, bool] = {},
dropout: float = 0.1,
context_size: int = None,
single_variable_grns: Dict[str, GatedResidualNetwork] = {},
prescalers: Dict[str, nn.Linear] = {},
):
"""
Calcualte weights for ``num_inputs`` variables which are each of size ``input_size``
"""
super().__init__()
self.hidden_size = hidden_size
self.input_sizes = input_sizes
self.input_embedding_flags = input_embedding_flags
self.dropout = dropout
self.context_size = context_size
if self.num_inputs > 1:
if self.context_size is not None:
self.flattened_grn = GatedResidualNetwork(
self.input_size_total,
min(self.hidden_size, self.num_inputs),
self.num_inputs,
self.dropout,
self.context_size,
residual=False,
)
else:
self.flattened_grn = GatedResidualNetwork(
self.input_size_total,
min(self.hidden_size, self.num_inputs),
self.num_inputs,
self.dropout,
residual=False,
)
self.single_variable_grns = nn.ModuleDict()
self.prescalers = nn.ModuleDict()
for name, input_size in self.input_sizes.items():
if name in single_variable_grns:
self.single_variable_grns[name] = single_variable_grns[name]
elif self.input_embedding_flags.get(name, False):
self.single_variable_grns[name] = ResampleNorm(input_size, self.hidden_size)
else:
self.single_variable_grns[name] = GatedResidualNetwork(
input_size,
min(input_size, self.hidden_size),
output_size=self.hidden_size,
dropout=self.dropout,
)
if name in prescalers: # reals need to be first scaled up
self.prescalers[name] = prescalers[name]
elif not self.input_embedding_flags.get(name, False):
self.prescalers[name] = nn.Linear(1, input_size)
self.softmax = nn.Softmax(dim=-1)
@property
def input_size_total(self):
return sum(size if name in self.input_embedding_flags else size for name, size in self.input_sizes.items())
@property
def num_inputs(self):
return len(self.input_sizes)
def forward(self, x: Dict[str, torch.Tensor], context: torch.Tensor = None):
if self.num_inputs > 1:
# transform single variables
var_outputs = []
weight_inputs = []
for name in self.input_sizes.keys():
# select embedding belonging to a single input
variable_embedding = x[name]
if name in self.prescalers:
variable_embedding = self.prescalers[name](variable_embedding)
weight_inputs.append(variable_embedding)
var_outputs.append(self.single_variable_grns[name](variable_embedding))
var_outputs = torch.stack(var_outputs, dim=-1)
# calculate variable weights
flat_embedding = torch.cat(weight_inputs, dim=-1)
sparse_weights = self.flattened_grn(flat_embedding, context)
sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)
outputs = var_outputs * sparse_weights
outputs = outputs.sum(dim=-1)
else: # for one input, do not perform variable selection but just encoding
name = next(iter(self.single_variable_grns.keys()))
variable_embedding = x[name]
if name in self.prescalers:
variable_embedding = self.prescalers[name](variable_embedding)
outputs = self.single_variable_grns[name](variable_embedding) # fast forward if only one variable
if outputs.ndim == 3: # -> batch size, time, hidden size, n_variables
sparse_weights = torch.ones(outputs.size(0), outputs.size(1), 1, 1, device=outputs.device) #
else: # ndim == 2 -> batch size, hidden size, n_variables
sparse_weights = torch.ones(outputs.size(0), 1, 1, device=outputs.device)
return outputs, sparse_weights
class PositionalEncoder(torch.nn.Module):
def __init__(self, d_model, max_seq_len=160):
super().__init__()
assert d_model % 2 == 0, "model dimension has to be multiple of 2 (encode sin(pos) and cos(pos))"
self.d_model = d_model
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
with torch.no_grad():
x = x * math.sqrt(self.d_model)
seq_len = x.size(0)
pe = self.pe[:, :seq_len].view(seq_len, 1, self.d_model)
x = x + pe
return x
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout: float = None, scale: bool = True):
super(ScaledDotProductAttention, self).__init__()
if dropout is not None:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = dropout
self.softmax = nn.Softmax(dim=2)
self.scale = scale
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.permute(0, 2, 1)) # query-key overlap
if self.scale:
dimension = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32))
attn = attn / dimension
if mask is not None:
attn = attn.masked_fill(mask, -1e9)
attn = self.softmax(attn)
if self.dropout is not None:
attn = self.dropout(attn)
output = torch.bmm(attn, v)
return output, attn
class InterpretableMultiHeadAttention(nn.Module):
def __init__(self, n_head: int, d_model: int, dropout: float = 0.0):
super(InterpretableMultiHeadAttention, self).__init__()
self.n_head = n_head
self.d_model = d_model
self.d_k = self.d_q = self.d_v = d_model // n_head
self.dropout = nn.Dropout(p=dropout)
self.v_layer = nn.Linear(self.d_model, self.d_v)
self.q_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_q) for _ in range(self.n_head)])
self.k_layers = nn.ModuleList([nn.Linear(self.d_model, self.d_k) for _ in range(self.n_head)])
self.attention = ScaledDotProductAttention()
self.w_h = nn.Linear(self.d_v, self.d_model, bias=False)
self.init_weights()
def init_weights(self):
for name, p in self.named_parameters():
if "bias" not in name:
torch.nn.init.xavier_uniform_(p)
else:
torch.nn.init.zeros_(p)
def forward(self, q, k, v, mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
heads = []
attns = []
vs = self.v_layer(v)
for i in range(self.n_head):
qs = self.q_layers[i](q)
ks = self.k_layers[i](k)
head, attn = self.attention(qs, ks, vs, mask)
head_dropout = self.dropout(head)
heads.append(head_dropout)
attns.append(attn)
head = torch.stack(heads, dim=2) if self.n_head > 1 else heads[0]
attn = torch.stack(attns, dim=2)
outputs = torch.mean(head, dim=2) if self.n_head > 1 else head
outputs = self.w_h(outputs)
outputs = self.dropout(outputs)
return outputs, attn