-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
34 lines (29 loc) · 1.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import torch
import random
import numpy as np
cudnn_deterministic = True
def seed_everything(seed=0):
"""Fix all random seeds"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = cudnn_deterministic
def print_summary(acc_taw, acc_tag, forg_taw, forg_tag):
"""Print summary of results"""
for name, metric in zip(['TAw Acc', 'TAg Acc', 'TAw Forg', 'TAg Forg'], [acc_taw, acc_tag, forg_taw, forg_tag]):
print('*' * 108)
print(name)
for i in range(metric.shape[0]):
print('\t', end='')
for j in range(metric.shape[1]):
print('{:5.1f}% '.format(100 * metric[i, j]), end='')
if np.trace(metric) == 0.0:
if i > 0:
print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i].mean()), end='')
else:
print('\tAvg.:{:5.1f}% '.format(100 * metric[i, :i + 1].mean()), end='')
print()
print('*' * 108)