Skip to content

Commit

Permalink
decoder changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 18, 2023
1 parent d568904 commit d5e482b
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 93 deletions.
55 changes: 33 additions & 22 deletions config.philipp.yaml
Original file line number Diff line number Diff line change
@@ -1,34 +1,45 @@
dataset:
download: True
dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: True # set to True if you want to use limited supervision
dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%)
shuffle: True

model:
n_cnn_layers: 3
n_rnn_layers: 5
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets

training:
learning_rate: 0.0005
batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 150
eval_every_n: 5 # evaluate every n epochs
num_workers: 4 # number of workers for dataloader
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically

dataset:
download: true
dataset_root_path: "data" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: false # set to True if you want to use limited supervision
dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%)
shuffle: true
dropout: 0.6 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets

tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"

checkpoints:
model_load_path: "data/runs/epoch31" # path to load model from
model_save_path: "data/runs/epoch" # path to save model to
decoder:
type: "lm" # greedy, or lm (beam search)

lm: # config for lm decoder
language_model_path: "data" # path where model and supplementary files are stored
language: "german"
n_gram: 3 # n-gram size of the language model, 3 or 5
beam_size: 50
beam_threshold: 50
n_best: 1
lm_weight: 2
word_score: 0

training:
learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 3
eval_every_n: 3 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader

checkpoints: # use "~" to disable saving/loading
model_load_path: "YOUR/PATH" # path to load model from
model_save_path: "YOUR/PATH" # path to save model to

inference:
model_load_path: "data/runs/epoch30" # path to load model from
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
model_load_path: "data/epoch67" # path to load model from
43 changes: 26 additions & 17 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
dataset:
download: True
dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: False # set to True if you want to use limited supervision
dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
shuffle: True

model:
n_cnn_layers: 3
n_rnn_layers: 5
Expand All @@ -6,32 +14,33 @@ model:
stride: 2
dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets

tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"

decoder:
type: "greedy" # greedy, or lm (beam search)

lm: # config for lm decoder
language_model_path: "data" # path where model and supplementary files are stored
language: "german"
n_gram: 3 # n-gram size of the language model, 3 or 5
beam_size: 50
beam_threshold: 50
n_best: 1
lm_weight: 2,
word_score: 0,

training:
learning_rate: 5e-4
learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 3
eval_every_n: 3 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader

dataset:
download: True
dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: False # set to True if you want to use limited supervision
dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
shuffle: True

tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml"

checkpoints:
checkpoints: # use "~" to disable saving/loading
model_load_path: "YOUR/PATH" # path to load model from
model_save_path: "YOUR/PATH" # path to save model to

inference:
model_load_path: "YOUR/PATH" # path to load model from
beam_width: 10 # beam width for beam search
device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically

lang_model:
path: "data/mls_lm_german" #path where model and supplementary files are stored
35 changes: 15 additions & 20 deletions swr2_asr/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,10 @@
import yaml

from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer


def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = tokenizer.get_blank_token()
decodes = []
for args in arg_maxes:
decode = []
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
decodes.append(tokenizer.decode(decode))
return decodes


@click.command()
@click.option(
"--config_path",
Expand All @@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None:
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
decoder_config = config_dict.get("decoder", {})

if inference_config["device"] == "cpu":
if inference_config.get("device", "") == "cpu":
device = "cpu"
elif inference_config["device"] == "cuda":
elif inference_config.get("device", "") == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif inference_config.get("device", "") == "mps":
device = "mps"
else:
device = "cpu"
device = torch.device(device) # pylint: disable=no-member

tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
Expand Down Expand Up @@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None:
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
spec = spec.to(device)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
decoded_preds = greedy_decoder(output, tokenizer)

print(decoded_preds)
decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)

preds = decoder(output)
preds = " ".join(preds[0][0].words).strip()

print(preds)


if __name__ == "__main__":
Expand Down
121 changes: 87 additions & 34 deletions swr2_asr/utils/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Decoder for CTC-based ASR.""" ""
from dataclasses import dataclass
import os

import torch
Expand All @@ -9,37 +10,39 @@
from swr2_asr.utils.tokenizer import CharTokenizer


# TODO: refactor to use torch CTC decoder class
def greedy_decoder(
output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True
): # pylint: disable=redefined-outer-name
"""Greedily decode a sequence."""
blank_label = tokenizer.get_blank_token()
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
decodes.append(tokenizer.decode(decode))
return decodes, targets


def beam_search_decoder(
@dataclass
class DecoderOutput:
"""Decoder output."""

words: list[str]


def decoder_factory(decoder_type: str = "greedy") -> callable:
"""Decoder factory."""
if decoder_type == "greedy":
return get_greedy_decoder
if decoder_type == "lm":
return get_beam_search_decoder
raise NotImplementedError


def get_greedy_decoder(
tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
*_,
):
"""Greedy decoder."""
return GreedyDecoder(tokenizer)


def get_beam_search_decoder(
tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
tokens_path: str,
lang_model_path: str,
language: str,
hparams: dict, # pylint: disable=redefined-outer-name
):
"""Beam search decoder."""

n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
hparams = hparams.get("lm", {})
language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
hparams["language"],
hparams["language_model_path"],
hparams["n_gram"],
hparams["beam_size"],
hparams["beam_threshold"],
Expand All @@ -53,6 +56,7 @@ def beam_search_decoder(
torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
_extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)

tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt")
if not os.path.isfile(tokens_path):
tokenizer.create_tokens_txt(tokens_path)

Expand All @@ -79,11 +83,66 @@ def beam_search_decoder(
return decoder


class GreedyDecoder:
"""Greedy decoder."""

def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name
self.tokenizer = tokenizer

def __call__(
self, output, greedy_type: str = "inference", labels=None, label_lengths=None
): # pylint: disable=redefined-outer-name
"""Greedily decode a sequence."""
if greedy_type == "train":
res = self.train(output, labels, label_lengths)
if greedy_type == "inference":
res = self.inference(output)

res = [[DecoderOutput(words=res)]]
return res

def train(self, output, labels, label_lengths):
"""Greedily decode a sequence with known labels."""
blank_label = tokenizer.get_blank_token()
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
for j, index in enumerate(args):
if index != blank_label:
if j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
decodes.append(self.tokenizer.decode(decode))
return decodes, targets

def inference(self, output):
"""Greedily decode a sequence."""
collapse_repeated = True
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = self.tokenizer.get_blank_token()
decodes = []
for args in arg_maxes:
decode = []
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
decodes.append(self.tokenizer.decode(decode))

return decodes


if __name__ == "__main__":
tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")

hparams = {
"language": "german",
"lang_model_path": "data",
"n_gram": 3,
"beam_size": 100,
"beam_threshold": 100,
Expand All @@ -92,10 +151,4 @@ def beam_search_decoder(
"word_score": 1.0,
}

beam_search_decoder(
tokenizer,
"data/tokenizers/tokens_german.txt",
"data",
"german",
hparams,
)
get_beam_search_decoder(tokenizer, hparams)

0 comments on commit d5e482b

Please sign in to comment.