-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmethods.py
executable file
·75 lines (61 loc) · 2.45 KB
/
methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def reg_ETF(output, label, classifier, mse_loss):
# cur_M = classifier.cur_M
target = classifier.cur_M[:, label].T ## B, d
loss = mse_loss(output, target)
return loss
def dot_loss(output, label, cur_M, classifier, criterion, H_length, reg_lam=0):
target = cur_M[:, label].T ## B, d output: B, d
if criterion == 'dot_loss':
loss = - torch.bmm(output.unsqueeze(1), target.unsqueeze(2)).view(-1).mean()
elif criterion == 'reg_dot_loss':
dot = torch.bmm(output.unsqueeze(1), target.unsqueeze(2)).view(-1) #+ classifier.module.bias[label].view(-1)
with torch.no_grad():
M_length = torch.sqrt(torch.sum(target ** 2, dim=1, keepdims=False))
loss = (1/2) * torch.mean(((dot-(M_length * H_length)) ** 2) / H_length)
if reg_lam > 0:
reg_Eh_l2 = torch.mean(torch.sqrt(torch.sum(output ** 2, dim=1, keepdims=True)))
loss = loss + reg_Eh_l2*reg_lam
return loss
def produce_Ew(label, num_classes):
uni_label, count = torch.unique(label, return_counts=True)
batch_size = label.size(0)
uni_label_num = uni_label.size(0)
assert batch_size == torch.sum(count)
gamma = batch_size / uni_label_num
Ew = torch.ones(1, num_classes).cuda(label.device)
for i in range(uni_label_num):
label_id = uni_label[i]
label_count = count[i]
length = torch.sqrt(gamma / label_count)
# length = (gamma / label_count)
#length = torch.sqrt(label_count / gamma)
Ew[0, label_id] = length
return Ew
def produce_global_Ew(cls_num_list):
num_classes = len(cls_num_list)
cls_num_list = torch.tensor(cls_num_list).cuda()
total_num = torch.sum(cls_num_list)
gamma = total_num / num_classes
Ew = torch.sqrt(gamma / cls_num_list)
Ew = Ew.unsqueeze(0)
return Ew
def mixup_data(x, y, alpha=1.0, use_cuda=True):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)