Skip to content

Commit

Permalink
remove ctc
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanio committed May 17, 2022
1 parent 8613abe commit 8ba8206
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 47 deletions.
5 changes: 0 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
@hydra.main(config_path=args.cp, config_name=args.cn)
def main(cfg: DictConfig):
text_process = TextProcess(**cfg.text_process)
if cfg.decoder.type == "beamsearch":
ctc_decoder = CTCDecoder(text_process=text_process, **cfg.ctcdecoder)
else:
ctc_decoder = None

trainset = VivosDataset(**cfg.dataset, subset="train")
testset = VivosDataset(**cfg.dataset, subset="test")
Expand All @@ -32,7 +28,6 @@ def main(cfg: DictConfig):
model = DeepSpeechModule(
n_class=n_class,
text_process=text_process,
ctc_decoder=ctc_decoder,
cfg_optim=cfg.optimizer,
**cfg.model
)
Expand Down
18 changes: 4 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(
n_class: int,
lr: float,
text_process: TextProcess,
ctc_decoder: CTCDecoder,
cfg_optim: dict,
):
super().__init__()
Expand All @@ -25,7 +24,6 @@ def __init__(
)
self.lr = lr
self.text_process = text_process
self.ctc_decoder = ctc_decoder
self.cal_wer = torchmetrics.WordErrorRate()
self.cfg_optim = cfg_optim
self.criterion = nn.CTCLoss(zero_infinity=True)
Expand Down Expand Up @@ -63,12 +61,8 @@ def validation_step(self, batch, batch_idx):
outputs.permute(1, 0, 2), targets, input_lengths, target_lengths
)

if self.ctc_decoder:
# unsqueeze for batchsize 1
predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs]
else:
decode = outputs.argmax(dim=-1)
predicts = [self.text_process.decode(sent) for sent in decode]
decode = outputs.argmax(dim=-1)
predicts = [self.text_process.decode(sent) for sent in decode]

targets = [self.text_process.int2text(sent) for sent in targets]

Expand All @@ -92,12 +86,8 @@ def test_step(self, batch, batch_idx):
outputs.permute(1, 0, 2), targets, input_lengths, target_lengths
)

if self.ctc_decoder:
# unsqueeze for batchsize 1
predicts = [self.ctc_decoder(sent.unsqueeze(0)) for sent in outputs]
else:
decode = outputs.argmax(dim=-1)
predicts = [self.text_process.decode(sent) for sent in decode]
decode = outputs.argmax(dim=-1)
predicts = [self.text_process.decode(sent) for sent in decode]

targets = [self.text_process.int2text(sent) for sent in targets]

Expand Down
55 changes: 27 additions & 28 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import ctcdecode


class TextProcess:
Expand Down Expand Up @@ -45,32 +44,32 @@ def int2text(self, s: torch.Tensor) -> str:
return "".join([self.list_vocab[i] for i in s if i > 2])


class CTCDecoder:
def __init__(
self,
alpha: float = 0.5,
beta: float = 0.96,
beam_size: int = 100,
kenlm_path: str = None,
text_process: TextProcess = None,
):
self.text_process = text_process
labels = text_process.list_vocab
blank_id = labels.index("<p>")
# class CTCDecoder:
# def __init__(
# self,
# alpha: float = 0.5,
# beta: float = 0.96,
# beam_size: int = 100,
# kenlm_path: str = None,
# text_process: TextProcess = None,
# ):
# self.text_process = text_process
# labels = text_process.list_vocab
# blank_id = labels.index("<p>")

print("loading beam search with lm...")
self.decoder = ctcdecode.CTCBeamDecoder(
labels,
alpha=alpha,
beta=beta,
beam_width=beam_size,
blank_id=blank_id,
model_path=kenlm_path,
)
print("finished loading beam search")
# print("loading beam search with lm...")
# self.decoder = ctcdecode.CTCBeamDecoder(
# labels,
# alpha=alpha,
# beta=beta,
# beam_width=beam_size,
# blank_id=blank_id,
# model_path=kenlm_path,
# )
# print("finished loading beam search")

def __call__(self, output: torch.Tensor) -> str:
beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output)
tokens = beam_result[0][0]
seq_len = out_seq_len[0][0]
return self.text_process.int2text(tokens[:seq_len])
# def __call__(self, output: torch.Tensor) -> str:
# beam_result, beam_scores, timesteps, out_seq_len = self.decoder.decode(output)
# tokens = beam_result[0][0]
# seq_len = out_seq_len[0][0]
# return self.text_process.int2text(tokens[:seq_len])

0 comments on commit 8ba8206

Please sign in to comment.