Skip to content

Commit

Permalink
add limited supervision training (10hr)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 5, 2023
1 parent 46b23fd commit 4bd118d
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 26 deletions.
25 changes: 12 additions & 13 deletions swr2_asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from .loss_scores import cer, wer

# TODO: improve naming of functions


class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
Expand Down Expand Up @@ -157,10 +155,10 @@ def run(

# load dataset
train_dataset = MLSDataset(
dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
)
valid_dataset = MLSDataset(
dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
)

# load tokenizer (bpe by default):
Expand All @@ -171,7 +169,6 @@ def run(
dataset_path=dataset_path,
language=language,
split="all",
download=False,
out_path="data/tokenizers/char_tokenizer_german.json",
)

Expand Down Expand Up @@ -211,7 +208,7 @@ def run(

# enable flag to find the most compatible algorithms in advance
if use_cuda:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = True # pylance: disable=no-member

model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
Expand Down Expand Up @@ -253,7 +250,7 @@ def run(
iter_meter,
)

test_loss,avg_cer,avg_wer = test(
test_loss, avg_cer, avg_wer = test(
model=model,
device=device,
test_loader=valid_loader,
Expand All @@ -262,12 +259,14 @@ def run(
)
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch,
"model_state_dict": model.state_dict(),
"loss": loss,
"test_loss": test_loss,
"avg_cer": avg_cer,
"avg_wer": avg_wer},
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"loss": loss,
"test_loss": test_loss,
"avg_cer": avg_cer,
"avg_wer": avg_wer,
},
path + str(epoch),
)

Expand Down
109 changes: 96 additions & 13 deletions swr2_asr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from enum import Enum
from typing import TypedDict

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
from tokenizers import Tokenizer
Expand Down Expand Up @@ -95,6 +95,7 @@ def __init__(
dataset_path: str,
language: str,
split: Split,
limited: bool,
download: bool,
spectrogram_hparams: dict | None,
):
Expand Down Expand Up @@ -124,10 +125,90 @@ def __init__(

self._handle_download_dataset(download)
self._validate_local_directory()
self.initialize()
if limited and (split == Split.TRAIN or split == Split.VALID):
self.initialize_limited()
else:
self.initialize()

def initialize_limited(self) -> None:
"""Initializes the limited supervision dataset"""
# get file handles
# get file paths
# get transcripts
# create train or validation split

handles = set()

train_root_path = os.path.join(self.dataset_path, self.language, "train")

# get file handles for 9h
with open(
os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"),
"r",
encoding="utf-8",
) as file:
for line in file:
handles.add(line.strip())

# get file handles for 1h splits
for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")):
if handle_path not in range(0, 6):
continue
with open(
os.path.join(
train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt"
),
"r",
encoding="utf-8",
) as file:
for line in file:
handles.add(line.strip())

# get file paths for handles
file_paths = []
for handle in handles:
file_paths.append(
os.path.join(
train_root_path,
"audio",
handle.split("_")[0],
handle.split("_")[1],
handle + self.file_ext,
)
)

# get transcripts for handles
transcripts = []
with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file:
for line in file:
if line.split("\t")[0] in handles:
transcripts.append(line.strip())

# create train or valid split randomly with seed 42
if self.split == Split.TRAIN:
np.random.seed(42)
indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8))
file_paths = [file_paths[i] for i in indices]
transcripts = [transcripts[i] for i in indices]
elif self.split == Split.VALID:
np.random.seed(42)
indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2))
file_paths = [file_paths[i] for i in indices]
transcripts = [transcripts[i] for i in indices]

# create dataset lookup
self.dataset_lookup = [
{
"speakerid": path.split("/")[-3],
"bookid": path.split("/")[-2],
"chapterid": path.split("/")[-1].split("_")[2].split(".")[0],
"utterance": utterance.split("\t")[1],
}
for path, utterance in zip(file_paths, transcripts, strict=False)
]

def initialize(self) -> None:
"""Initializes the dataset
"""Initializes the entire dataset
Reads the transcripts.txt file and creates a lookup table
"""
Expand Down Expand Up @@ -189,7 +270,8 @@ def _handle_download_dataset(self, download: bool) -> None:
# unzip the dataset
if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
print(
f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
f"Unzipping the dataset at \
{os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
)
extract_archive(
os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
Expand Down Expand Up @@ -236,7 +318,7 @@ def __getitem__(self, idx: int) -> Sample:
+ self.file_ext,
)

waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member

# resample if necessary
if sample_rate != self.spectrogram_hparams["sample_rate"]:
Expand All @@ -257,7 +339,7 @@ def __getitem__(self, idx: int) -> Sample:

utterance = self.tokenizer.encode(utterance)

utterance = torch.LongTensor(utterance.ids)
utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member

return Sample(
waveform=waveform,
Expand Down Expand Up @@ -311,24 +393,25 @@ def collate_fn(samples: list[Sample]) -> dict:
split = Split.TRAIN
DOWNLOAD = False

dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None)

tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)


def plot(epochs,path):
losses = list()
def plot(epochs, path):
"""Plots the losses over the epochs"""
losses = list()
test_losses = list()
cers = list()
wers =list()
for epoch in range(1, epochs +1):
wers = list()
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])
test_losses.append(current_state["test_loss"])
cers.append(current_state["avg_cer"])
wers.append(current_state["avg_wer"])

plt.plot(losses)
plt.plot(test_losses)
plt.savefig("losses.svg")
plt.savefig("losses.svg")

0 comments on commit 4bd118d

Please sign in to comment.