-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
244 lines (194 loc) · 8.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import os
import logging
import queue
import tqdm
import shutil
import yaml
import torch
def get_logger(log_dir, name, verbose=True, log_file='log.txt'):
"""Adapted from https://github.com/chrischute/squad."""
class StreamHandlerWithTQDM(logging.Handler):
"""
Let `logging` print without breaking `tqdm` progress bars.
See also:
> https://stackoverflow.com/questions/38543506
"""
def emit(self, record):
try:
msg = self.format(record)
tqdm.tqdm.write(msg)
self.flush()
except (KeyboardInterrupt, SystemExit):
raise
except:
self.handleError(record)
# Create logger
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(logging.DEBUG)
if verbose:
# Log everything (i.e., DEBUG level and above) to a file
log_path = os.path.join(log_dir, log_file)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.DEBUG)
# Log everything except DEBUG level (i.e., INFO level and above) to console
console_handler = StreamHandlerWithTQDM()
console_handler.setLevel(logging.INFO)
else:
# Log INFO level and above to a file
log_path = os.path.join(log_dir, log_file)
file_handler = logging.FileHandler(log_path)
file_handler.setLevel(logging.INFO)
# Log WARN level and above to console
console_handler = StreamHandlerWithTQDM()
console_handler.setLevel(logging.WARN)
# Create format for the logs
file_formatter = logging.Formatter('[%(asctime)s] %(message)s',
datefmt='%m.%d.%y %H:%M:%S')
file_handler.setFormatter(file_formatter)
console_formatter = logging.Formatter('[%(asctime)s] %(message)s',
datefmt='%m.%d.%y %H:%M:%S')
console_handler.setFormatter(console_formatter)
# add the handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
class CheckpointSaver:
"""Adapted from https://github.com/chrischute/squad."""
def __init__(self, save_dir, max_checkpoints, primary_metric,
maximize_metric=True, logger=None):
super(CheckpointSaver, self).__init__()
self.save_dir = save_dir
self.max_checkpoints = max_checkpoints
self.primary_metric = primary_metric
self.maximize_metric = maximize_metric
self.best_val = None
self.ckpt_paths = queue.PriorityQueue()
self.logger = logger
self._print(f'Saver will {"max" if maximize_metric else "min"}imize {primary_metric}.')
def is_best(self, metric_val):
if metric_val is None:
# No metric reported
return False
if self.best_val is None:
# No checkpoint saved yet
return True
return ((self.maximize_metric and self.best_val < metric_val)
or (not self.maximize_metric and self.best_val > metric_val))
def _print(self, message):
if self.logger is not None:
self.logger.info(message)
def save(self, model, step, eval_results, optimizer=None):
self._print('Saving model...')
if hasattr(model, 'module'):
model = model.module
metric_val = eval_results[self.primary_metric]
checkpoint_path = os.path.join(self.save_dir, f'model_step_{step}.bin')
torch.save(model.state_dict(), checkpoint_path)
if optimizer is not None:
torch.save(optimizer.state_dict(), checkpoint_path + '.optim')
self._print(f'Saved checkpoint: {checkpoint_path}')
# Last checkpoint
last_path = os.path.join(self.save_dir, 'model_last.bin')
shutil.copy(checkpoint_path, last_path)
if optimizer is not None:
shutil.copy(checkpoint_path + '.optim', last_path + '.optim')
self._print(f'{last_path} is now checkpoint from step {step}.')
if self.is_best(metric_val):
# Save the best model
self.best_val = metric_val
best_path = os.path.join(self.save_dir, 'model_best.bin')
shutil.copy(checkpoint_path, best_path)
if optimizer is not None:
shutil.copy(checkpoint_path + '.optim', best_path + '.optim')
self._print('New best checkpoint!')
self._print(f'{best_path} is now checkpoint from step {step}.')
# Add checkpoint path to priority queue (lowest priority removed first)
if self.maximize_metric:
priority_order = metric_val
else:
priority_order = -metric_val
self.ckpt_paths.put((priority_order, checkpoint_path))
# Remove a checkpoint if more than max_checkpoints have been saved
if self.ckpt_paths.qsize() > self.max_checkpoints:
_, worst_ckpt = self.ckpt_paths.get()
try:
os.remove(worst_ckpt)
if optimizer is not None:
os.remove(worst_ckpt + '.optim')
self._print(f'Removed checkpoint: {worst_ckpt}')
except OSError:
# Avoid crashing if checkpoint has been removed or protected
pass
class AverageMeter:
"""Taken from https://github.com/chrischute/squad."""
def __init__(self):
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.__init__()
def update(self, val, num_samples=1):
self.count += num_samples
self.sum += val * num_samples
self.avg = self.sum / self.count
def get_save_dir(base_dir, name, subdir='train', max_idx=100, use_existing_dir=False):
"""Adapted from https://github.com/chrischute/squad."""
for idx in range(1, max_idx):
save_dir = os.path.join(base_dir, subdir, f'{name}_{idx:02d}')
if not os.path.exists(save_dir):
if not use_existing_dir:
os.makedirs(save_dir)
return save_dir
else:
save_dir = os.path.join(base_dir, subdir, f'{name}_{idx - 1:02d}')
return save_dir
raise RuntimeError('Too many save directories created with the same name. '
'Delete old save directories or use another name.')
def get_data_sizes(data_dir, num_epochs, logger=None):
num_train_samples_per_epoch = []
config_file = os.path.join(data_dir, f'data_config.yaml')
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
num_dev_samples = config['dev_size']
for epoch in range(1, num_epochs + 1):
if f'epoch_{epoch}_size' in config:
num_train_samples_per_epoch.append(config[f'epoch_{epoch}_size'])
else:
break
num_unique_train_epochs = len(num_train_samples_per_epoch)
if logger is not None:
logger.info(f'{num_unique_train_epochs} unique epochs of data found.')
for i in range(num_epochs - len(num_train_samples_per_epoch)):
num_train_samples_per_epoch.append(num_train_samples_per_epoch[i])
if logger is not None:
logger.info(f'Number of samples per epoch: {num_train_samples_per_epoch}')
return num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs
def get_parameter_groups(model):
no_decay = ['bias', 'LayerNorm.weight']
parameter_groups = [
{
'params': [param for name, param in model.named_parameters() if not any(nd in name for nd in no_decay)]
},
{
'params': [param for name, param in model.named_parameters() if any(nd in name for nd in no_decay)],
'weight_decay': 0
}
]
return parameter_groups
def get_lr_scheduler(optimizer, num_steps, num_warmup_steps=None, warmup_proportion=None, last_step=-1):
"""
Creates learning rate scheduler with linear warmup and linear decay to zero.
Either num_warmup_steps or warmup_proportion should be provided. If both are provided, num_warmup_steps is used.
"""
if num_warmup_steps is None and warmup_proportion is None:
raise ValueError('Either num_warmup_steps or warmup_proportion should be provided.')
if num_warmup_steps is None:
num_warmup_steps = int(num_steps * warmup_proportion)
def get_lr_multiplier(step):
if step < num_warmup_steps:
return (step + 1) / (num_warmup_steps + 1)
else:
return (num_steps - step) / (num_steps - num_warmup_steps)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr_multiplier, last_epoch=last_step)
return lr_scheduler