-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
62 lines (45 loc) · 1.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import sys
import os
from engine import Engine
from misc.utils import log
import options
import argparse
from misc.dist_utils import get_dist_info, init_dist, setup_for_distributed
def main():
log.process(os.getpid())
log.title("[{}] (PyTorch code for training MuRF)".format(sys.argv[0]))
opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd, load_confd=True)
options.save_options_file(opt)
# distributed training
if getattr(opt, 'dist', False):
print('distributed training')
dist_params = dict(backend='nccl')
launcher = getattr(opt, 'launcher', 'pytorch')
init_dist(launcher, **dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
opt.gpu_ids = range(world_size)
opt.local_rank = int(os.environ['LOCAL_RANK'])
opt.device = torch.device('cuda:{}'.format(opt.local_rank))
setup_for_distributed(opt.local_rank == 0)
else:
opt.local_rank = 0
opt.dist = False
m = Engine(opt)
# setup model
m.build_networks()
# setup dataset
if getattr(opt, 'no_val', False):
m.load_dataset(splits=['train', 'test'])
else:
m.load_dataset(splits=['train', 'val', 'test'])
# setup trianing utils
m.setup_visualizer()
m.setup_optimizer()
if opt.resume or opt.load:
m.restore_checkpoint()
m.train_model()
if __name__=="__main__":
main()