-
Notifications
You must be signed in to change notification settings - Fork 3
/
horovod_trainer.py
131 lines (119 loc) · 6.49 KB
/
horovod_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
from __future__ import print_function
import time
import torch
import numpy as np
import argparse, os
import settings
import utils
import logging
from dl_trainer import DLTrainer, _support_datasets, _support_dnns
import hv_distributed_optimizer as hvd
from tensorboardX import SummaryWriter
from compression import compressors
from profiling import benchmark
from mpi4py import MPI
comm = MPI.COMM_WORLD
writer = None
from settings import logger, formatter
def ssgd_with_horovod(dnn, dataset, data_dir, nworkers, lr, batch_size, nsteps_update, max_epochs, nwpernode, pretrain, num_steps, compressor, density, threshold):
rank = hvd.rank()
torch.cuda.set_device(rank%nwpernode)
if rank != 0:
pretrain = None
trainer = DLTrainer(rank, nworkers, dist=False, batch_size=batch_size, is_weak_scaling=True, ngpus=1, data_dir=data_dir, dataset=dataset, dnn=dnn, lr=lr, nworkers=nworkers, prefix='allreduce', pretrain=pretrain, num_steps=num_steps, tb_writer=writer)
init_epoch = torch.ones(1) * trainer.get_train_epoch()
init_iter = torch.ones(1) * trainer.get_train_iter()
trainer.set_train_epoch(int(hvd.broadcast(init_epoch, root_rank=0)[0]))
trainer.set_train_iter(int(hvd.broadcast(init_iter, root_rank=0)[0]))
is_sparse = density < 1
if not is_sparse:
compressor = None
if settings.ADAPTIVE_MERGE or settings.ADAPTIVE_SPARSE:
seq_layernames, layerwise_times, layerwise_sizes = benchmark(trainer)
layerwise_times = comm.bcast(layerwise_times, root=0)
if rank == 0:
logger.info('layerwise backward times: %s', list(layerwise_times))
logger.info('layerwise backward sizes: %s', list(layerwise_sizes))
logger.info('Bencharmked backward time: %f', np.sum(layerwise_times))
logger.info('Model size: %d', np.sum(layerwise_sizes))
else:
seq_layernames, layerwise_times, layerwise_sizes = None, None, None
norm_clip = None
if dnn == 'lstm':
norm_clip = 0.25
elif dnn == 'lstman4':
norm_clip = 400
optimizer = hvd.DistributedOptimizer(trainer.optimizer, named_parameters=trainer.net.named_parameters(), compression=compressors[compressor], is_sparse=is_sparse, density=density, seq_layernames=seq_layernames, layerwise_times=layerwise_times, norm_clip=norm_clip, threshold=threshold, writer=writer)
hvd.broadcast_parameters(trainer.net.state_dict(), root_rank=0)
trainer.update_optimizer(optimizer)
iters_per_epoch = trainer.get_num_of_training_samples() // (nworkers * batch_size * nsteps_update)
times = []
logger.info('max_epochs: %d', max_epochs)
display = 40 if iters_per_epoch > 40 else iters_per_epoch-1
for epoch in range(max_epochs):
hidden = None
if dnn == 'lstm':
hidden = trainer.net.init_hidden()
for i in range(iters_per_epoch):
s = time.time()
optimizer.zero_grad()
for j in range(nsteps_update):
if j < nsteps_update - 1 and nsteps_update > 1:
optimizer.local = True
else:
optimizer.local = False
if dnn == 'lstm':
_, hidden = trainer.train(1, hidden=hidden)
else:
trainer.train(1)
if dnn == 'lstm':
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 0.25)
elif dnn == 'lstman4':
optimizer.synchronize()
torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 400)
trainer.update_model()
times.append(time.time()-s)
if i % display == 0 and i > 0:
time_per_iter = np.mean(times)
logger.warn('Time per iteration including communication: %f, Speed: %f images/s', time_per_iter, batch_size * nsteps_update / time_per_iter)
times = []
optimizer.train_epoch += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="AllReduce trainer")
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--nsteps-update', type=int, default=1)
parser.add_argument('--nworkers', type=int, default=1, help='Just for experiments, and it cannot be used in production')
parser.add_argument('--nwpernode', type=int, default=1, help='Number of workers per node')
parser.add_argument('--dataset', type=str, default='imagenet', choices=_support_datasets, help='Specify the dataset for training')
parser.add_argument('--dnn', type=str, default='resnet50', choices=_support_dnns, help='Specify the neural network for training')
parser.add_argument('--data-dir', type=str, default='./data', help='Specify the data root path')
parser.add_argument('--lr', type=float, default=0.1, help='Default learning rate')
parser.add_argument('--max-epochs', type=int, default=10, help='Default maximum epochs to train')
parser.add_argument('--pretrain', type=str, default=None, help='Specify the pretrain path')
parser.add_argument('--num-steps', type=int, default=35)
parser.add_argument('--compressor', type=str, default='topk', choices=compressors.keys(), help='Specify the compressors if \'compression\' is open')
parser.add_argument('--density', type=float, default=1, help='Density for sparsification')
parser.add_argument('--threshold', type=int, default=0, help='Specify the threshold for gradient merging')
args = parser.parse_args()
batch_size = args.batch_size * args.nsteps_update
prefix = settings.PREFIX
if args.density < 1:
prefix = 'comp-' + args.compressor + '-' + prefix
logdir = 'allreduce-%s-thres-%dkbytes/%s-n%d-bs%d-lr%.4f-ns%d-ds%s' % (prefix, args.threshold/1024, args.dnn, args.nworkers, batch_size, args.lr, args.nsteps_update, str(args.density))
relative_path = './logs/%s'%logdir
utils.create_path(relative_path)
rank = 0
if args.nworkers > 1:
hvd.init()
rank = hvd.rank()
if rank == 0:
tb_runs = './runs/%s'%logdir
writer = None #SummaryWriter(tb_runs)
logfile = os.path.join(relative_path, settings.hostname+'-'+str(rank)+'.log')
hdlr = logging.FileHandler(logfile)
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.info('Configurations: %s', args)
ssgd_with_horovod(args.dnn, args.dataset, args.data_dir, args.nworkers, args.lr, args.batch_size, args.nsteps_update, args.max_epochs, args.nwpernode, args.pretrain, args.num_steps, args.compressor, args.density, args.threshold)