forked from szagoruyko/wide-residual-networks
-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
74 lines (59 loc) · 2.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import math
import torch
import torch.cuda.comm as comm
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel import scatter, parallel_apply, gather
from functools import partial
from torch.autograd import Variable
from nested_dict import nested_dict
from collections import OrderedDict
def cast(params, dtype='float'):
if isinstance(params, dict):
return {k: cast(v, dtype) for k,v in params.items()}
else:
return getattr(params.cuda(), dtype)()
def conv_params(ni,no,k=1,g=1):
assert ni % g == 0
return cast(torch.Tensor(no,ni//g,k,k).normal_(0,2/math.sqrt(ni*k*k)))
def linear_params(ni,no):
return cast(dict(
weight=torch.Tensor(no,ni).normal_(0,2/math.sqrt(ni)),
bias=torch.zeros(no)))
def bnparams(n):
return cast(dict(
# weight=torch.Tensor(n).uniform_(),
weight=torch.ones(n),
bias=torch.zeros(n)))
def bnstats(n):
return cast(dict(
running_mean=torch.zeros(n),
running_var=torch.ones(n)))
def data_parallel(f, input, params, stats, mode, device_ids, output_device=None):
if output_device is None:
output_device = device_ids[0]
if len(device_ids) == 1:
return f(input, params, stats, mode)
def replicate(param_dict, g):
replicas = [{} for d in device_ids]
for k,v in param_dict.iteritems():
for i,u in enumerate(g(v)):
replicas[i][k] = u
return replicas
params_replicas = replicate(params, lambda x: Broadcast(device_ids)(x))
stats_replicas = replicate(stats, lambda x: comm.broadcast(x, device_ids))
replicas = [partial(f, params=p, stats=s, mode=mode)
for p,s in zip(params_replicas, stats_replicas)]
inputs = scatter([input], device_ids)
outputs = parallel_apply(replicas, inputs)
return gather(outputs, output_device)
def flatten_params(params):
flat_params = OrderedDict()
for keys, v in nested_dict(params).iteritems_flat():
if v is not None:
flat_params['.'.join(keys)] = Variable(v, requires_grad=True)
return flat_params
def flatten_stats(stats):
flat_stats = OrderedDict()
for keys, v in nested_dict(stats).iteritems_flat():
flat_stats['.'.join(keys)] = v
return flat_stats