-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextra_loss.py
84 lines (72 loc) · 3.61 KB
/
extra_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
# @Time : 2021/7/11 12:32 下午
# @Author : Bubble
# @FileName: extra_loss.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingLoss(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.01):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, num_classes=2, size_average=True):
"""
focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)
步骤详细的实现了 focal_loss损失函数.
:param alpha: 阿尔法α,类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25
:param gamma: 伽马γ,难易样本调节参数. retainnet中设置为2
:param num_classes: 类别数量
:param size_average: 损失计算方式,默认取均值
"""
super(FocalLoss, self).__init__()
self.size_average = size_average
if isinstance(alpha, list):
assert len(alpha) == num_classes # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重
# print(" --- Focal_loss alpha = {}, 将对每一类权重进行精细化赋值 --- ".format(alpha))
self.alpha = torch.Tensor(alpha)
else:
assert alpha < 1 # 如果α为一个常数,则降低第一类的影响,在目标检测中为第一类
# print(" --- Focal_loss alpha = {} ,将对背景类进行衰减,请在目标检测任务中使用 --- ".format(alpha))
self.alpha = torch.zeros(num_classes)
self.alpha[0] += alpha
self.alpha[1:] += (1 - alpha) # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]
self.gamma = gamma
def forward(self, preds, labels):
"""
focal_loss损失计算
:param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数
:param labels: 实际类别. size:[B,N] or [B]
:return:
"""
# assert preds.dim()==2 and labels.dim()==1
preds = preds.view(-1, preds.size(-1))
self.alpha = self.alpha.to(preds.device)
preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax
preds_softmax = torch.exp(preds_logsoft) # softmax
preds_softmax = preds_softmax.gather(1, labels.view(-1, 1)) # 这部分实现nll_loss ( crossempty = log_softmax + nll )
preds_logsoft = preds_logsoft.gather(1, labels.view(-1, 1))
self.alpha = self.alpha.gather(0, labels.view(-1))
loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma),
preds_logsoft) # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ
loss = torch.mul(self.alpha, loss.t())
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss