-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_loop.py
386 lines (346 loc) · 13 KB
/
training_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import copy
import os
import pickle
import time
import numpy as np
import torch
import util
import wandb
from dataset import InfiniteSampler
from lightning.fabric import Fabric
from thor.checkpoint import CheckpointIO
from thor.score import DefaultScoreFunction
# The training procedure is largely based on
# https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/training_loop.py#L25
# and adapted for distributed training using pytorch lightning.
def training_loop(
fabric: Fabric,
run_dir,
#
dataset_kwargs,
network_kwargs,
pipeline_kwargs,
optimizer_kwargs,
lr_kwargs,
#
batch_size,
batch_gpu,
total_ndata,
log_ndata,
status_ndata,
snapshot_ndata,
checkpoint_ndata,
valid_ndata,
#
ema_kwargs=None,
slice_ndata=None,
seed=0,
loss_scaling=1,
cudnn_benchmark=True,
logger=None,
):
# Initialize.
prev_status_time = time.time()
util.set_random_seed(seed, fabric.global_rank)
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
DO_LOG = logger is not None and log_ndata is not None
# Validate batch size.
batch_gpu_total = batch_size // fabric.world_size
if batch_gpu is None or batch_gpu > batch_gpu_total:
batch_gpu = batch_gpu_total
num_accumulation_rounds = batch_gpu_total // batch_gpu
assert batch_size == batch_gpu * num_accumulation_rounds * fabric.world_size
assert total_ndata % batch_size == 0
assert slice_ndata is None or slice_ndata % batch_size == 0
assert log_ndata is None or log_ndata % batch_size == 0
assert status_ndata is None or status_ndata % batch_size == 0
assert snapshot_ndata is None or (
snapshot_ndata % batch_size == 0 and snapshot_ndata % 1024 == 0
)
assert checkpoint_ndata is None or (
checkpoint_ndata % batch_size == 0 and checkpoint_ndata % 1024 == 0
)
# ==| Dataset(s)
fabric.print("Setting up datasets...")
train_dataset = util.construct_class_by_name(**dataset_kwargs.train)
DO_VALIDATION = False
if "valid" in dataset_kwargs:
fabric.print(
"WARNING: Validation dataset provided but currently not supported."
)
DO_VALIDATION = True
valid_dataset = util.construct_class_by_name(**dataset_kwargs.valid)
# ==| Network
fabric.print("Setting up network...")
with fabric.init_module():
net = util.construct_class_by_name(**network_kwargs)
net = fabric.to_device(net)
net.train()
# --| Print model summary
if fabric.is_global_zero:
ref_data = train_dataset[0]
fabric.print(f"Data shape: {ref_data.shape}")
with fabric.autocast():
util.print_module_summary(
net,
[
torch.zeros(
[1, *ref_data.shape],
device=fabric.device,
),
torch.ones([1], device=fabric.device),
],
max_nesting=6,
)
# ==| Setup training state.
fabric.print("Setting up training state...")
state = util.EasyDict(cur_ndata=0, total_elapsed_time=0)
# Prepare the model for distributed training.
# The model is moved automatically to the right device.
ddp = fabric.setup_module(net) # NOTE: `net is ddp.module` ~> `True`
pipeline = util.construct_class_by_name(**pipeline_kwargs)
optimizer = util.construct_class_by_name(
params=net.parameters(), **optimizer_kwargs
)
# Prepare the optimizer for distributed training.
optimizer = fabric.setup_optimizers(optimizer)
ema = (
util.construct_class_by_name(net=net, **ema_kwargs)
if ema_kwargs is not None
else None
)
# Load previous checkpoint and decide how long to train.
checkpoint = CheckpointIO(
state=state,
net=net,
pipeline=pipeline,
optimizer=optimizer,
ema=ema,
)
checkpoint.load_latest(fabric, run_dir)
stop_at_ndata = total_ndata
if slice_ndata is not None:
granularity = (
checkpoint_ndata
if checkpoint_ndata is not None
else snapshot_ndata if snapshot_ndata is not None else batch_size
)
slice_end_ndata = (
(state.cur_ndata + slice_ndata) // granularity * granularity
) # round down
stop_at_ndata = min(stop_at_ndata, slice_end_ndata)
assert stop_at_ndata > state.cur_ndata
fabric.print(
f"Training from {state.cur_ndata // 1000} kdata to {stop_at_ndata // 1000} kdata:"
)
fabric.print()
fabric.print(
f"Batch size: {batch_size} (per device: {batch_gpu}; number of accumulation rounds: {num_accumulation_rounds})"
)
fabric.print()
# ==| Main training loop
dataset_sampler = InfiniteSampler(
dataset=train_dataset,
rank=fabric.global_rank,
num_replicas=fabric.world_size,
shuffle=True,
seed=seed,
start_idx=state.cur_ndata,
)
# ==| Dataloaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=dataset_sampler,
batch_size=batch_gpu,
pin_memory=True,
num_workers=2,
prefetch_factor=2,
)
if DO_VALIDATION:
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
shuffle=False,
batch_size=batch_gpu,
pin_memory=True,
num_workers=2,
prefetch_factor=2,
drop_last=False,
)
valid_loader = fabric.setup_dataloaders(valid_loader)
train_loader = fabric.setup_dataloaders(train_loader, use_distributed_sampler=False)
dataset_iterator = iter(train_loader)
prev_status_ndata = state.cur_ndata
cumulative_training_time = 0
start_ndata = state.cur_ndata
# def sampling_proc_x0(x):
# return torch.clamp(x, -1.5, 1.5)
losses_accum = []
while True:
done = state.cur_ndata >= total_ndata
# Report status.
if (
status_ndata is not None
and (done or state.cur_ndata % status_ndata == 0)
and (state.cur_ndata != start_ndata or start_ndata == 0)
):
cur_time = time.time()
state.total_elapsed_time += cur_time - prev_status_time
fabric.print(
" +++ ".join(
[
"Status:",
f"{state.cur_ndata} / {total_ndata} ({state.cur_ndata/total_ndata:.2%})",
f"{state.total_elapsed_time:.2f} sec total",
f"{cur_time - prev_status_time:.2f} sec/tick",
f"{cumulative_training_time / max(state.cur_ndata - prev_status_ndata, 1) * 1e3:.3f} sec/kdata",
]
)
)
cumulative_training_time = 0
prev_status_ndata = state.cur_ndata
prev_status_time = cur_time
# Save network snapshot.
if (
snapshot_ndata is not None
and state.cur_ndata % snapshot_ndata == 0
and (state.cur_ndata != start_ndata)
and fabric.is_global_zero
):
ema_list = (
ema.get()
if ema is not None
else optimizer.get_ema(net) if hasattr(optimizer, "get_ema") else net
)
ema_list = ema_list if isinstance(ema_list, list) else [(ema_list, "")]
for ema_net, ema_suffix in ema_list:
snap_data = util.EasyDict(
dataset_kwargs=dataset_kwargs, pipeline=pipeline
)
snap_data.ema = (
copy.deepcopy(ema_net)
.cpu()
.eval()
.requires_grad_(False)
.to(torch.float16)
)
fname = f"network-snapshot-{state.cur_ndata//1000:07d}{ema_suffix}.pkl"
fabric.print(f"Saving {fname} ... ", end="", flush=True)
with open(os.path.join(run_dir, fname), "wb") as f:
pickle.dump(snap_data, f)
fabric.print("done")
del snap_data # conserve memory
# Validation
if (
valid_ndata is not None
and state.cur_ndata % valid_ndata == 0
and (state.cur_ndata != start_ndata or start_ndata == 0)
):
# Log samples
if fabric.is_global_zero:
ema_list = (
ema.get()
if ema is not None
else (
optimizer.get_ema(net) if hasattr(optimizer, "get_ema") else net
)
)
ema_list = ema_list if isinstance(ema_list, list) else [(ema_list, "")]
with fabric.autocast():
noisevec = torch.randn(
dataset_kwargs.train.window,
dataset_kwargs.train.num_features,
dataset_kwargs.train.spatial_res,
dataset_kwargs.train.spatial_res,
).to(device=fabric.device)
for ema_net, ema_suffix in ema_list:
_ema_module_train = ema_net.training
ema_net.train(False) # Set to eval mode
score_function = DefaultScoreFunction(
ema_net,
markov_order=dataset_kwargs.train.window // 2,
noise_process=pipeline,
)
with torch.no_grad():
with fabric.autocast():
gen_sample = pipeline.sample(
score_function,
noisevec,
steps=100,
# proc_x0=sampling_proc_x0, # 1
device=fabric.device,
).cpu()
ema_net.train(_ema_module_train) # Restore training/eval mode
img_array = util.trajectory_to_imgrid(gen_sample)
if DO_LOG:
_hist_fig = wandb.Image(util.value_histogram(gen_sample))
imgs = wandb.Image(
img_array, caption="Samples [time x features]"
)
logger.log(
{
f"gen_sample{ema_suffix}": imgs,
f"value_histogram{ema_suffix}": _hist_fig,
},
commit=False,
)
if DO_VALIDATION:
fabric.print(
"WARNING: Validation dataset provided but currently not supported."
)
# Logging
if (
DO_LOG
and log_ndata is not None
and (done or state.cur_ndata % log_ndata == 0)
and (state.cur_ndata != start_ndata)
):
logger.log(
{
"train/loss": np.mean(losses_accum),
"train/kdata": state.cur_ndata // 1000,
"train/elapsed_time": state.total_elapsed_time,
**{
f"train/lr-{i}": g["lr"]
for i, g in enumerate(optimizer.param_groups, 1)
},
},
)
losses_accum = []
# Save state checkpoint.
if (
checkpoint_ndata is not None
and (done or state.cur_ndata % checkpoint_ndata == 0)
and state.cur_ndata != start_ndata
):
checkpoint.save(
fabric,
os.path.join(
run_dir, f"training-state-{state.cur_ndata//1000:07d}.ckpt"
),
)
# Done?
if done:
break
# Evaluate loss and accumulate gradients.
batch_start_time = time.time()
# util.set_random_seed(seed, fabric.global_rank, state.cur_ndata)
optimizer.zero_grad()
for round_idx in range(num_accumulation_rounds):
is_accumulating = round_idx != num_accumulation_rounds - 1
with fabric.no_backward_sync(ddp, enabled=is_accumulating):
data = next(dataset_iterator)
loss = pipeline.loss(net=ddp, x=data).mean().mul(loss_scaling)
fabric.backward(loss)
lr = util.call_func_by_name(cur_ndata=state.cur_ndata, **lr_kwargs)
for g in optimizer.param_groups:
g["lr"] = lr
optimizer.step()
losses_accum.append(loss.detach().item())
# Update EMA and training state.
state.cur_ndata += batch_size
if ema is not None:
ema.update(cur_ndata=state.cur_ndata, batch_size=batch_size)
cumulative_training_time += time.time() - batch_start_time