Skip to content

Commit

Permalink
added support for lm decoder during training
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 18, 2023
1 parent d5e482b commit c09ff76
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
10 changes: 5 additions & 5 deletions config.philipp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ dataset:
dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir
language_name: "mls_german_opus"
limited_supervision: True # set to True if you want to use limited supervision
dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%)
dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%)
shuffle: True

model:
Expand Down Expand Up @@ -33,13 +33,13 @@ decoder:
training:
learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 3
eval_every_n: 3 # evaluate every n epochs
epochs: 100
eval_every_n: 1 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader

checkpoints: # use "~" to disable saving/loading
model_load_path: "YOUR/PATH" # path to load model from
model_save_path: "YOUR/PATH" # path to save model to
model_load_path: "data/epoch67" # path to load model from
model_save_path: ~ # path to save model to

inference:
model_load_path: "data/epoch67" # path to load model from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ target-version = "py310"
line-length = 100

[tool.poetry.scripts]
train = "swr2_asr.train:run_cli"
train = "swr2_asr.train:main"
train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer"

Expand Down
32 changes: 18 additions & 14 deletions swr2_asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
from swr2_asr.utils.decoder import greedy_decoder
from swr2_asr.utils.tokenizer import CharTokenizer

from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import cer, wer
from swr2_asr.utils.tokenizer import CharTokenizer


class IterMeter:
Expand Down Expand Up @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
# get values from test_args:
model, device, test_loader, criterion, tokenizer, decoder = test_args.values()

if decoder == "greedy":
decoder = greedy_decoder

model.eval()
test_loss = 0
test_cer, test_wer = [], []
Expand All @@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)

decoded_preds, decoded_targets = greedy_decoder(
output.transpose(0, 1), labels, label_lengths, tokenizer
)
decoded_targets = tokenizer.decode_batch(labels)
decoded_preds = decoder(output.transpose(0, 1))
for j, _ in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
if j >= len(decoded_targets):
break
test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0]))

avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
Expand Down Expand Up @@ -187,6 +184,7 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
decoder_config = config_dict.get("decoder", {})

if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
Expand Down Expand Up @@ -262,12 +260,19 @@ def main(config_path: str):

if checkpoints_config["model_load_path"] is not None:
checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
state_dict = {
k[len("module.") :] if k.startswith("module.") else k: v
for k, v in checkpoint["model_state_dict"].items()
}

model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]

iter_meter = IterMeter()

decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config)

for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
train_args: TrainArgs = {
"model": model,
Expand All @@ -283,14 +288,13 @@ def main(config_path: str):
train_loss = train(train_args)

test_loss, test_cer, test_wer = 0, 0, 0

test_args: TestArgs = {
"model": model,
"device": device,
"test_loader": valid_loader,
"criterion": criterion,
"tokenizer": tokenizer,
"decoder": "greedy",
"decoder": decoder,
}

if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
Expand Down
2 changes: 1 addition & 1 deletion swr2_asr/utils/decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Decoder for CTC-based ASR.""" ""
from dataclasses import dataclass
import os
from dataclasses import dataclass

import torch
from torchaudio.datasets.utils import _extract_tar
Expand Down
8 changes: 8 additions & 0 deletions swr2_asr/utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ def decode(self, labels: list[int]) -> str:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")

def decode_batch(self, labels: list[list[int]]) -> list[str]:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for label in labels:
string.append(self.decode(label))
return string

def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)
Expand Down

0 comments on commit c09ff76

Please sign in to comment.