-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
157 lines (127 loc) · 5.21 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import torch.nn.functional as F
EPS = 1e-6
class lossAV(nn.Module):
def __init__(self):
super(lossAV, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(256, 2)
def forward(self, x, labels=None):
x = x.squeeze(1)
x = self.FC(x)
if labels == None:
predScore = x[:,1]
predScore = predScore.t()
predScore = predScore.view(-1).detach().cpu().numpy()
return predScore
else:
nloss = self.criterion(x, labels)
predScore = F.softmax(x, dim = -1)
predLabel = torch.round(F.softmax(x, dim = -1))[:,1]
correctNum = (predLabel == labels).sum().float()
return nloss, predScore, predLabel, correctNum
class lossA(nn.Module):
def __init__(self):
super(lossA, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels):
x = x.squeeze(1)
x = self.FC(x)
nloss = self.criterion(x, labels)
return nloss
class lossV(nn.Module):
def __init__(self):
super(lossV, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels):
x = x.squeeze(1)
x = self.FC(x)
nloss = self.criterion(x, labels)
return nloss
class loss_aud(nn.Module):
def __init__(self):
super(loss_aud, self).__init__()
def sisnr(self, x, s, asd_pred=None, asd_gt=None, eps=1e-8):
def l2norm(mat, keepdim=False):
return torch.norm(mat, dim=-1, keepdim=keepdim)
if x.shape != s.shape:
raise RuntimeError(
"Dimention mismatch when calculate si-snr, {} vs {}".format(
x.shape, s.shape))
if asd_pred is not None and asd_gt is not None:
speech_gap_e_frame = x.size(0) * x.size(1) // asd_gt.size(0)
frame_gap = asd_gt.size(0) // x.size(0)
asd_p = torch.zeros(x.size()).to(x.device)
asd_g = torch.zeros(x.size()).to(x.device)
k = 0
for n in range(x.size(0)):
for i in range(0, x.size(1), speech_gap_e_frame):
asd_p[n][i:i + speech_gap_e_frame] = asd_pred[k]
asd_g[n][i:i + speech_gap_e_frame] = asd_gt[k]
k = n * frame_gap + i // speech_gap_e_frame + 1
# x = x + torch.mul(x, (asd_p + asd_g))
x = x + torch.mul(x, asd_p )
x_zm = x - torch.mean(x, dim=-1, keepdim=True)
s_zm = s - torch.mean(s, dim=-1, keepdim=True)
# S_target
t = torch.sum(x_zm * s_zm, dim=-1, keepdim=True) * s_zm / (
l2norm(s_zm, keepdim=True)**2 + eps)
# e_noise = pred - S_target
# SI-SNR
return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
def cal_SISNR(source, estimate_source):
"""Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
Args:
source: torch tensor, [batch size, sequence length]
estimate_source: torch tensor, [batch size, sequence length]
Returns:
SISNR, [batch size]
"""
assert source.size() == estimate_source.size()
# Step 1. Zero-mean norm
source = source - torch.mean(source, axis = -1, keepdim=True)
estimate_source = estimate_source - torch.mean(estimate_source, axis = -1, keepdim=True)
# Step 2. SI-SNR
# s_target = <s', s>s / ||s||^2
ref_energy = torch.sum(source ** 2, axis = -1, keepdim=True) + EPS
proj = torch.sum(source * estimate_source, axis = -1, keepdim=True) * source / ref_energy
# e_noise = s' - s_target
noise = estimate_source - proj
# SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
ratio = torch.sum(proj ** 2, axis = -1) / (torch.sum(noise ** 2, axis = -1) + EPS)
sisnr = 10 * torch.log10(ratio + EPS)
return sisnr
def forward(self, pre, label, asd_pred=None, asd_gt=None):
# loss = 0
# for i in range(pre.shape[0]):
# loss += self.sisnr(pre[i], label[i])
return -torch.mean(self.sisnr(pre, label, asd_pred, asd_gt))
class BaseLoss(nn.Module):
def __init__(self):
super(BaseLoss, self).__init__()
def forward(self, preds, targets, weight=None):
if isinstance(preds, list):
N = len(preds)
if weight is None:
weight = preds[0].new_ones(1)
errs = [self._forward(preds[n], targets[n], weight[n])
for n in range(N)]
err = torch.mean(torch.stack(errs))
elif isinstance(preds, torch.Tensor):
if weight is None:
weight = preds.new_ones(1)
err = self._forward(preds, targets, weight)
return err
class L1Loss(BaseLoss):
def __init__(self):
super(L1Loss, self).__init__()
def _forward(self, pred, target, weight):
return torch.mean(weight * torch.abs(pred - target))
class L2Loss(BaseLoss):
def __init__(self):
super(L2Loss, self).__init__()
def _forward(self, pred, target, weight):
return torch.mean(weight * torch.pow(pred - target, 2))