-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathGHM_loss.py
73 lines (53 loc) · 2.06 KB
/
GHM_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import nn
import torch.nn.functional as F
class GHM_Loss(nn.Module):
def __init__(self, bins, alpha):
super(GHM_Loss, self).__init__()
self._bins = bins
self._alpha = alpha
self._last_bin_count = None
def _g2bin(self, g):
return torch.floor(g * (self._bins - 0.0001)).long()
def _custom_loss(self, x, target, weight):
raise NotImplementedError
def _custom_loss_grad(self, x, target):
raise NotImplementedError
def forward(self, x, target):
g = torch.abs(self._custom_loss_grad(x, target)).detach()
bin_idx = self._g2bin(g)
bin_count = torch.zeros((self._bins))
for i in range(self._bins):
bin_count[i] = (bin_idx == i).sum().item()
N = (x.size(0) * x.size(1))
if self._last_bin_count is None:
self._last_bin_count = bin_count
else:
bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
self._last_bin_count = bin_count
nonempty_bins = (bin_count > 0).sum().item()
gd = bin_count * nonempty_bins
gd = torch.clamp(gd, min=0.0001)
beta = N / gd
return self._custom_loss(x, target, beta[bin_idx])
class GHMC_Loss(GHM_Loss):
def __init__(self, bins, alpha):
super(GHMC_Loss, self).__init__(bins, alpha)
def _custom_loss(self, x, target, weight):
return F.binary_cross_entropy_with_logits(x, target, weight=weight)
def _custom_loss_grad(self, x, target):
return torch.sigmoid(x).detach() - target
class GHMR_Loss(GHM_Loss):
def __init__(self, bins, alpha, mu):
super(GHMR_Loss, self).__init__(bins, alpha)
self._mu = mu
def _custom_loss(self, x, target, weight):
d = x - target
mu = self._mu
loss = torch.sqrt(d * d + mu * mu) - mu
N = x.size(0) * x.size(1)
return (loss * weight).sum() / N
def _custom_loss_grad(self, x, target):
d = x - target
mu = self._mu
return d / torch.sqrt(d * d + mu * mu)