-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
95 lines (75 loc) · 2.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import pickle
import numpy as np
import torch
from tqdm.auto import tqdm
from datasets.common import get_dataloader, maybe_dictionarize
from datasets.registry import get_dataset
def torch_save(model, save_path):
if os.path.dirname(save_path) != "":
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model, save_path)
def torch_load(save_path, device=None):
model = torch.load(save_path, map_location="cpu")
if device is not None:
model = model.to(device)
return model
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def train_diag_fim_logtr(
args,
model,
dataset_name: str,
samples_nr: int = 2000):
model.cuda()
if not dataset_name.endswith('Val'):
dataset_name += 'Val'
dataset = get_dataset(
dataset_name,
model.val_preprocess,
location=args.data_location,
batch_size=args.batch_size,
num_workers=0
)
data_loader = torch.utils.data.DataLoader(
dataset.train_dataset,
batch_size=args.batch_size,
num_workers=0, shuffle=False)
fim = {}
for name, param in model.named_parameters():
if param.requires_grad:
fim[name] = torch.zeros_like(param)
progress_bar = tqdm(total=samples_nr)
seen_nr = 0
while seen_nr < samples_nr:
data_iterator = iter(data_loader)
try:
data = next(data_iterator)
except StopIteration:
data_iterator = iter(data_loader)
data = next(data_loader)
data = maybe_dictionarize(data)
x, y = data['images'], data['labels']
x, y = x.cuda(), y.cuda()
logits = model(x)
outdx = torch.distributions.Categorical(logits=logits).sample().unsqueeze(1).detach()
samples = logits.gather(1, outdx)
idx, batch_size = 0, x.size(0)
for idx in range(batch_size):
model.zero_grad()
torch.autograd.backward(samples[idx], retain_graph=True)
for name, param in model.named_parameters():
if param.requires_grad and hasattr(param, 'grad') and param.grad is not None:
fim[name] += (param.grad * param.grad)
fim[name].detach_()
seen_nr += 1
progress_bar.update(1)
if seen_nr >= samples_nr: break
fim_trace = 0.0
for name, grad2 in fim.items():
fim_trace += grad2.sum()
fim_trace = torch.log(fim_trace / samples_nr).item()
return fim_trace