diff --git a/model.py b/model.py index 261ba3e..0a37c08 100644 --- a/model.py +++ b/model.py @@ -27,6 +27,7 @@ def __init__( self.ctc_decoder = ctc_decoder self.cal_wer = torchmetrics.WordErrorRate() self.cfg_optim = cfg_optim + self.criterion = nn.CTCLoss(zero_infinity=True) def forward(self, inputs): """predicting function""" @@ -45,7 +46,7 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx): inputs, input_lengths, targets, target_lengths = batch outputs = self.deepspeech(inputs) - loss = F.ctc_loss( + loss = self.criterion( outputs.permute(1, 0, 2), targets, input_lengths, target_lengths ) @@ -57,7 +58,7 @@ def validation_step(self, batch, batch_idx): inputs, input_lengths, targets, target_lengths = batch outputs = self.deepspeech(inputs) - loss = F.ctc_loss( + loss = self.criterion( outputs.permute(1, 0, 2), targets, input_lengths, target_lengths ) @@ -86,7 +87,7 @@ def test_step(self, batch, batch_idx): inputs, input_lengths, targets, target_lengths = batch outputs = self.deepspeech(inputs) - loss = F.ctc_loss( + loss = self.criterion( outputs.permute(1, 0, 2), targets, input_lengths, target_lengths )