Skip to content

Commit

Permalink
reset commit history
Browse files Browse the repository at this point in the history
  • Loading branch information
JoJoBarthold2 committed Sep 18, 2023
1 parent ec8bfe9 commit 0f14789
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
5 changes: 4 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@ checkpoints:
inference:
model_load_path: "YOUR/PATH" # path to load model from
beam_width: 10 # beam width for beam search
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically

lang_model:
path: "data/mls_lm_german" #path where model and supplementary files are stored
6 changes: 4 additions & 2 deletions swr2_asr/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]:
idx,
) # type: ignore

def create_lexicon(vocab_counts_path, lexicon_path):
def create_lexicon(vocab_counts_path, lexicon_path):

words_list = []
with open(vocab_counts_path, 'r') as file:
Expand All @@ -361,6 +361,8 @@ def create_lexicon(vocab_counts_path, lexicon_path):
file.write(f"{word} ")
for char in word:
file.write(char + ' ')
file.write("|")
file.write("<SPACE>")




29 changes: 27 additions & 2 deletions swr2_asr/utils/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import torch

from swr2_asr.utils.tokenizer import CharTokenizer


from swr2_asr.utils.data import create_lexicon
import os
from torchaudio.datasets.utils import _extract_tar
from torchaudio.models.decoder import ctc_decoder
LEXICON = "lexicon.txt"
# TODO: refactor to use torch CTC decoder class
def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
Expand All @@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll

# TODO: add beam search decoder

def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path):
if not os.path.isdir(lang_model_path):
url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz"
torch.hub.download_url_to_file(
url, "data/mls_lm_german.tar.gz" )
_extract_tar("data/mls_lm_german.tar.gz", overwrite=True)
if not os.path.isfile(tokenizer_txt_path):
tokenizer.create_txt(tokenizer_txt_path)

lexicon_path= os.join(lang_model_path, LEXICON)
if not os.path.isfile(lexicon_path):
occurences_path = os.join(lang_model_path,"vocab_counts.txt")
create_lexicon(occurences_path, lexicon_path)
lm_path = os.join(lang_model_path,"3-gram_lm.apa")
decoder = ctc_decoder(lexicon = lexicon_path,
tokenizer = tokenizer_txt_path,
lm =lm_path,
blank_token = '_',
nbest =1,
sil_token= '<SPACE>',
unk_word = '<UNK>')
return decoder

0 comments on commit 0f14789

Please sign in to comment.