Skip to content

Commit

Permalink
Add robustness to automatic decoder instantiation (#4)
Browse files Browse the repository at this point in the history
Simplify automatic vocabulary parsing
  • Loading branch information
gkucsko authored Oct 1, 2021
1 parent a624a46 commit 455f16a
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 254 deletions.
207 changes: 118 additions & 89 deletions pyctcdecode/alphabet.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,122 @@
# Copyright 2021-present Kensho Technologies, LLC.
from __future__ import division

import logging
from typing import List, Optional
import re
from typing import Collection, List


BPE_TOKEN = "▁" # nosec # representation of token boundary in BPE alphabet
UNK_TOKEN = "⁇" # nosec # representation of special UNK token in regular alphabet
UNK_BPE_TOKEN = "▁⁇▁" # nosec # representation of special UNK token in BPE alphabet

BPE_CHAR = "▁" # character used for token boundary if BPE is used
UNK_BPE_CHAR = "▁⁇▁" # representation of unknown character in BPE
# special tokens are usually encode with things like `[]` or `<>`
SPECIAL_TOKEN_PTN = re.compile(r"^[<\[].+[>\]]$")
BLANK_TOKEN_PTN = re.compile(r"^[<\[]pad[>\]]$", flags=re.IGNORECASE)
UNK_TOKEN_PTN = re.compile(r"^[<\[]unk[>\]]$", flags=re.IGNORECASE)

logger = logging.getLogger(__name__)


def _get_ctc_index(label_list: List[str]) -> int:
"""Get index of ctc blank character in alphabet."""
return len(label_list) - 1 if label_list[-1] == "" else -1


def _normalize_alphabet(label_list: List[str], ctc_token_idx: Optional[int] = None) -> List[str]:
"""Normalize alphabet for non-bpe decoder."""
if any([len(c) > 1 for c in label_list]):
raise ValueError("For non-bpe alphabet only length 1 entries and blank token are allowed.")
if ctc_token_idx is None:
ctc_token_idx = _get_ctc_index(label_list)
clean_labels = label_list[:]
# check for space token
if " " not in clean_labels:
raise ValueError("Space token ' ' missing from vocabulary.")
# specify ctc blank token
if ctc_token_idx == -1:
clean_labels.append("")
def _check_if_bpe(labels: List[str]) -> bool:
"""Check if input alphabet is BPE or not."""
is_bpe = any([s.startswith("##") for s in labels]) or any(
[s.startswith(BPE_TOKEN) for s in labels]
)
if is_bpe:
logger.info("Alphabet determined to be of BPE style.")
else:
clean_labels[ctc_token_idx] = ""
return clean_labels
logger.info("Alphabet determined to be of regular style.")
return is_bpe


def _normalize_regular_alphabet(labels: List[str]) -> List[str]:
"""Normalize non-bpe labels to alphabet for decoder."""
normalized_labels = labels[:]
# substitute space characters
if "|" in normalized_labels and " " not in normalized_labels:
logger.info("Found '|' in vocabulary but not ' ', doing substitution.")
normalized_labels[normalized_labels.index("|")] = " "
# substituted ctc blank char
for n, label in enumerate(normalized_labels):
if BLANK_TOKEN_PTN.match(label):
logger.info(
"Found %s in vocabulary, interpreted as a CTC blank token, substituting with %s.",
label,
"",
)
normalized_labels[n] = ""
if "_" in normalized_labels and "" not in normalized_labels:
logger.info("Found '_' in vocabulary but not '', doing substitution.")
normalized_labels[normalized_labels.index("_")] = ""
if "" not in normalized_labels:
logger.info("CTC blank char '' not found, appending to end.")
normalized_labels.append("")
# substitute unk
for n, label in enumerate(normalized_labels):
if UNK_TOKEN_PTN.match(label):
logger.info(
"Found %s in vocabulary, interpreting as unknown token, substituting with %s.",
label,
UNK_TOKEN,
)
normalized_labels[n] = UNK_TOKEN
# additional checks
if any([len(c) > 1 for c in normalized_labels]):
logger.warning(
"Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the "
"alphabet was not recognized as BPE type. Is this correct?"
)
if " " not in normalized_labels:
logger.warning("Space token ' ' missing from vocabulary.")
return normalized_labels


def _convert_bpe_format(token: str) -> str:
"""Convert token from ## type bpe format to ▁ type."""
if token[:2] == "##":
def _convert_bpe_token_style(token: str) -> str:
"""Convert token from ## style bpe format to ▁ style."""
if token.startswith("##"):
return token[2:]
elif token == BPE_CHAR:
return token
elif token == "": # nosec
elif SPECIAL_TOKEN_PTN.match(token) or token in ("", BPE_TOKEN, UNK_BPE_TOKEN):
return token
elif token in ("<unk>", UNK_BPE_CHAR):
elif token in ("<unk>", UNK_BPE_TOKEN):
return token
else:
return BPE_CHAR + token
return BPE_TOKEN + token


def _normalize_bpe_alphabet(
label_list: List[str],
unk_token_idx: Optional[int] = None,
ctc_token_idx: Optional[int] = None,
) -> List[str]:
def _normalize_bpe_alphabet(labels: List[str]) -> List[str]:
"""Normalize alphabet for bpe decoder."""
if ctc_token_idx is None:
ctc_token_idx = _get_ctc_index(label_list)
# create copy
clean_labels = label_list[:]
# there are two common formats for BPE vocabulary
# 1) where ▁ indicates a space (this is the main format we use)
if any([s[:1] == BPE_CHAR and len(s) > 1 for s in clean_labels]):
# verify unk token and make sure it is consistently represented as ▁⁇▁
if unk_token_idx is None and clean_labels[0] in ("<unk>", UNK_BPE_CHAR):
unk_token_idx = 0
else:
raise ValueError(
"First token in vocab for BPE should be '▁⁇▁' or specify unk_token_idx."
)
clean_labels[unk_token_idx] = UNK_BPE_CHAR
# 2) where ## indicates continuation of a token (note: also contains the single token: ▁)
elif any([s[:2] == "##" for s in clean_labels]):
# convert to standard format 1)
clean_labels = [_convert_bpe_format(c) for c in clean_labels]
# add unk token if needed
if clean_labels[0] in ("<unk>", UNK_BPE_CHAR):
clean_labels[0] = UNK_BPE_CHAR
else:
clean_labels = [UNK_BPE_CHAR] + clean_labels
ctc_token_idx += 1
else:
raise ValueError(
"Unknown BPE format for vocabulary. Supported formats are 1) ▁ for indicating a"
" space and 2) ## for continuation of a word."
)
# specify ctc blank token
if ctc_token_idx == -1:
clean_labels.append("")
else:
clean_labels[ctc_token_idx] = ""
return clean_labels
normalized_labels = labels[:]
# if BPE is of style '##' then convert it
if any([s.startswith("##") for s in labels]):
normalized_labels = [_convert_bpe_token_style(c) for c in normalized_labels]
# substituted ctc blank char
for n, label in enumerate(normalized_labels):
if BLANK_TOKEN_PTN.match(label):
logger.info("Found %s in vocabulary, substituting with %s.", label, "")
normalized_labels[n] = ""
if "" not in normalized_labels:
logger.info("CTC blank char '' not found, appending to end.")
normalized_labels.append("")
# substitute unk
for n, label in enumerate(normalized_labels):
if UNK_TOKEN_PTN.match(label):
logger.info("Found %s in vocabulary, substituting with %s.", label, UNK_BPE_TOKEN)
normalized_labels[n] = UNK_BPE_TOKEN
# additional checks
if UNK_BPE_TOKEN not in normalized_labels:
logger.warning("UNK token %s not found, is this a mistake?", UNK_BPE_TOKEN)
return normalized_labels


def _verify_alphabet(labels: List[str], is_bpe: bool) -> None:
"""Verify basic alphabet labels."""
# check if duplicates exist
if len(labels) != len(set(labels)):
raise ValueError("Alphabet contains duplicate entries, this is not allowed.")
# check if space character is absent in bpe alphabet
if is_bpe and any([" " in s for s in labels]):
raise ValueError("Space token ' ' found in vocabulary even though it looks like BPE.")


class Alphabet:
Expand All @@ -98,7 +127,7 @@ def __init__(self, labels: List[str], is_bpe: bool) -> None:

@property
def is_bpe(self) -> bool:
"""Whether the alphabet is a bytepair encoded one."""
"""Whether the alphabet is bpe style."""
return self._is_bpe

@property
Expand All @@ -107,20 +136,20 @@ def labels(self) -> List[str]:
return self._labels[:] # this is a copy

@classmethod
def build_alphabet(
cls, label_list: List[str], ctc_token_idx: Optional[int] = None
) -> "Alphabet":
"""Make a non-BPE alphabet."""
formatted_alphabet_list = _normalize_alphabet(label_list, ctc_token_idx)
return cls(formatted_alphabet_list, False)
def build_alphabet(cls, labels: List[str]) -> "Alphabet":
"""Make an alphabet from labels in standardized format for decoder."""
is_bpe = _check_if_bpe(labels)
_verify_alphabet(labels, is_bpe)
if is_bpe:
normalized_labels = _normalize_bpe_alphabet(labels)
else:
normalized_labels = _normalize_regular_alphabet(labels)
return cls(normalized_labels, is_bpe)

@classmethod
def build_bpe_alphabet(
cls,
label_list: List[str],
unk_token_idx: Optional[int] = None,
ctc_token_idx: Optional[int] = None,
) -> "Alphabet":
"""Make a BPE alphabet."""
formatted_label_list = _normalize_bpe_alphabet(label_list, unk_token_idx, ctc_token_idx)
return cls(formatted_label_list, True)

def verify_alphabet_coverage(alphabet: Alphabet, unigrams: Collection[str]) -> None:
"""Verify if alphabet covers a given unigrams."""
label_chars = set(alphabet.labels)
unigram_sample_chars = set("".join(unigrams))
if len(unigram_sample_chars - label_chars) / len(unigram_sample_chars) > 0.2:
logger.warning("Unigrams and labels don't seem to agree.")
45 changes: 28 additions & 17 deletions pyctcdecode/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import logging
import math
import os
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np

from .alphabet import BPE_CHAR, Alphabet
from .alphabet import BPE_TOKEN, Alphabet, verify_alphabet_coverage
from .constants import (
DEFAULT_ALPHA,
DEFAULT_BEAM_WIDTH,
Expand All @@ -23,7 +23,12 @@
DEFAULT_UNK_LOGP_OFFSET,
MIN_TOKEN_CLIP_P,
)
from .language_model import AbstractLanguageModel, HotwordScorer, LanguageModel
from .language_model import (
AbstractLanguageModel,
HotwordScorer,
LanguageModel,
load_unigram_set_from_arpa,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -363,13 +368,13 @@ def _decode_logits(
)
)
# if bpe and leading space char
elif self._is_bpe and (char[:1] == BPE_CHAR or force_next_break):
elif self._is_bpe and (char[:1] == BPE_TOKEN or force_next_break):
force_next_break = False
# some tokens are bounded on both sides like ▁⁇▁
clean_char = char
if char[:1] == BPE_CHAR:
if char[:1] == BPE_TOKEN:
clean_char = clean_char[1:]
if char[-1:] == BPE_CHAR:
if char[-1:] == BPE_TOKEN:
clean_char = clean_char[:-1]
force_next_break = True
new_frame_list = (
Expand Down Expand Up @@ -668,35 +673,41 @@ def decode_batch(

def build_ctcdecoder(
labels: List[str],
kenlm_model: Optional[kenlm.Model] = None,
unigrams: Optional[Iterable[str]] = None,
kenlm_model_path: Optional[str] = None,
unigrams: Optional[Collection[str]] = None,
alpha: float = DEFAULT_ALPHA,
beta: float = DEFAULT_BETA,
unk_score_offset: float = DEFAULT_UNK_LOGP_OFFSET,
lm_score_boundary: bool = DEFAULT_SCORE_LM_BOUNDARY,
ctc_token_idx: Optional[int] = None,
is_bpe: bool = False,
) -> BeamSearchDecoderCTC:
"""Build a BeamSearchDecoderCTC instance with main functionality.
Args:
labels: class containing the labels for input logit matrices
kenlm_model: instance of kenlm n-gram language model `kenlm.Model`
kenlm_model_path: path to kenlm n-gram language model
unigrams: list of known word unigrams
alpha: weight for language model during shallow fusion
beta: weight for length score adjustment of during scoring
unk_score_offset: amount of log score offset for unknown tokens
lm_score_boundary: whether to have kenlm respect boundaries when scoring
ctc_token_idx: index of ctc blank token within the labels
is_bpe: indicate if labels are BPE type
Returns:
instance of BeamSearchDecoderCTC
"""
if is_bpe:
alphabet = Alphabet.build_bpe_alphabet(labels, ctc_token_idx=ctc_token_idx)
else:
alphabet = Alphabet.build_alphabet(labels, ctc_token_idx=ctc_token_idx)
kenlm_model = None if kenlm_model_path is None else kenlm.Model(kenlm_model_path)
if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"):
logger.info("Using arpa instead of binary LM file, decoder instantiation might be slow.")
if unigrams is None and kenlm_model_path is not None:
if kenlm_model_path.endswith(".arpa"):
unigrams = load_unigram_set_from_arpa(kenlm_model_path)
else:
logger.warning(
"Unigrams not provided and cannot be automatically determined from LM file (only "
"arpa format). Decoding accuracy might be reduced."
)
alphabet = Alphabet.build_alphabet(labels)
if unigrams is not None:
verify_alphabet_coverage(alphabet, unigrams)
if kenlm_model is not None:
language_model: Optional[AbstractLanguageModel] = LanguageModel(
kenlm_model,
Expand Down
Loading

0 comments on commit 455f16a

Please sign in to comment.