-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
134 lines (107 loc) · 5.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import yaml
import numpy as np
import torch
class Logger:
def __init__(self, args, metric, num_data_splits):
self.save_dir = self.get_save_dir(base_dir=args.save_dir, dataset=args.dataset, name=args.name)
self.verbose = args.verbose
self.metric = metric
self.val_metrics = []
self.test_metrics = []
self.best_steps = []
self.num_runs = args.num_runs
self.num_data_splits = num_data_splits
self.cur_run = None
self.cur_data_split = None
print(f'Results will be saved to {self.save_dir}.')
with open(os.path.join(self.save_dir, 'args.yaml'), 'w') as file:
yaml.safe_dump(vars(args), file, sort_keys=False)
def start_run(self, run, data_split):
self.cur_run = run
self.cur_data_split = data_split
self.val_metrics.append(0)
self.test_metrics.append(0)
self.best_steps.append(None)
if self.num_data_splits == 1:
print(f'Starting run {run}/{self.num_runs}...')
else:
print(f'Starting run {run}/{self.num_runs} (using data split {data_split}/{self.num_data_splits})...')
def update_metrics(self, metrics, step):
if metrics[f'val {self.metric}'] > self.val_metrics[-1]:
self.val_metrics[-1] = metrics[f'val {self.metric}']
self.test_metrics[-1] = metrics[f'test {self.metric}']
self.best_steps[-1] = step
if self.verbose:
print(f'run: {self.cur_run:02d}, step: {step:03d}, '
f'train {self.metric}: {metrics[f"train {self.metric}"]:.4f}, '
f'val {self.metric}: {metrics[f"val {self.metric}"]:.4f}, '
f'test {self.metric}: {metrics[f"test {self.metric}"]:.4f}')
def finish_run(self):
self.save_metrics()
print(f'Finished run {self.cur_run}. '
f'Best val {self.metric}: {self.val_metrics[-1]:.4f}, '
f'corresponding test {self.metric}: {self.test_metrics[-1]:.4f} '
f'(step {self.best_steps[-1]}).\n')
def save_metrics(self):
num_runs = len(self.val_metrics)
val_metric_mean = np.mean(self.val_metrics).item()
val_metric_std = np.std(self.val_metrics, ddof=1).item() if len(self.val_metrics) > 1 else np.nan
test_metric_mean = np.mean(self.test_metrics).item()
test_metric_std = np.std(self.test_metrics, ddof=1).item() if len(self.test_metrics) > 1 else np.nan
metrics = {
'num runs': num_runs,
f'val {self.metric} mean': val_metric_mean,
f'val {self.metric} std': val_metric_std,
f'test {self.metric} mean': test_metric_mean,
f'test {self.metric} std': test_metric_std,
f'val {self.metric} values': self.val_metrics,
f'test {self.metric} values': self.test_metrics,
'best steps': self.best_steps
}
with open(os.path.join(self.save_dir, 'metrics.yaml'), 'w') as file:
yaml.safe_dump(metrics, file, sort_keys=False)
def print_metrics_summary(self):
with open(os.path.join(self.save_dir, 'metrics.yaml'), 'r') as file:
metrics = yaml.safe_load(file)
print(f'Finished {metrics["num runs"]} runs.')
print(f'Val {self.metric} mean: {metrics[f"val {self.metric} mean"]:.4f}')
print(f'Val {self.metric} std: {metrics[f"val {self.metric} std"]:.4f}')
print(f'Test {self.metric} mean: {metrics[f"test {self.metric} mean"]:.4f}')
print(f'Test {self.metric} std: {metrics[f"test {self.metric} std"]:.4f}')
@staticmethod
def get_save_dir(base_dir, dataset, name):
idx = 1
save_dir = os.path.join(base_dir, dataset, f'{name}_{idx:02d}')
while os.path.exists(save_dir):
idx += 1
save_dir = os.path.join(base_dir, dataset, f'{name}_{idx:02d}')
os.makedirs(save_dir)
return save_dir
def get_parameter_groups(model):
no_weight_decay_names = ['bias', 'normalization', 'label_embeddings']
parameter_groups = [
{
'params': [param for name, param in model.named_parameters()
if not any(no_weight_decay_name in name for no_weight_decay_name in no_weight_decay_names)]
},
{
'params': [param for name, param in model.named_parameters()
if any(no_weight_decay_name in name for no_weight_decay_name in no_weight_decay_names)],
'weight_decay': 0
},
]
return parameter_groups
def get_lr_scheduler_with_warmup(optimizer, num_warmup_steps=None, num_steps=None, warmup_proportion=None,
last_step=-1):
if num_warmup_steps is None and (num_steps is None or warmup_proportion is None):
raise ValueError('Either num_warmup_steps or num_steps and warmup_proportion should be provided.')
if num_warmup_steps is None:
num_warmup_steps = int(num_steps * warmup_proportion)
def get_lr_multiplier(step):
if step < num_warmup_steps:
return (step + 1) / (num_warmup_steps + 1)
else:
return 1
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr_multiplier, last_epoch=last_step)
return lr_scheduler