Skip to content

Commit

Permalink
try fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 6, 2023
1 parent 4bd118d commit 6b7bd14
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 32 deletions.
28 changes: 4 additions & 24 deletions swr2_asr/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class CharTokenizer(TokenizerType):
def __init__(self):
self.char_map = {}
self.index_map = {}
self.add_tokens(["<UNK>", "<SPACE>"])

def add_tokens(self, tokens: list[str]):
"""Manually add tokens to the tokenizer
Expand Down Expand Up @@ -128,22 +127,11 @@ def train(self, dataset_path: str, language: str, split: str):
self.index_map[i] = char

def encode(self, sequence: str):
"""Use a character map and convert text to an integer sequence
automatically maps spaces to <SPACE> and makes everything lowercase
unknown characters are mapped to the <UNK> token
"""
"""Use a character map and convert text to an integer sequence"""
int_sequence = []
sequence = sequence.lower()
for char in sequence:
if char == " ":
mapped_char = self.char_map["<SPACE>"]
elif char not in self.char_map:
mapped_char = self.char_map["<UNK>"]
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
int_sequence.append(self.char_map[char])
return Encoding(ids=int_sequence, tokens=list(sequence))

def decode(self, labels: list[int], remove_special_tokens: bool = True):
Expand All @@ -156,25 +144,17 @@ def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""
string = []
for i in labels:
if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
continue
if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
string.append(" ")
string.append(self.index_map[f"{i}"])
return "".join(string).replace("<SPACE>", " ")
return "".join(string)

def decode_batch(self, labels: list[list[int]]):
"""Use a character map and convert integer labels to an text sequence"""
strings = []
for label in labels:
string = []
for i in label:
if self.index_map[i] == "<UNK>":
continue
if self.index_map[i] == "<SPACE>":
string.append(" ")
string.append(self.index_map[i])
strings.append("".join(string).replace("<SPACE>", " "))
strings.append("".join(string))
return strings

def get_vocab_size(self):
Expand Down
10 changes: 6 additions & 4 deletions swr2_asr/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def test(model, device, test_loader, criterion, tokenizer):
label_lengths=_data["utterance_length"],
tokenizer=tokenizer,
)

print(f"decoded_preds: {decoded_preds}")
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
Expand Down Expand Up @@ -155,10 +157,10 @@ def run(

# load dataset
train_dataset = MLSDataset(
dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
dataset_path, language, Split.TRAIN, download=False, spectrogram_hparams=None, limited=True
)
valid_dataset = MLSDataset(
dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
dataset_path, language, Split.VALID, download=False, spectrogram_hparams=None, limited=True
)

# load tokenizer (bpe by default):
Expand Down Expand Up @@ -196,14 +198,14 @@ def run(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: collate_fn(x),
collate_fn=lambda x: collate_fn(x, tokenizer.encode(" ").ids[0]),
)

valid_loader = DataLoader(
valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
collate_fn=lambda x: collate_fn(x),
collate_fn=lambda x: collate_fn(x, tokenizer.encode(" ").ids[0]),
)

# enable flag to find the most compatible algorithms in advance
Expand Down
19 changes: 15 additions & 4 deletions swr2_asr/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
from typing import TypedDict
from typing import TypedDict, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -354,11 +354,14 @@ def __getitem__(self, idx: int) -> Sample:
)


def collate_fn(samples: list[Sample]) -> dict:
def collate_fn(samples: list[Sample], padding_id: Union[float, int]) -> dict:
"""Collate function for the dataloader
pads all tensors within a batch to the same dimensions
"""
if isinstance(padding_id, int):
padding_id = float(padding_id)

waveforms = []
spectrograms = []
labels = []
Expand Down Expand Up @@ -399,19 +402,27 @@ def collate_fn(samples: list[Sample]) -> dict:
dataset.set_tokenizer(tok)


def plot(epochs, path):
def plot(epochs, path, title):
"""Plots the losses over the epochs"""
losses = list()
test_losses = list()
cers = list()
wers = list()
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
losses.append(current_state["loss"])
test_losses.append(current_state["test_loss"])
cers.append(current_state["avg_cer"])
wers.append(current_state["avg_wer"])

# plot losses and cers
plt.plot(losses)

plt.plot(test_losses)
plt.plot(cers)
plt.plot(wers)

plt.legend(["train loss", "test loss", "cer (test)", "wer (tes)"])
plt.title(title)

plt.savefig("losses.svg")

0 comments on commit 6b7bd14

Please sign in to comment.