-
Notifications
You must be signed in to change notification settings - Fork 7
/
train_or_infer.py
570 lines (502 loc) · 22.8 KB
/
train_or_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
import os
import json
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, radius_graph
from torch.utils.tensorboard import SummaryWriter
os.makedirs('train_log', exist_ok=True)
os.makedirs('rollouts', exist_ok=True)
INPUT_SEQUENCE_LENGTH = 6
batch_size = 2
noise_std = 6.7e-4
training_steps = int(2e7)
log_steps = 5
eval_steps = 20
save_steps = 100
model_path = None # 'model425000.pth'
device = 'cuda'
with open('data/metadata.json', 'rt') as f:
metadata = json.loads(f.read())
num_steps = metadata['sequence_length'] - INPUT_SEQUENCE_LENGTH
normalization_stats = {
'acceleration': {
'mean':torch.FloatTensor(metadata['acc_mean']).to(device),
'std':torch.sqrt(torch.FloatTensor(metadata['acc_std'])**2 + noise_std**2).to(device),
},
'velocity': {
'mean':torch.FloatTensor(metadata['vel_mean']).to(device),
'std':torch.sqrt(torch.FloatTensor(metadata['vel_std'])**2 + noise_std**2).to(device),
},
}
def build_mlp(
input_size,
layer_sizes,
output_size=None,
output_activation=torch.nn.Identity,
activation=torch.nn.ReLU,
):
sizes = [input_size] + layer_sizes
if output_size:
sizes.append(output_size)
layers = []
for i in range(len(sizes) - 1):
act = activation if i < len(sizes) - 2 else output_activation
layers += [torch.nn.Linear(sizes[i], sizes[i + 1]), act()]
return torch.nn.Sequential(*layers)
def time_diff(input_sequence):
return input_sequence[:, 1:] - input_sequence[:, :-1]
def get_random_walk_noise_for_position_sequence(position_sequence, noise_std_last_step):
"""Returns random-walk noise in the velocity applied to the position."""
velocity_sequence = time_diff(position_sequence)
num_velocities = velocity_sequence.shape[1]
velocity_sequence_noise = torch.randn(list(velocity_sequence.shape)) * (noise_std_last_step/num_velocities**0.5)
velocity_sequence_noise = torch.cumsum(velocity_sequence_noise, dim=1)
position_sequence_noise = torch.cat([
torch.zeros_like(velocity_sequence_noise[:, 0:1]),
torch.cumsum(velocity_sequence_noise, dim=1)], dim=1)
return position_sequence_noise
def _read_metadata(data_path):
with open(os.path.join(data_path, 'metadata.json'), 'rt') as fp:
return json.loads(fp.read())
class Encoder(nn.Module):
def __init__(
self,
node_in,
node_out,
edge_in,
edge_out,
mlp_num_layers,
mlp_hidden_dim,
):
super(Encoder, self).__init__()
self.node_fn = nn.Sequential(*[build_mlp(node_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out),
nn.LayerNorm(node_out)])
self.edge_fn = nn.Sequential(*[build_mlp(edge_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], edge_out),
nn.LayerNorm(edge_out)])
def forward(self, x, edge_index, e_features): # global_features
# x: (E, node_in)
# edge_index: (2, E)
# e_features: (E, edge_in)
return self.node_fn(x), self.edge_fn(e_features)
class InteractionNetwork(MessagePassing):
def __init__(
self,
node_in,
node_out,
edge_in,
edge_out,
mlp_num_layers,
mlp_hidden_dim,
):
super(InteractionNetwork, self).__init__(aggr='add')
self.node_fn = nn.Sequential(*[build_mlp(node_in+edge_out, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out),
nn.LayerNorm(node_out)])
self.edge_fn = nn.Sequential(*[build_mlp(node_in+node_in+edge_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], edge_out),
nn.LayerNorm(edge_out)])
def forward(self, x, edge_index, e_features):
# x: (E, node_in)
# edge_index: (2, E)
# e_features: (E, edge_in)
x_residual = x
e_features_residual = e_features
x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features)
return x+x_residual, e_features+e_features_residual
def message(self, edge_index, x_i, x_j, e_features):
e_features = torch.cat([x_i, x_j, e_features], dim=-1)
e_features = self.edge_fn(e_features)
return e_features
def update(self, x_updated, x, e_features):
# x_updated: (E, edge_out)
# x: (E, node_in)
x_updated = torch.cat([x_updated, x], dim=-1)
x_updated = self.node_fn(x_updated)
return x_updated, e_features
class Processor(MessagePassing):
def __init__(
self,
node_in,
node_out,
edge_in,
edge_out,
num_message_passing_steps,
mlp_num_layers,
mlp_hidden_dim,
):
super(Processor, self).__init__(aggr='max')
self.gnn_stacks = nn.ModuleList([
InteractionNetwork(
node_in=node_in,
node_out=node_out,
edge_in=edge_in,
edge_out=edge_out,
mlp_num_layers=mlp_num_layers,
mlp_hidden_dim=mlp_hidden_dim,
) for _ in range(num_message_passing_steps)])
def forward(self, x, edge_index, e_features):
for gnn in self.gnn_stacks:
x, e_features = gnn(x, edge_index, e_features)
return x, e_features
class Decoder(nn.Module):
def __init__(
self,
node_in,
node_out,
mlp_num_layers,
mlp_hidden_dim,
):
super(Decoder, self).__init__()
self.node_fn = build_mlp(node_in, [mlp_hidden_dim for _ in range(mlp_num_layers)], node_out)
def forward(self, x):
# x: (E, node_in)
return self.node_fn(x)
class EncodeProcessDecode(nn.Module):
def __init__(
self,
node_in,
node_out,
edge_in,
latent_dim,
num_message_passing_steps,
mlp_num_layers,
mlp_hidden_dim,
):
super(EncodeProcessDecode, self).__init__()
self._encoder = Encoder(
node_in=node_in,
node_out=latent_dim,
edge_in=edge_in,
edge_out=latent_dim,
mlp_num_layers=mlp_num_layers,
mlp_hidden_dim=mlp_hidden_dim,
)
self._processor = Processor(
node_in=latent_dim,
node_out=latent_dim,
edge_in=latent_dim,
edge_out=latent_dim,
num_message_passing_steps=num_message_passing_steps,
mlp_num_layers=mlp_num_layers,
mlp_hidden_dim=mlp_hidden_dim,
)
self._decoder = Decoder(
node_in=latent_dim,
node_out=node_out,
mlp_num_layers=mlp_num_layers,
mlp_hidden_dim=mlp_hidden_dim,
)
def forward(self, x, edge_index, e_features):
# x: (E, node_in)
x, e_features = self._encoder(x, edge_index, e_features)
x, e_features = self._processor(x, edge_index, e_features)
x = self._decoder(x)
return x
class Simulator(nn.Module):
def __init__(
self,
particle_dimension,
node_in,
edge_in,
latent_dim,
num_message_passing_steps,
mlp_num_layers,
mlp_hidden_dim,
connectivity_radius,
boundaries,
normalization_stats,
num_particle_types,
particle_type_embedding_size,
device='cuda',
):
super(Simulator, self).__init__()
self._boundaries = boundaries
self._connectivity_radius = connectivity_radius
self._normalization_stats = normalization_stats
self._num_particle_types = num_particle_types
self._particle_type_embedding = nn.Embedding(num_particle_types, particle_type_embedding_size) # (9, 16)
self._encode_process_decode = EncodeProcessDecode(
node_in=node_in,
node_out=particle_dimension,
edge_in=edge_in,
latent_dim=latent_dim,
num_message_passing_steps=num_message_passing_steps,
mlp_num_layers=mlp_num_layers,
mlp_hidden_dim=mlp_hidden_dim,
)
self._device = device
def forward(self):
pass
def _build_graph_from_raw(self, position_sequence, n_particles_per_example, particle_types):
n_total_points = position_sequence.shape[0]
most_recent_position = position_sequence[:, -1] # (n_nodes, 2)
velocity_sequence = time_diff(position_sequence)
# senders and receivers are integers of shape (E,)
senders, receivers = self._compute_connectivity(most_recent_position, n_particles_per_example, self._connectivity_radius)
node_features = []
# Normalized velocity sequence, merging spatial an time axis.
velocity_stats = self._normalization_stats["velocity"]
normalized_velocity_sequence = (velocity_sequence - velocity_stats['mean']) / velocity_stats['std']
flat_velocity_sequence = normalized_velocity_sequence.view(n_total_points, -1)
node_features.append(flat_velocity_sequence)
# Normalized clipped distances to lower and upper boundaries.
# boundaries are an array of shape [num_dimensions, 2], where the second
# axis, provides the lower/upper boundaries.
boundaries = torch.tensor(self._boundaries, requires_grad=False).float().to(self._device)
distance_to_lower_boundary = (most_recent_position - boundaries[:, 0][None])
distance_to_upper_boundary = (boundaries[:, 1][None] - most_recent_position)
distance_to_boundaries = torch.cat([distance_to_lower_boundary, distance_to_upper_boundary], dim=1)
normalized_clipped_distance_to_boundaries = torch.clamp(distance_to_boundaries / self._connectivity_radius, -1., 1.)
node_features.append(normalized_clipped_distance_to_boundaries)
if self._num_particle_types > 1:
particle_type_embeddings = self._particle_type_embedding(particle_types)
node_features.append(particle_type_embeddings)
# Collect edge features.
edge_features = []
# Relative displacement and distances normalized to radius
# (E, 2)
# normalized_relative_displacements = (
# torch.gather(most_recent_position, 0, senders) - torch.gather(most_recent_position, 0, receivers)
# ) / self._connectivity_radius
normalized_relative_displacements = (
most_recent_position[senders, :] - most_recent_position[receivers, :]
) / self._connectivity_radius
edge_features.append(normalized_relative_displacements)
normalized_relative_distances = torch.norm(normalized_relative_displacements, dim=-1, keepdim=True)
edge_features.append(normalized_relative_distances)
return torch.cat(node_features, dim=-1), torch.stack([senders, receivers]), torch.cat(edge_features, dim=-1)
def _compute_connectivity(self, node_features, n_particles_per_example, radius, add_self_edges=True):
# handle batches. Default is 2 examples per batch
# Specify examples id for particles/points
batch_ids = torch.cat([torch.LongTensor([i for _ in range(n)]) for i, n in enumerate(n_particles_per_example)]).to(self._device)
# radius = radius + 0.00001 # radius_graph takes r < radius not r <= radius
edge_index = radius_graph(node_features, r=radius, batch=batch_ids, loop=add_self_edges) # (2, n_edges)
receivers = edge_index[0, :]
senders = edge_index[1, :]
return receivers, senders
def _decoder_postprocessor(self, normalized_acceleration, position_sequence):
# The model produces the output in normalized space so we apply inverse
# normalization.
acceleration_stats = self._normalization_stats["acceleration"]
acceleration = (
normalized_acceleration * acceleration_stats['std']
) + acceleration_stats['mean']
# Use an Euler integrator to go from acceleration to position, assuming
# a dt=1 corresponding to the size of the finite difference.
most_recent_position = position_sequence[:, -1]
most_recent_velocity = most_recent_position - position_sequence[:, -2]
new_velocity = most_recent_velocity + acceleration # * dt = 1
new_position = most_recent_position + new_velocity # * dt = 1
return new_position
def predict_positions(self, current_positions, n_particles_per_example, particle_types):
node_features, edge_index, e_features = self._build_graph_from_raw(current_positions, n_particles_per_example, particle_types)
predicted_normalized_acceleration = self._encode_process_decode(node_features, edge_index, e_features)
next_position = self._decoder_postprocessor(predicted_normalized_acceleration, current_positions)
return next_position
def predict_accelerations(self, next_position, position_sequence_noise, position_sequence, n_particles_per_example, particle_types):
noisy_position_sequence = position_sequence + position_sequence_noise
node_features, edge_index, e_features = self._build_graph_from_raw(noisy_position_sequence, n_particles_per_example, particle_types)
predicted_normalized_acceleration = self._encode_process_decode(node_features, edge_index, e_features)
next_position_adjusted = next_position + position_sequence_noise[:, -1]
target_normalized_acceleration = self._inverse_decoder_postprocessor(next_position_adjusted, noisy_position_sequence)
return predicted_normalized_acceleration, target_normalized_acceleration
def _inverse_decoder_postprocessor(self, next_position, position_sequence):
"""Inverse of `_decoder_postprocessor`."""
previous_position = position_sequence[:, -1]
previous_velocity = previous_position - position_sequence[:, -2]
next_velocity = next_position - previous_position
acceleration = next_velocity - previous_velocity
acceleration_stats = self._normalization_stats["acceleration"]
normalized_acceleration = (acceleration - acceleration_stats['mean']) / acceleration_stats['std']
return normalized_acceleration
def save(self, path='model.pth'):
torch.save(self.state_dict(), path)
def load(self, path):
self.load_state_dict(torch.load(path))
def prepare_data_from_tfds(data_path='data/train.tfrecord', is_rollout=False, batch_size=2):
import functools
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import reading_utils
import tree
from tfrecord.torch.dataset import TFRecordDataset
def prepare_inputs(tensor_dict):
pos = tensor_dict['position']
pos = tf.transpose(pos, perm=[1, 0, 2])
target_position = pos[:, -1]
tensor_dict['position'] = pos[:, :-1]
num_particles = tf.shape(pos)[0]
tensor_dict['n_particles_per_example'] = num_particles[tf.newaxis]
if 'step_context' in tensor_dict:
tensor_dict['step_context'] = tensor_dict['step_context'][-2]
tensor_dict['step_context'] = tensor_dict['step_context'][tf.newaxis]
return tensor_dict, target_position
def batch_concat(dataset, batch_size):
windowed_ds = dataset.window(batch_size)
initial_state = tree.map_structure(lambda spec: tf.zeros(shape=[0] + spec.shape.as_list()[1:], dtype=spec.dtype),dataset.element_spec)
def reduce_window(initial_state, ds):
return ds.reduce(initial_state, lambda x, y: tf.concat([x, y], axis=0))
return windowed_ds.map(lambda *x: tree.map_structure(reduce_window, initial_state, x))
def prepare_rollout_inputs(context, features):
out_dict = {**context}
pos = tf.transpose(features['position'], [1, 0, 2])
target_position = pos[:, -1]
out_dict['position'] = pos[:, :-1]
out_dict['n_particles_per_example'] = [tf.shape(pos)[0]]
if 'step_context' in features:
out_dict['step_context'] = features['step_context']
out_dict['is_trajectory'] = tf.constant([True], tf.bool)
return out_dict, target_position
metadata = _read_metadata('data/')
ds = tf.data.TFRecordDataset([data_path])
ds = ds.map(functools.partial(reading_utils.parse_serialized_simulation_example, metadata=metadata))
if is_rollout:
ds = ds.map(prepare_rollout_inputs)
else:
split_with_window = functools.partial(
reading_utils.split_trajectory,
window_length=6 + 1)
ds = ds.flat_map(split_with_window)
ds = ds.map(prepare_inputs)
ds = ds.repeat()
ds = ds.shuffle(512)
ds = batch_concat(ds, batch_size)
ds = tfds.as_numpy(ds)
for i in range(100): # clear screen
print()
return ds
def eval_single_rollout(simulator, features, num_steps, device):
initial_positions = features['position'][:, 0:INPUT_SEQUENCE_LENGTH]
ground_truth_positions = features['position'][:, INPUT_SEQUENCE_LENGTH:]
current_positions = initial_positions
predictions = []
for step in range(num_steps):
next_position = simulator.predict_positions(
current_positions,
n_particles_per_example=features['n_particles_per_example'],
particle_types=features['particle_type'],
) # (n_nodes, 2)
# Update kinematic particles from prescribed trajectory.
kinematic_mask = (features['particle_type'] == 3).clone().detach().to(device)
next_position_ground_truth = ground_truth_positions[:, step]
kinematic_mask = kinematic_mask.bool()[:, None].expand(-1, 2)
next_position = torch.where(kinematic_mask, next_position_ground_truth, next_position)
predictions.append(next_position)
current_positions = torch.cat([current_positions[:, 1:], next_position[:, None, :]], dim=1)
predictions = torch.stack(predictions) # (time, n_nodes, 2)
ground_truth_positions = ground_truth_positions.permute(1,0,2)
loss = (predictions - ground_truth_positions) ** 2
output_dict = {
'initial_positions': initial_positions.permute(1,0,2).cpu().numpy(),
'predicted_rollout': predictions.cpu().numpy(),
'ground_truth_rollout': ground_truth_positions.cpu().numpy(),
'particle_types': features['particle_type'].cpu().numpy(),
}
return output_dict, loss
def eval_rollout(ds, simulator, num_steps, num_eval_steps=1, save_results=False, device='cuda'):
eval_loss = []
i = 0
simulator.eval()
with torch.no_grad():
for example_i, (features, labels) in enumerate(ds):
features['position'] = torch.tensor(features['position']).to(device) # (n_nodes, 600, 2)
features['n_particles_per_example'] = torch.tensor(features['n_particles_per_example']).to(device)
features['particle_type'] = torch.tensor(features['particle_type']).to(device)
labels = torch.tensor(labels).to(device)
example_rollout, loss = eval_single_rollout(simulator, features, num_steps, device)
example_rollout['metadata'] = metadata
eval_loss.append(loss)
if save_results:
example_rollout['metadata'] = metadata
filename = f'rollout_{example_i}.pkl'
filename = os.path.join('rollouts/', filename)
with open(filename, 'wb') as f:
pickle.dump(example_rollout, f)
i += 1
if i >= num_eval_steps:
break
simulator.train()
return torch.stack(eval_loss).mean(0)
def train(simulator):
i = 0
while os.path.isdir('train_log/run'+str(i)):
i += 1
LOG_DIR = 'train_log/run'+str(i)+'/'
writer = SummaryWriter(LOG_DIR)
lr_init = 1e-4
lr_min = 1e-6
lr_decay = 0.1
lr_decay_steps = int(5e6)
lr_new = lr_init
optimizer = torch.optim.Adam(simulator.parameters(), lr=lr_init)
ds = prepare_data_from_tfds(batch_size=batch_size)
# ds_eval = prepare_data_from_tfds(data_path='data/valid.tfrecord', is_rollout=True)
step = 0
try:
for features, labels in ds:
features['position'] = torch.tensor(features['position']).to(device)
features['n_particles_per_example'] = torch.tensor(features['n_particles_per_example']).to(device)
features['particle_type'] = torch.tensor(features['particle_type']).to(device)
labels = torch.tensor(labels).to(device)
sampled_noise = get_random_walk_noise_for_position_sequence(features['position'], noise_std_last_step=noise_std).to(device)
non_kinematic_mask = (features['particle_type'] != 3).clone().detach().to(device)
sampled_noise *= non_kinematic_mask.view(-1, 1, 1)
pred, target = simulator.predict_accelerations(
next_position=labels,
position_sequence_noise=sampled_noise,
position_sequence=features['position'],
n_particles_per_example=features['n_particles_per_example'],
particle_types=features['particle_type'],
)
loss = (pred - target) ** 2
loss = loss.sum(dim=-1)
num_non_kinematic = non_kinematic_mask.sum()
loss = torch.where(non_kinematic_mask.bool(), loss, torch.zeros_like(loss))
loss = loss.sum() / num_non_kinematic
if step % log_steps == 0:
writer.add_scalar("training_loss", loss, step)
writer.add_scalar("lr", lr_new, step)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_new = lr_init * (lr_decay ** (step/lr_decay_steps))
for g in optimizer.param_groups:
g['lr'] = lr_new
step += 1
print(f'Training step: {step}/{training_steps}. Loss: {loss}.', end="\r",)
if step >= training_steps == 0:
break
# if step % eval_steps == 0:
# eval_loss = eval_rollout(ds_eval, simulator, num_steps, num_eval_steps=10, device=device)
# writer.add_scalar("eval_loss", eval_loss, step)
if step % save_steps == 0:
simulator.save(LOG_DIR+'model.pth')
except KeyboardInterrupt:
pass
simulator.save(LOG_DIR+'model.pth')
def infer(simulator):
ds = prepare_data_from_tfds(data_path='data/valid.tfrecord', is_rollout=True)
eval_rollout(ds, simulator, num_steps=num_steps, save_results=True, device=device)
if __name__ == '__main__':
simulator = Simulator(
particle_dimension=2,
node_in=30,
edge_in=3,
latent_dim=128,
num_message_passing_steps=10,
mlp_num_layers=2,
mlp_hidden_dim=128,
connectivity_radius=metadata['default_connectivity_radius'],
boundaries=np.array(metadata['bounds']),
normalization_stats=normalization_stats,
num_particle_types=9,
particle_type_embedding_size=16,
device=device,
)
if model_path is not None:
simulator.load(model_path)
if device == 'cuda':
simulator.cuda()
train(simulator)
# infer(simulator)