From 455f16ac2bd1373b172680538a8e414f40141722 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Fri, 1 Oct 2021 10:37:21 -0400 Subject: [PATCH] Add robustness to automatic decoder instantiation (#4) Simplify automatic vocabulary parsing --- pyctcdecode/alphabet.py | 207 ++++++++++++++---------- pyctcdecode/decoder.py | 45 ++++-- pyctcdecode/language_model.py | 52 +++++- pyctcdecode/tests/test_alphabet.py | 131 +++++++-------- pyctcdecode/tests/test_decoder.py | 27 ++-- tutorials/00_basic_usage.ipynb | 4 +- tutorials/01_pipeline_nemo.ipynb | 50 +++--- tutorials/02_pipeline_huggingface.ipynb | 70 ++++---- 8 files changed, 332 insertions(+), 254 deletions(-) diff --git a/pyctcdecode/alphabet.py b/pyctcdecode/alphabet.py index 474660c..c14fa9b 100644 --- a/pyctcdecode/alphabet.py +++ b/pyctcdecode/alphabet.py @@ -1,93 +1,122 @@ # Copyright 2021-present Kensho Technologies, LLC. +from __future__ import division + import logging -from typing import List, Optional +import re +from typing import Collection, List + +BPE_TOKEN = "▁" # nosec # representation of token boundary in BPE alphabet +UNK_TOKEN = "⁇" # nosec # representation of special UNK token in regular alphabet +UNK_BPE_TOKEN = "▁⁇▁" # nosec # representation of special UNK token in BPE alphabet -BPE_CHAR = "▁" # character used for token boundary if BPE is used -UNK_BPE_CHAR = "▁⁇▁" # representation of unknown character in BPE +# special tokens are usually encode with things like `[]` or `<>` +SPECIAL_TOKEN_PTN = re.compile(r"^[<\[].+[>\]]$") +BLANK_TOKEN_PTN = re.compile(r"^[<\[]pad[>\]]$", flags=re.IGNORECASE) +UNK_TOKEN_PTN = re.compile(r"^[<\[]unk[>\]]$", flags=re.IGNORECASE) logger = logging.getLogger(__name__) -def _get_ctc_index(label_list: List[str]) -> int: - """Get index of ctc blank character in alphabet.""" - return len(label_list) - 1 if label_list[-1] == "" else -1 - - -def _normalize_alphabet(label_list: List[str], ctc_token_idx: Optional[int] = None) -> List[str]: - """Normalize alphabet for non-bpe decoder.""" - if any([len(c) > 1 for c in label_list]): - raise ValueError("For non-bpe alphabet only length 1 entries and blank token are allowed.") - if ctc_token_idx is None: - ctc_token_idx = _get_ctc_index(label_list) - clean_labels = label_list[:] - # check for space token - if " " not in clean_labels: - raise ValueError("Space token ' ' missing from vocabulary.") - # specify ctc blank token - if ctc_token_idx == -1: - clean_labels.append("") +def _check_if_bpe(labels: List[str]) -> bool: + """Check if input alphabet is BPE or not.""" + is_bpe = any([s.startswith("##") for s in labels]) or any( + [s.startswith(BPE_TOKEN) for s in labels] + ) + if is_bpe: + logger.info("Alphabet determined to be of BPE style.") else: - clean_labels[ctc_token_idx] = "" - return clean_labels + logger.info("Alphabet determined to be of regular style.") + return is_bpe + + +def _normalize_regular_alphabet(labels: List[str]) -> List[str]: + """Normalize non-bpe labels to alphabet for decoder.""" + normalized_labels = labels[:] + # substitute space characters + if "|" in normalized_labels and " " not in normalized_labels: + logger.info("Found '|' in vocabulary but not ' ', doing substitution.") + normalized_labels[normalized_labels.index("|")] = " " + # substituted ctc blank char + for n, label in enumerate(normalized_labels): + if BLANK_TOKEN_PTN.match(label): + logger.info( + "Found %s in vocabulary, interpreted as a CTC blank token, substituting with %s.", + label, + "", + ) + normalized_labels[n] = "" + if "_" in normalized_labels and "" not in normalized_labels: + logger.info("Found '_' in vocabulary but not '', doing substitution.") + normalized_labels[normalized_labels.index("_")] = "" + if "" not in normalized_labels: + logger.info("CTC blank char '' not found, appending to end.") + normalized_labels.append("") + # substitute unk + for n, label in enumerate(normalized_labels): + if UNK_TOKEN_PTN.match(label): + logger.info( + "Found %s in vocabulary, interpreting as unknown token, substituting with %s.", + label, + UNK_TOKEN, + ) + normalized_labels[n] = UNK_TOKEN + # additional checks + if any([len(c) > 1 for c in normalized_labels]): + logger.warning( + "Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the " + "alphabet was not recognized as BPE type. Is this correct?" + ) + if " " not in normalized_labels: + logger.warning("Space token ' ' missing from vocabulary.") + return normalized_labels -def _convert_bpe_format(token: str) -> str: - """Convert token from ## type bpe format to ▁ type.""" - if token[:2] == "##": +def _convert_bpe_token_style(token: str) -> str: + """Convert token from ## style bpe format to ▁ style.""" + if token.startswith("##"): return token[2:] - elif token == BPE_CHAR: - return token - elif token == "": # nosec + elif SPECIAL_TOKEN_PTN.match(token) or token in ("", BPE_TOKEN, UNK_BPE_TOKEN): return token - elif token in ("", UNK_BPE_CHAR): + elif token in ("", UNK_BPE_TOKEN): return token else: - return BPE_CHAR + token + return BPE_TOKEN + token -def _normalize_bpe_alphabet( - label_list: List[str], - unk_token_idx: Optional[int] = None, - ctc_token_idx: Optional[int] = None, -) -> List[str]: +def _normalize_bpe_alphabet(labels: List[str]) -> List[str]: """Normalize alphabet for bpe decoder.""" - if ctc_token_idx is None: - ctc_token_idx = _get_ctc_index(label_list) - # create copy - clean_labels = label_list[:] - # there are two common formats for BPE vocabulary - # 1) where ▁ indicates a space (this is the main format we use) - if any([s[:1] == BPE_CHAR and len(s) > 1 for s in clean_labels]): - # verify unk token and make sure it is consistently represented as ▁⁇▁ - if unk_token_idx is None and clean_labels[0] in ("", UNK_BPE_CHAR): - unk_token_idx = 0 - else: - raise ValueError( - "First token in vocab for BPE should be '▁⁇▁' or specify unk_token_idx." - ) - clean_labels[unk_token_idx] = UNK_BPE_CHAR - # 2) where ## indicates continuation of a token (note: also contains the single token: ▁) - elif any([s[:2] == "##" for s in clean_labels]): - # convert to standard format 1) - clean_labels = [_convert_bpe_format(c) for c in clean_labels] - # add unk token if needed - if clean_labels[0] in ("", UNK_BPE_CHAR): - clean_labels[0] = UNK_BPE_CHAR - else: - clean_labels = [UNK_BPE_CHAR] + clean_labels - ctc_token_idx += 1 - else: - raise ValueError( - "Unknown BPE format for vocabulary. Supported formats are 1) ▁ for indicating a" - " space and 2) ## for continuation of a word." - ) - # specify ctc blank token - if ctc_token_idx == -1: - clean_labels.append("") - else: - clean_labels[ctc_token_idx] = "" - return clean_labels + normalized_labels = labels[:] + # if BPE is of style '##' then convert it + if any([s.startswith("##") for s in labels]): + normalized_labels = [_convert_bpe_token_style(c) for c in normalized_labels] + # substituted ctc blank char + for n, label in enumerate(normalized_labels): + if BLANK_TOKEN_PTN.match(label): + logger.info("Found %s in vocabulary, substituting with %s.", label, "") + normalized_labels[n] = "" + if "" not in normalized_labels: + logger.info("CTC blank char '' not found, appending to end.") + normalized_labels.append("") + # substitute unk + for n, label in enumerate(normalized_labels): + if UNK_TOKEN_PTN.match(label): + logger.info("Found %s in vocabulary, substituting with %s.", label, UNK_BPE_TOKEN) + normalized_labels[n] = UNK_BPE_TOKEN + # additional checks + if UNK_BPE_TOKEN not in normalized_labels: + logger.warning("UNK token %s not found, is this a mistake?", UNK_BPE_TOKEN) + return normalized_labels + + +def _verify_alphabet(labels: List[str], is_bpe: bool) -> None: + """Verify basic alphabet labels.""" + # check if duplicates exist + if len(labels) != len(set(labels)): + raise ValueError("Alphabet contains duplicate entries, this is not allowed.") + # check if space character is absent in bpe alphabet + if is_bpe and any([" " in s for s in labels]): + raise ValueError("Space token ' ' found in vocabulary even though it looks like BPE.") class Alphabet: @@ -98,7 +127,7 @@ def __init__(self, labels: List[str], is_bpe: bool) -> None: @property def is_bpe(self) -> bool: - """Whether the alphabet is a bytepair encoded one.""" + """Whether the alphabet is bpe style.""" return self._is_bpe @property @@ -107,20 +136,20 @@ def labels(self) -> List[str]: return self._labels[:] # this is a copy @classmethod - def build_alphabet( - cls, label_list: List[str], ctc_token_idx: Optional[int] = None - ) -> "Alphabet": - """Make a non-BPE alphabet.""" - formatted_alphabet_list = _normalize_alphabet(label_list, ctc_token_idx) - return cls(formatted_alphabet_list, False) + def build_alphabet(cls, labels: List[str]) -> "Alphabet": + """Make an alphabet from labels in standardized format for decoder.""" + is_bpe = _check_if_bpe(labels) + _verify_alphabet(labels, is_bpe) + if is_bpe: + normalized_labels = _normalize_bpe_alphabet(labels) + else: + normalized_labels = _normalize_regular_alphabet(labels) + return cls(normalized_labels, is_bpe) - @classmethod - def build_bpe_alphabet( - cls, - label_list: List[str], - unk_token_idx: Optional[int] = None, - ctc_token_idx: Optional[int] = None, - ) -> "Alphabet": - """Make a BPE alphabet.""" - formatted_label_list = _normalize_bpe_alphabet(label_list, unk_token_idx, ctc_token_idx) - return cls(formatted_label_list, True) + +def verify_alphabet_coverage(alphabet: Alphabet, unigrams: Collection[str]) -> None: + """Verify if alphabet covers a given unigrams.""" + label_chars = set(alphabet.labels) + unigram_sample_chars = set("".join(unigrams)) + if len(unigram_sample_chars - label_chars) / len(unigram_sample_chars) > 0.2: + logger.warning("Unigrams and labels don't seem to agree.") diff --git a/pyctcdecode/decoder.py b/pyctcdecode/decoder.py index e241fb1..23df8d9 100644 --- a/pyctcdecode/decoder.py +++ b/pyctcdecode/decoder.py @@ -6,11 +6,11 @@ import logging import math import os -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union import numpy as np -from .alphabet import BPE_CHAR, Alphabet +from .alphabet import BPE_TOKEN, Alphabet, verify_alphabet_coverage from .constants import ( DEFAULT_ALPHA, DEFAULT_BEAM_WIDTH, @@ -23,7 +23,12 @@ DEFAULT_UNK_LOGP_OFFSET, MIN_TOKEN_CLIP_P, ) -from .language_model import AbstractLanguageModel, HotwordScorer, LanguageModel +from .language_model import ( + AbstractLanguageModel, + HotwordScorer, + LanguageModel, + load_unigram_set_from_arpa, +) logger = logging.getLogger(__name__) @@ -363,13 +368,13 @@ def _decode_logits( ) ) # if bpe and leading space char - elif self._is_bpe and (char[:1] == BPE_CHAR or force_next_break): + elif self._is_bpe and (char[:1] == BPE_TOKEN or force_next_break): force_next_break = False # some tokens are bounded on both sides like ▁⁇▁ clean_char = char - if char[:1] == BPE_CHAR: + if char[:1] == BPE_TOKEN: clean_char = clean_char[1:] - if char[-1:] == BPE_CHAR: + if char[-1:] == BPE_TOKEN: clean_char = clean_char[:-1] force_next_break = True new_frame_list = ( @@ -668,35 +673,41 @@ def decode_batch( def build_ctcdecoder( labels: List[str], - kenlm_model: Optional[kenlm.Model] = None, - unigrams: Optional[Iterable[str]] = None, + kenlm_model_path: Optional[str] = None, + unigrams: Optional[Collection[str]] = None, alpha: float = DEFAULT_ALPHA, beta: float = DEFAULT_BETA, unk_score_offset: float = DEFAULT_UNK_LOGP_OFFSET, lm_score_boundary: bool = DEFAULT_SCORE_LM_BOUNDARY, - ctc_token_idx: Optional[int] = None, - is_bpe: bool = False, ) -> BeamSearchDecoderCTC: """Build a BeamSearchDecoderCTC instance with main functionality. Args: labels: class containing the labels for input logit matrices - kenlm_model: instance of kenlm n-gram language model `kenlm.Model` + kenlm_model_path: path to kenlm n-gram language model unigrams: list of known word unigrams alpha: weight for language model during shallow fusion beta: weight for length score adjustment of during scoring unk_score_offset: amount of log score offset for unknown tokens lm_score_boundary: whether to have kenlm respect boundaries when scoring - ctc_token_idx: index of ctc blank token within the labels - is_bpe: indicate if labels are BPE type Returns: instance of BeamSearchDecoderCTC """ - if is_bpe: - alphabet = Alphabet.build_bpe_alphabet(labels, ctc_token_idx=ctc_token_idx) - else: - alphabet = Alphabet.build_alphabet(labels, ctc_token_idx=ctc_token_idx) + kenlm_model = None if kenlm_model_path is None else kenlm.Model(kenlm_model_path) + if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"): + logger.info("Using arpa instead of binary LM file, decoder instantiation might be slow.") + if unigrams is None and kenlm_model_path is not None: + if kenlm_model_path.endswith(".arpa"): + unigrams = load_unigram_set_from_arpa(kenlm_model_path) + else: + logger.warning( + "Unigrams not provided and cannot be automatically determined from LM file (only " + "arpa format). Decoding accuracy might be reduced." + ) + alphabet = Alphabet.build_alphabet(labels) + if unigrams is not None: + verify_alphabet_coverage(alphabet, unigrams) if kenlm_model is not None: language_model: Optional[AbstractLanguageModel] = LanguageModel( kenlm_model, diff --git a/pyctcdecode/language_model.py b/pyctcdecode/language_model.py index 94f5e66..87ffbfe 100644 --- a/pyctcdecode/language_model.py +++ b/pyctcdecode/language_model.py @@ -4,7 +4,7 @@ import abc import logging import re -from typing import Iterable, List, Optional, Pattern, Tuple, cast +from typing import Collection, Iterable, List, Optional, Pattern, Set, Tuple, cast import numpy as np from pygtrie import CharTrie # type: ignore @@ -32,6 +32,45 @@ ) +def load_unigram_set_from_arpa(arpa_path: str) -> Set[str]: + """Read unigrams from arpa file.""" + unigrams = set() + with open(arpa_path) as f: + start_1_gram = False + for line in f: + line = line.strip() + if line == "\\1-grams:": + start_1_gram = True + elif line == "\\2-grams:": + break + if start_1_gram and len(line) > 0: + parts = line.split("\t") + if len(parts) == 3: + unigrams.add(parts[1]) + if len(unigrams) == 0: + raise ValueError("No unigrams found in arpa file. Something is wrong with the file.") + return unigrams + + +def _prepare_unigram_set(unigrams: Collection[str], kenlm_model: kenlm.Model) -> Set[str]: + """Filter unigrams down to vocabulary that exists in kenlm_model.""" + if len(unigrams) < 1000: + logger.warning( + "Only %s unigrams passed as vocabulary. Is this small or artificial data?", + len(unigrams), + ) + unigram_set = set(unigrams) + unigram_set = set([t for t in unigram_set if t in kenlm_model]) + retained_fraction = 1.0 if len(unigrams) == 0 else len(unigram_set) / len(unigrams) + if retained_fraction < 0.1: + logger.warning( + "Only %s%% of unigrams in vocabulary found in kenlm model-- this might mean that your " + "vocabulary and language model are incompatible. Is this intentional?", + round(retained_fraction * 100, 1), + ) + return unigram_set + + def _get_empty_lm_state() -> kenlm.State: """Get unintialized kenlm state.""" try: @@ -146,7 +185,7 @@ class LanguageModel(AbstractLanguageModel): def __init__( self, kenlm_model: kenlm.Model, - unigrams: Optional[Iterable[str]] = None, + unigrams: Optional[Collection[str]] = None, alpha: float = DEFAULT_ALPHA, beta: float = DEFAULT_BETA, unk_score_offset: float = DEFAULT_UNK_LOGP_OFFSET, @@ -164,10 +203,11 @@ def __init__( """ self._kenlm_model = kenlm_model if unigrams is None: + logger.warning("No known unigrams provided, decoding results might be a lot worse.") unigram_set = set() char_trie = None else: - unigram_set = set([t for t in set(unigrams) if t in self._kenlm_model]) + unigram_set = _prepare_unigram_set(unigrams, self._kenlm_model) char_trie = CharTrie.fromkeys(unigram_set) self._unigram_set = unigram_set self._char_trie = char_trie @@ -202,8 +242,10 @@ def _get_raw_end_score(self, start_state: kenlm.State) -> float: def score_partial_token(self, partial_token: str) -> float: """Get partial token score.""" if self._char_trie is None: - return 0.0 - unk_score = self.unk_score_offset * int(self._char_trie.has_node(partial_token) == 0) + is_oov = 1.0 + else: + is_oov = int(self._char_trie.has_node(partial_token) == 0) + unk_score = self.unk_score_offset * is_oov # if unk token length exceeds expected length then additionally decrease score if len(partial_token) > AVG_TOKEN_LEN: unk_score = unk_score * len(partial_token) / AVG_TOKEN_LEN diff --git a/pyctcdecode/tests/test_alphabet.py b/pyctcdecode/tests/test_alphabet.py index 469ce15..e4960f0 100644 --- a/pyctcdecode/tests/test_alphabet.py +++ b/pyctcdecode/tests/test_alphabet.py @@ -1,7 +1,7 @@ # Copyright 2021-present Kensho Technologies, LLC. import unittest -from ..alphabet import Alphabet, _normalize_alphabet, _normalize_bpe_alphabet +from ..alphabet import Alphabet, _normalize_bpe_alphabet, _normalize_regular_alphabet def _approx_beams(beams, precis=5): @@ -9,87 +9,66 @@ def _approx_beams(beams, precis=5): return [tuple(list(b[:-1]) + [round(b[-1], precis)]) for b in beams] -class TestModelHelpers(unittest.TestCase): - def test_normalize_alphabet(self): - alphabet_list = [" ", "a", "b", ""] - norm_alphabet = _normalize_alphabet(alphabet_list) - expected_alphabet = [" ", "a", "b", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) +KNOWN_MAPPINGS = [ + ( + [" ", "a", "b"], + [" ", "a", "b", ""], + False, + ), # nemo + ( + ["", "", "", "", "|", "A", "B"], + ["", "", "", "⁇", " ", "A", "B"], + False, + ), # huggingface + ( + ["", "▁", "##a", "##b", "a", "b"], + ["▁⁇▁", "▁", "a", "b", "▁a", "▁b", ""], + True, + ), # nemo-bpe +] + +TEST_MAPPINGS = [ + ( + [" ", "a", "b", ""], + [" ", "a", "b", ""], + ), # make sure ctc blank doesn"t get added if exists +] + +BPE_TEST_MAPPINGS = [ + ( + ["▁⁇▁", "▁", "a", "b", "▁a", "▁b"], + ["▁⁇▁", "▁", "a", "b", "▁a", "▁b", ""], + ), # bpe in correct form + ( + ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "##a", "##b", "a", "b"], + ["", "▁⁇▁", "[CLS]", "[SEP]", "[MASK]", "a", "b", "▁a", "▁b"], + ), # other special tokens +] - # missing blank char - alphabet_list = [" ", "a", "b"] - norm_alphabet = _normalize_alphabet(alphabet_list) - expected_alphabet = [" ", "a", "b", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - # invalid input - alphabet_list = [" ", "a", "bb"] - with self.assertRaises(ValueError): - _normalize_alphabet(alphabet_list) +class TestModelHelpers(unittest.TestCase): + def test_normalize_alphabet(self): + for labels, expected_labels in TEST_MAPPINGS: + norm_labels = _normalize_regular_alphabet(labels) + self.assertListEqual(norm_labels, expected_labels) def test_normalize_alphabet_bpe(self): - # style ▁ input - alphabet_list = ["▁⁇▁", "▁B", "ugs", "▁", "▁bunny", ""] - norm_alphabet = _normalize_bpe_alphabet(alphabet_list) - expected_alphabet = ["▁⁇▁", "▁B", "ugs", "▁", "▁bunny", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - - # style ▁ with missing blank char - alphabet_list = ["▁⁇▁", "▁B", "ugs"] - norm_alphabet = _normalize_bpe_alphabet(alphabet_list) - expected_alphabet = ["▁⁇▁", "▁B", "ugs", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - - # other unk style - alphabet_list = ["", "▁B", "ugs"] - norm_alphabet = _normalize_bpe_alphabet(alphabet_list) - expected_alphabet = ["▁⁇▁", "▁B", "ugs", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - - # style ## input - alphabet_list = ["B", "##ugs", ""] - norm_alphabet = _normalize_bpe_alphabet(alphabet_list) - expected_alphabet = ["▁⁇▁", "▁B", "ugs", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - - # style ## with single ▁ char - alphabet_list = ["B", "##ugs", "▁", ""] - norm_alphabet = _normalize_bpe_alphabet(alphabet_list) - expected_alphabet = ["▁⁇▁", "▁B", "ugs", "▁", ""] - self.assertListEqual(norm_alphabet, expected_alphabet) - - # invalid input - alphabet_list = ["B", "##ugs", "▁bunny", ""] - with self.assertRaises(ValueError): - _normalize_bpe_alphabet(alphabet_list) + for labels, expected_labels in BPE_TEST_MAPPINGS: + norm_labels = _normalize_bpe_alphabet(labels) + self.assertListEqual(norm_labels, expected_labels) def test_alphabets(self): - label_list = [" ", "a", "b", ""] - alphabet = Alphabet.build_alphabet(label_list) - expected_labels = [" ", "a", "b", ""] - self.assertFalse(alphabet.is_bpe) - self.assertListEqual(alphabet.labels, expected_labels) - - label_list = ["B", "##ugs", ""] - alphabet_bpe = Alphabet.build_bpe_alphabet(label_list) - expected_labels = ["▁⁇▁", "▁B", "ugs", ""] - self.assertTrue(alphabet_bpe.is_bpe) - self.assertListEqual(alphabet_bpe.labels, expected_labels) - - def test_missing_space(self): - """Ensure detection of missing space char in vocabulary.""" - label_list = ["a", "b", "c", ""] + for labels, expected_labels, expected_is_bpe in KNOWN_MAPPINGS: + alphabet = Alphabet.build_alphabet(labels) + self.assertListEqual(alphabet.labels, expected_labels) + self.assertEqual(alphabet.is_bpe, expected_is_bpe) + + def test_asserts(self): + # duplication + label_list = ["a", "a", "b", "c"] with self.assertRaises(ValueError): Alphabet.build_alphabet(label_list) - - def test_unknown_bpe_format(self): - """Ensure detection of a bad bpe format.""" - label_list = ["a", "b", "c", " ", ""] + # bpe with space + label_list = ["▁a", " "] with self.assertRaises(ValueError): - Alphabet.build_bpe_alphabet(label_list) - - def test_unk_bpe_char_assignment(self): - """Ensure assignment of unk_bpe_char in alphabet normalization.""" - label_list = ["##", "##hi"] - labels = Alphabet.build_bpe_alphabet(label_list).labels - self.assertEqual(labels, ["▁⁇▁", "hi", ""]) + Alphabet.build_alphabet(label_list) diff --git a/pyctcdecode/tests/test_decoder.py b/pyctcdecode/tests/test_decoder.py index 3c7c8f5..8d1051c 100644 --- a/pyctcdecode/tests/test_decoder.py +++ b/pyctcdecode/tests/test_decoder.py @@ -10,7 +10,7 @@ import kenlm # type: ignore import numpy as np -from ..alphabet import BPE_CHAR, UNK_BPE_CHAR, Alphabet +from ..alphabet import BPE_TOKEN, UNK_BPE_TOKEN, Alphabet from ..decoder import ( BeamSearchDecoderCTC, _merge_beams, @@ -157,8 +157,8 @@ def test_prune_history(self): ] # basic 2-gram kenlm model trained with 'bugs bunny' -KENLM_BINARY_PATH = os.path.join(CUR_PATH, "sample_data", "bugs_bunny_kenlm.arpa") -TEST_KENLM_MODEL = kenlm.Model(KENLM_BINARY_PATH) +KENLM_MODEL_PATH = os.path.join(CUR_PATH, "sample_data", "bugs_bunny_kenlm.arpa") +TEST_KENLM_MODEL = kenlm.Model(KENLM_MODEL_PATH) SAMPLE_LABELS = [" ", "b", "g", "n", "s", "u", "y", ""] SAMPLE_VOCAB = {c: n for n, c in enumerate(SAMPLE_LABELS)} @@ -250,19 +250,19 @@ def test_decoder(self): self.assertEqual(len(BeamSearchDecoderCTC.model_container), 0) def test_build_ctcdecoder(self): - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH) text = decoder.decode(TEST_LOGITS) self.assertEqual(text, "bugs bunny") def test_decode_batch(self): - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL, TEST_UNIGRAMS) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH, TEST_UNIGRAMS) with multiprocessing.Pool() as pool: text_list = decoder.decode_batch(pool, [TEST_LOGITS] * 5) expected_text_list = ["bugs bunny"] * 5 self.assertListEqual(expected_text_list, text_list) def test_decode_beams_batch(self): - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL, TEST_UNIGRAMS) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH, TEST_UNIGRAMS) with multiprocessing.Pool() as pool: text_list = decoder.decode_beams_batch(pool, [TEST_LOGITS] * 5) expected_text_list = [ @@ -327,7 +327,7 @@ def test_multi_lm(self): self.assertListEqual(_approx_lm_beams(beams_1), _approx_lm_beams(beams_2)) def test_pruning(self): - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH) text = decoder.decode(TEST_LOGITS) self.assertEqual(text, "bugs bunny") text = _greedy_decode(TEST_LOGITS, decoder._alphabet) # pylint: disable=W0212 @@ -364,7 +364,7 @@ def test_stateful(self): self.assertEqual(text, "bugs bugs") # now let's add a LM - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL, TEST_UNIGRAMS) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH, TEST_UNIGRAMS) # LM correctly picks up the higher bigram probability for 'bugs bunny' over 'bugs bugs' text = decoder.decode(bunny_bunny_probs) @@ -380,7 +380,7 @@ def test_stateful(self): self.assertEqual(text, "bugs bunny") def test_hotwords(self): - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH) text = decoder.decode(TEST_LOGITS) self.assertEqual(text, "bugs bunny") @@ -420,7 +420,7 @@ def test_beam_results(self): # if we add the language model, that should push bugs bunny to the top, far enough to # remove all other beams from the output - decoder = build_ctcdecoder(SAMPLE_LABELS, TEST_KENLM_MODEL) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH) beams = decoder.decode_beams(TEST_LOGITS) self.assertEqual(len(beams), 1) top_beam = beams[0] @@ -481,10 +481,10 @@ def test_realistic_alphabet(self): self.assertEqual(len(beams[0][0].split()), len(beams[0][2])) # test with fake BPE vocab, spoof space with with ▁▁ - libri_labels_bpe = [UNK_BPE_CHAR, BPE_CHAR * 2] + LIBRI_LABELS[1:] + libri_labels_bpe = [UNK_BPE_TOKEN, BPE_TOKEN] + ["##" + c for c in LIBRI_LABELS[1:]] zero_row = np.array([[-100.0] * LIBRI_LOGITS.shape[0]]).T libri_logits_bpe = np.hstack([zero_row, LIBRI_LOGITS]) - decoder = build_ctcdecoder(libri_labels_bpe, is_bpe=True) + decoder = build_ctcdecoder(libri_labels_bpe) text = decoder.decode(libri_logits_bpe) expected_text = ( "i have a good deal of will you remember and what i have set my mind upon no doubt " @@ -520,8 +520,7 @@ def test_invalid_logit_inputs(self, logits: np.ndarray): lm_score_boundary=st.one_of(st.none(), st.booleans()), ) def test_fuzz_reset_params(self, alpha, beta, unk_score_offset, lm_score_boundary): - language_model = LanguageModel(TEST_KENLM_MODEL, alpha=0.0) - decoder = build_ctcdecoder(SAMPLE_LABELS, language_model) + decoder = build_ctcdecoder(SAMPLE_LABELS, KENLM_MODEL_PATH, alpha=0.0) decoder.reset_params( alpha=alpha, beta=beta, diff --git a/tutorials/00_basic_usage.ipynb b/tutorials/00_basic_usage.ipynb index d30dcd9..49133df 100644 --- a/tutorials/00_basic_usage.ipynb +++ b/tutorials/00_basic_usage.ipynb @@ -166,7 +166,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -180,7 +180,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.10" } }, "nbformat": 4, diff --git a/tutorials/01_pipeline_nemo.ipynb b/tutorials/01_pipeline_nemo.ipynb index c7d2b85..097ab63 100644 --- a/tutorials/01_pipeline_nemo.ipynb +++ b/tutorials/01_pipeline_nemo.ipynb @@ -14,7 +14,7 @@ "outputs": [], "source": [ "# install NeMo\n", - "!pip install \"nemo-toolkit[asr]==1.0.0rc1\"" + "!pip install \"nemo-toolkit[asr]==1.3.0\"" ] }, { @@ -26,16 +26,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "--2021-04-27 12:16:42-- https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav\n", - "Resolving dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)... 52.219.101.58\n", - "Connecting to dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)|52.219.101.58|:443... connected.\n", + "--2021-10-01 10:20:31-- https://dldata-public.s3.us-east-2.amazonaws.com/1919-142785-0028.wav\n", + "Resolving dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)... 52.219.88.16\n", + "Connecting to dldata-public.s3.us-east-2.amazonaws.com (dldata-public.s3.us-east-2.amazonaws.com)|52.219.88.16|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 165164 (161K) [audio/wav]\n", "Saving to: ‘1919-142785-0028.wav’\n", "\n", - "1919-142785-0028.wa 100%[===================>] 161.29K 810KB/s in 0.2s \n", + "1919-142785-0028.wa 100%[===================>] 161.29K --.-KB/s in 0.08s \n", "\n", - "2021-04-27 12:16:42 (810 KB/s) - ‘1919-142785-0028.wav’ saved [165164/165164]\n", + "2021-10-01 10:20:31 (1.94 MB/s) - ‘1919-142785-0028.wav’ saved [165164/165164]\n", "\n" ] } @@ -61,41 +61,45 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "dd961d5964ba4042a18617fb0dcde9e8", + "model_id": "ba9ee9e92297439f81dded72d3590a74", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Transcribing', max=1.0, style=ProgressStyle(description_w…" + "Transcribing: 0%| | 0/1 [00:00\"` with `\"\"` and `\"|\"` with `\" \"` as well as the other special tokens (which are essentially unused)\n", - "\n", - "We need to standardize the special tokens and then specifically pass which index is the ctc blank token index (since it's not the last). For that reason we have to manually build the Alphabet and the decoder instead of using the convenience wrapper `build_ctcdecoder`." + "The vocabulary is in a slightly unconventional shape, for example the blank ctc token is ``. Pyctcdecode tries to automatically convert that and will give a warning if it's not sure. In this case it does a good job so we don't need to modify the vocabulary by hand." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 10, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?\n" + ] + }, { "data": { "text/plain": [ "'BOIL THEM BEFORE THEY ARE PUT INTO THE SOUP OR OTHER DISH THEY MAY BE INTENDED FOR'" ] }, - "execution_count": 7, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from pyctcdecode import Alphabet, BeamSearchDecoderCTC\n", - "\n", - "# make alphabet\n", - "vocab_list = list(asr_processor.tokenizer.get_vocab().keys())\n", - "# convert ctc blank character representation\n", - "vocab_list[0] = \"\"\n", - "# replace special characters\n", - "vocab_list[1] = \"⁇\"\n", - "vocab_list[2] = \"⁇\"\n", - "vocab_list[3] = \"⁇\"\n", - "# convert space character representation\n", - "vocab_list[4] = \" \"\n", - "# specify ctc blank char index, since conventionally it is the last entry of the logit matrix\n", - "alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=0)\n", + "from pyctcdecode import build_ctcdecoder\n", "\n", - "# build the decoder and decode the logits\n", - "decoder = BeamSearchDecoderCTC(alphabet)\n", + "decoder = build_ctcdecoder(vocab_list)\n", "decoder.decode(logits)" ] }, @@ -140,7 +154,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -154,7 +168,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.7.10" } }, "nbformat": 4,