Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 18, 2023
1 parent c09ff76 commit d44cf7b
Show file tree
Hide file tree
Showing 11 changed files with 456 additions and 60 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# pictures
**/*.png

# Training files
data/*
!data/tokenizers
!data/own
!data/metrics.csv

# Mac
**/.DS_Store
Expand Down
13 changes: 0 additions & 13 deletions Dockerfile

This file was deleted.

9 changes: 0 additions & 9 deletions Makefile

This file was deleted.

34 changes: 0 additions & 34 deletions config.cluster.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion config.philipp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"

decoder:
type: "lm" # greedy, or lm (beam search)
type: "greedy" # greedy, or lm (beam search)

lm: # config for lm decoder
language_model_path: "data" # path where model and supplementary files are stored
Expand Down
Binary file added data/own/Philipp_HerrK.flac
Binary file not shown.
245 changes: 245 additions & 0 deletions lm_decoder_hparams.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"lm_weights = [0, 1.0, 2.5,]\n",
"word_score = [-1.5, 0.0, 1.5]\n",
"beam_sizes = [50, 500]\n",
"beam_thresholds = [50]\n",
"beam_size_token = [10, 38]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n",
"/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from tqdm.autonotebook import tqdm\n",
"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn.functional as F\n",
"\n",
"from swr2_asr.utils.decoder import decoder_factory\n",
"from swr2_asr.utils.tokenizer import CharTokenizer\n",
"from swr2_asr.model_deep_speech import SpeechRecognitionModel\n",
"from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n",
"from swr2_asr.utils.loss_scores import cer, wer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "34aafd9aca2541748dc41d8550334536",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/144 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Download flag not set, skipping download\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"New best WER: 0.8266228565397248 CER: 0.6048691547202959\n",
"Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
"LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n",
"--------------------------------------------------------------\n",
"New best WER: 0.7900706123452581 CER: 0.49197597466135945\n",
"Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
"LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n",
"--------------------------------------------------------------\n",
"New best WER: 0.7877685082828738 CER: 0.48660732878914315\n",
"Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
"LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n",
"--------------------------------------------------------------\n"
]
}
],
"source": [
"\n",
"\n",
"tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n",
"\n",
"# manually increment tqdm progress bar\n",
"pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n",
"\n",
"base_config = {\n",
" \"language\": \"german\",\n",
" \"language_model_path\": \"data\", # path where model and supplementary files are stored\n",
" \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n",
" \"beam_size\": 50 ,\n",
" \"beam_threshold\": 50,\n",
" \"n_best\": 1,\n",
" \"lm_weight\": 2,\n",
" \"word_score\": 0,\n",
" }\n",
"\n",
"dataset_params = {\n",
" \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n",
" \"language\": \"mls_german_opus\",\n",
" \"split\": Split.DEV,\n",
" \"limited\": True,\n",
" \"download\": False,\n",
" \"size\": 0.01,\n",
"}\n",
" \n",
"\n",
"model_params = {\n",
" \"n_cnn_layers\": 3,\n",
" \"n_rnn_layers\": 5,\n",
" \"rnn_dim\": 512,\n",
" \"n_class\": tokenizer.get_vocab_size(),\n",
" \"n_feats\": 128,\n",
" \"stride\": 2,\n",
" \"dropout\": 0.1,\n",
"}\n",
"\n",
"model = SpeechRecognitionModel(**model_params)\n",
"\n",
"checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n",
"\n",
"state_dict = {\n",
" k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n",
" for k, v in checkpoint[\"model_state_dict\"].items()\n",
"}\n",
"model.load_state_dict(state_dict, strict=True)\n",
"model.eval()\n",
"\n",
"\n",
"dataset = MLSDataset(**dataset_params,)\n",
"\n",
"data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n",
"\n",
"dataloader = DataLoader(\n",
" dataset=dataset,\n",
" batch_size=16,\n",
" shuffle = False,\n",
" collate_fn=data_processing,\n",
" num_workers=8,\n",
" pin_memory=True,\n",
")\n",
"\n",
"best_wer = 1.0\n",
"best_cer = 1.0\n",
"best_config = None\n",
"\n",
"for lm_weight in lm_weights:\n",
" for ws in word_score:\n",
" for beam_size in beam_sizes:\n",
" for beam_threshold in beam_thresholds:\n",
" for beam_size_t in beam_size_token:\n",
" config = base_config.copy()\n",
" config[\"lm_weight\"] = lm_weight\n",
" config[\"word_score\"] = ws\n",
" config[\"beam_size\"] = beam_size\n",
" config[\"beam_threshold\"] = beam_threshold\n",
" config[\"beam_size_token\"] = beam_size_t\n",
" \n",
" decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n",
" \n",
" test_cer, test_wer = [], []\n",
" with torch.no_grad():\n",
" model.eval()\n",
" for batch in dataloader:\n",
" # perform inference, decode, compute WER and CER\n",
" spectrograms, labels, input_lengths, label_lengths = batch\n",
" \n",
" output = model(spectrograms)\n",
" output = F.log_softmax(output, dim=2)\n",
" \n",
" decoded_preds = decoder(output)\n",
" decoded_targets = tokenizer.decode_batch(labels)\n",
" \n",
" for j, _ in enumerate(decoded_preds):\n",
" if j >= len(decoded_targets):\n",
" break\n",
" pred = \" \".join(decoded_preds[j][0].words).strip()\n",
" target = decoded_targets[j]\n",
" \n",
" test_cer.append(cer(pred, target))\n",
" test_wer.append(wer(pred, target))\n",
"\n",
" avg_cer = sum(test_cer) / len(test_cer)\n",
" avg_wer = sum(test_wer) / len(test_wer)\n",
" \n",
" if avg_wer < best_wer:\n",
" best_wer = avg_wer\n",
" best_cer = avg_cer\n",
" best_config = config\n",
" print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n",
" print(\"Config: \", best_config)\n",
" print(\"LM Weight: \", lm_weight, \n",
" \" Word Score: \", ws, \n",
" \" Beam Size: \", beam_size, \n",
" \" Beam Threshold: \", beam_threshold, \n",
" \" Beam Size Token: \", beam_size_t)\n",
" print(\"--------------------------------------------------------------\")\n",
" \n",
" pbar.update(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
69 changes: 69 additions & 0 deletions metrics.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
epoch,train_loss,test_loss,cer,wer
0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454
1.0,2.791025161743164,0.0,0.0,0.0
2.0,1.5954065322875977,0.0,0.0,0.0
3.0,1.3106564283370972,0.0,0.0,0.0
4.0,1.206541895866394,0.0,0.0,0.0
5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183
6.0,1.0295032262802124,0.0,0.0,0.0
7.0,0.957234263420105,0.0,0.0,0.0
8.0,0.8958202004432678,0.0,0.0,0.0
9.0,0.8403098583221436,0.0,0.0,0.0
10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198
11.0,0.7537956833839417,0.0,0.0,0.0
12.0,0.7180628776550293,0.0,0.0,0.0
13.0,0.6870554089546204,0.0,0.0,0.0
14.0,0.6595032811164856,0.0,0.0,0.0
15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556
16.0,0.6134707927703857,0.0,0.0,0.0
17.0,0.5946973562240601,0.0,0.0,0.0
18.0,0.577201783657074,0.0,0.0,0.0
19.0,0.5612062811851501,0.0,0.0,0.0
20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307
21.0,0.5190389752388,0.0,0.0,0.0
22.0,0.5163558721542358,0.0,0.0,0.0
23.0,0.5132778286933899,0.0,0.0,0.0
24.0,0.5090991854667664,0.0,0.0,0.0
25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658
26.0,0.5023046731948853,0.0,0.0,0.0
27.0,0.4994561970233917,0.0,0.0,0.0
28.0,0.4942632019519806,0.0,0.0,0.0
29.0,0.4906529486179352,0.0,0.0,0.0
30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594
31.0,0.4822919964790344,0.0,0.0,0.0
32.0,0.4456436336040497,0.0,0.0,0.0
33.0,0.4389857053756714,0.0,0.0,0.0
34.0,0.43762147426605225,0.0,0.0,0.0
35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124
36.0,0.43377435207366943,0.0,0.0,0.0
37.0,0.4318349063396454,0.0,0.0,0.0
38.0,0.43010208010673523,0.0,0.0,0.0
39.0,0.4276123046875,0.0,0.0,0.0
40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734
41.0,0.4236880838871002,0.0,0.0,0.0
42.0,0.42077934741973877,0.0,0.0,0.0
43.0,0.4181424081325531,0.0,0.0,0.0
44.0,0.4154696464538574,0.0,0.0,0.0
45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078
46.0,0.4099026024341583,0.0,0.0,0.0
47.0,0.4078012704849243,0.0,0.0,0.0
48.0,0.40490180253982544,0.0,0.0,0.0
49.0,0.4024839699268341,0.0,0.0,0.0
50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995
51.0,0.36624056100845337,0.0,0.0,0.0
52.0,0.36418089270591736,0.0,0.0,0.0
53.0,0.36366793513298035,0.0,0.0,0.0
54.0,0.36317530274391174,0.0,0.0,0.0
55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951
56.0,0.36174166202545166,0.0,0.0,0.0
57.0,0.36113062500953674,0.0,0.0,0.0
58.0,0.36098596453666687,0.0,0.0,0.0
59.0,0.35909315943717957,0.0,0.0,0.0
60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114
61.0,0.35837724804878235,0.0,0.0,0.0
62.0,0.3567410409450531,0.0,0.0,0.0
63.0,0.3565385341644287,0.0,0.0,0.0
64.0,0.35535314679145813,0.0,0.0,0.0
65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726
66.0,0.35215333104133606,0.0,0.0,0.0
67.0,0.35401859879493713,0.0,0.0,0.0
Loading

0 comments on commit d44cf7b

Please sign in to comment.