-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss_function.py
67 lines (53 loc) · 2.2 KB
/
loss_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import torch
from torch import nn
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()
self.bce_loss = nn.BCELoss()
def forward(self, outputs, targets, smooth=0):
"""
DiceBCELoss - Compute the Dice-BCE Loss.
Args:
outputs (tensor): output tensor
targets (tensor): target tensor
Returns:
dice_BCE_loss (tensor): the Dice-BCE Loss
"""
# Flatten output and target tensors
outputs = outputs.view(-1)
targets = targets.view(-1)
# Compute the dice Loss
intersection = (outputs * targets).sum()
dice_loss = 1 - (2. * intersection + smooth) / (outputs.sum() + targets.sum() + smooth)
# Compute the standard binary cross-entropy (BCE) loss
BCE_loss = self.bce_loss(outputs, targets)
dice_BCE_loss = dice_loss + BCE_loss
return dice_BCE_loss
class BCEIoULoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(BCEIoULoss, self).__init__()
self.bce_loss = nn.BCELoss()
def forward(self, outputs, targets, beta=0.6, alpha=0.25, gamma=3, smooth=0):
"""
BCEIoULoss - Compute the BCEIoULoss Loss.
Args:
outputs (tensor): output tensor
targets (tensor): target tensor
Returns:
BCE_IoU_loss (tensor): the BCE-IoU Loss
"""
# Flatten output and target tensors
outputs = outputs.view(-1)
targets = targets.view(-1)
# Compute the intersection-over-union (IoU) loss
intersection = (outputs * targets).sum()
total = (outputs + targets).sum()
union = total - intersection
IoU_loss = 1 - (intersection + smooth) / (union + smooth)
# Compute the modified BCE loss
BCE_loss = self.bce_loss(outputs, targets)
BCE_exp = torch.exp(-BCE_loss)
modified_BCE_loss = alpha * (1 - BCE_exp) ** gamma * BCE_loss
BCE_IoU_loss = beta * modified_BCE_loss + (1 - beta) * IoU_loss
return BCE_IoU_loss