-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfocal_loss.py
145 lines (105 loc) · 4.92 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __init__ import *
import torch.nn as nn
from anchor_assigner import Anchor_Assigner
class Focal_Loss(nn.Module):
__doc__ = r"""
fast and applicable implementation of focal loss for object detection
Args:
fore_th: threshold to filter anchors to foreground
back_th: threshold to filter anchors to background
alpha: alpha of focal loss
gamma: gamma of focal loss
beta: beta of smooth-L1 loss
fore_mean: average each loss by the number of foregrounds
reg_weight: weight for regression loss. if None, it is 1.0
average: average the losses by the batch size
bbox_format: bbox format for batch_iou
Output:
total loss, classification loss, regression loss
"""
def __init__(self,
fore_th: float = 0.5,
back_th: float = 0.4,
alpha: float = 0.25,
gamma: float = 1.5,
beta: float = 0.1,
fore_mean: bool = True,
reg_weight: Optional[float] = None,
average: bool = True,
bbox_format: str = 'cxcywh'
):
super().__init__()
self.fore_th = fore_th
self.back_th = back_th
self.anchor_assigner = Anchor_Assigner(fore_th, back_th, False, False, bbox_format)
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.fore_mean = fore_mean
self.reg_weight = reg_weight if reg_weight else 1.0
self.average = average
@classmethod
def focal_loss(cls, cls_pred, fore_idx, back_idx, fore_label_cls, alpha, gamma, mean):
fore_pred = cls_pred[fore_idx]
back_pred = cls_pred[back_idx]
fore_pred_t = torch.where(fore_label_cls == 1, fore_pred, 1 - fore_pred)
back_pred_t = 1 - back_pred
fore_alpha_t = torch.where(fore_label_cls == 1, alpha, 1 - alpha)
back_alpha_t = 1 - alpha
fore_weight = -1 * fore_alpha_t * torch.pow(1 - fore_pred_t, gamma)
back_weight = -1 * back_alpha_t * torch.pow(1 - back_pred_t, gamma)
fore_loss = fore_weight * torch.log(fore_pred_t)
back_loss = back_weight * torch.log(back_pred_t)
loss = torch.sum(fore_loss) + torch.sum(back_loss)
if mean:
num = fore_idx.size(0)
loss = loss / num if num > 0 else loss
return loss
@classmethod
def smooothL1_loss(cls, reg_pred, anchors, fore_idx, fore_label_bbox, beta, mean):
fore_pred = reg_pred[fore_idx]
fore_anchor = anchors.squeeze()[fore_idx]
reg_label = torch.zeros_like(fore_label_bbox)
reg_label[..., 0] = (fore_label_bbox[..., 0] - fore_anchor[..., 0]) / fore_anchor[..., 2]
reg_label[..., 1] = (fore_label_bbox[..., 1] - fore_anchor[..., 1]) / fore_anchor[..., 3]
reg_label[..., 2] = torch.log(fore_label_bbox[..., 2].clamp(min=1) / fore_anchor[..., 2])
reg_label[..., 3] = torch.log(fore_label_bbox[..., 3].clamp(min=1) / fore_anchor[..., 3])
mae = torch.abs(reg_label - fore_pred)
loss = torch.where(torch.le(mae, beta), 0.5 * (mae ** 2) / beta, mae - 0.5 * beta)
loss = torch.sum(loss)
if mean:
num = 4 * fore_idx.size(0)
loss = loss / num if num > 0 else loss
return loss
def forward(self,
preds: Tensor,
anchors: Tensor,
labels: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
if len(preds.shape) != 3:
raise ValueError("preds should be given in 3d tensor")
if len(anchors.shape) != 3:
raise ValueError("anchors should be given in 3d tensor")
if len(labels.shape) != 3:
raise ValueError("labels should be given in 3d tensor")
reg_preds = preds[..., :4]
cls_preds = preds[..., 4:]
cls_preds = cls_preds.clamp(1e-5, 1.0 - 1e-5)
target_assigns = self.anchor_assigner(labels, anchors)
cls_losses, reg_losses = [], []
for i, assign in enumerate(target_assigns):
fore_idx = assign['foreground'][0]
back_idx = assign['background'][0]
fore_label_cls = assign['foreground'][1][..., 4:]
fore_label_bbox = assign['foreground'][1][..., :4]
cls_losses.append(self.focal_loss(cls_preds[i], fore_idx, back_idx, fore_label_cls, self.alpha, self.gamma, self.fore_mean))
reg_losses.append(self.smooothL1_loss(reg_preds[i], anchors, fore_idx, fore_label_bbox, self.beta, self.fore_mean))
cls_loss = sum(cls_losses)
reg_loss = sum(reg_losses)
total_loss = cls_loss + self.reg_weight * reg_loss
if self.average:
batch = len(target_assigns)
total_loss /= batch
cls_loss /= batch
reg_loss /= batch
return total_loss, cls_loss, reg_loss