-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
51 lines (41 loc) · 1.79 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class GeneralizedCELoss(nn.Module):
def __init__(self, q=0.7):
super(GeneralizedCELoss, self).__init__()
self.q = q
def forward(self, logits, targets):
p = F.softmax(logits, dim=1)
if np.isnan(p.mean().item()):
raise NameError('GCE_p')
Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
# modify gradient of cross entropy
loss_weight = (Yg.squeeze().detach()**self.q)*self.q
if np.isnan(Yg.mean().item()):
raise NameError('GCE_Yg')
loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight
return loss
class EMA:
def __init__(self, label, num_classes=None, alpha=0.9,device = 'cuda:0'):
self.device = device
self.label = label.to(self.device)
self.alpha = alpha
self.parameter = torch.zeros(label.size(0))
self.updated = torch.zeros(label.size(0))
self.num_classes = num_classes
self.max = torch.zeros(self.num_classes).to(self.device)
def update(self, data, index, curve=None, iter_range=None, step=None):
self.parameter = self.parameter.to(self.device)
self.updated = self.updated.to(self.device)
index = index.to(self.device)
if curve is None:
self.parameter[index] = self.alpha * self.parameter[index] + (1 - self.alpha * self.updated[index]) * data
else:
alpha = curve ** -(step / iter_range)
self.parameter[index] = alpha * self.parameter[index] + (1 - alpha * self.updated[index]) * data
self.updated[index] = 1
def max_loss(self, label):
label_index = torch.where(self.label == label)[0]
return self.parameter[label_index].max()