-
Notifications
You must be signed in to change notification settings - Fork 2
/
GCE.py
35 lines (30 loc) · 1.27 KB
/
GCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# For MNIST dataset we difine classes to 10
classes = 10
class GuidedComplementEntropy(nn.Module):
def __init__(self, alpha):
super(GuidedComplementEntropy, self).__init__()
self.alpha = alpha
# here we implemented step by step for corresponding to our formula
# described in the paper
def forward(self, yHat, y):
self.batch_size = len(y)
self.classes = classes
yHat = F.softmax(yHat, dim=1)
Yg = torch.gather(yHat, 1, torch.unsqueeze(y, 1))
Yg_ = (1 - Yg) + 1e-7 # avoiding numerical issues (first)
# avoiding numerical issues (second)
guided_factor = (Yg + 1e-7) ** self.alpha
Px = yHat / Yg_.view(len(yHat), 1)
Px_log = torch.log(Px + 1e-10) # avoiding numerical issues (third)
y_zerohot = torch.ones(self.batch_size, self.classes).scatter_(
1, y.view(self.batch_size, 1).data.cpu(), 0)
output = Px * Px_log * y_zerohot.cuda()
guided_output = guided_factor.squeeze() * torch.sum(output, dim=1)
loss = torch.sum(guided_output)
loss /= float(self.batch_size)
loss /= math.log(float(self.classes))
return loss