-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathutils.py
103 lines (80 loc) · 2.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import time
import json
import numpy as np
import torch.nn as nn
from numpy.random import uniform, normal, randint, choice
def save_results(r):
date = time.strftime("%Y%m%d_%H%M%S")
filename = date + '_results.json'
param_path = os.path.join('./results/', filename)
with open(param_path, 'w') as fp:
json.dump(r, fp, indent=4, sort_keys=True)
def find_key(params, partial_key):
return next(v for k, v in params.items() if partial_key in k)
def sample_from(space):
"""
Sample a hyperparameter value from a distribution
defined and parametrized in the list `space`.
"""
distrs = {
'choice': choice,
'randint': randint,
'uniform': uniform,
'normal': normal,
}
s = space[0]
np.random.seed(int(time.time() + np.random.randint(0, 300)))
log = s.startswith('log_')
s = s[len('log_'):] if log else s
quantized = s.startswith('q')
s = s[1:] if quantized else s
distr = distrs[s]
if s == 'choice':
return distr(space[1])
samp = distr(space[1], space[2])
if log:
samp = np.exp(samp)
if quantized:
samp = round((samp / space[3]) * space[3])
return samp
def str2act(a):
if a == 'relu':
return nn.ReLU()
elif a == 'selu':
return nn.SELU()
elif a == 'elu':
return nn.ELU()
elif a == 'tanh':
return nn.Tanh()
elif a == 'sigmoid':
return nn.Sigmoid()
else:
raise ValueError('[!] Unsupported activation.')
def prepare_dirs(dirs):
for path in dirs:
if not os.path.exists(path):
os.makedirs(path)
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view(self.shape, -1)
class AverageMeter(object):
"""
Computes and stores the average and
current value.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count