-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
38 lines (34 loc) · 1.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import h5py
import torch
import shutil
def save_net(fname, net):
with h5py.File(fname, 'w') as h5f:
for k, v in net.state_dict().items():
h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
with h5py.File(fname, 'r') as h5f:
for k, v in net.state_dict().items():
param = torch.from_numpy(np.asarray(h5f[k]))
v.copy_(param)
def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'):
torch.save(state, task_id+filename)
if is_best:
shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar')
def auto_loss(snet_out, cls_err, variables, M, alpha):
w = None
for p in variables:
w = torch.cat((w, p.view(-1))) if w is not None else p.view(-1)
l1 = F.l1_loss(w, torch.zeros_like(w))
loss = 1-snet_out
# loss= loss*torch.log(torch.clamp(cls_err+1e-5, min=1e-5, max=1))
loss = loss * cls_err
# print(loss)
# loss = loss + snet_out*torch.log(torch.clamp(1-cls_err, min=1e-5, max=1))
loss = loss + snet_out * torch.clamp(M-cls_err, min=0)
# print(loss)
# loss = -1 * loss
res = torch.sum(loss)+alpha*l1
return res
def mse_loss(output, target):
loss = torch.pow((output - target), 2)
return loss