Skip to content

Commit

Permalink
Add py.typed and remove some type ignore tags (#101)
Browse files Browse the repository at this point in the history
This should ensure that the package works properly with mypy when imported. There are still quite a few type: ignores left for numpy types that should be cleaned up. This also adds a reset_params() function to the abstract language model so that we're not accessing implementation details like parameters directly here.
  • Loading branch information
lopez86 authored Jan 17, 2023
1 parent 30b72e1 commit 0ca5ccc
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 23 deletions.
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
graft pyctcdecode/tests/sample_data
graft tutorials
include pyctcdecode/py.typed
graft tutorials
71 changes: 50 additions & 21 deletions pyctcdecode/decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2021-present Kensho Technologies, LLC.
from __future__ import division
from __future__ import annotations, division

import functools
import heapq
Expand All @@ -9,9 +9,22 @@
from multiprocessing.pool import Pool
import os
from pathlib import Path
from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
import sys
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import numpy as np # type: ignore
import numpy as np
from numpy.typing import NBitBase, NDArray

from .alphabet import BPE_TOKEN, Alphabet, verify_alphabet_coverage
from .constants import (
Expand Down Expand Up @@ -71,6 +84,19 @@
EMPTY_START_BEAM: Beam = ("", "", "", None, [], NULL_FRAMES, 0.0)


# Generic float type
if sys.version_info < (3, 8):
NpFloat = Any
else:
if sys.version_info < (3, 9) and not TYPE_CHECKING:
NpFloat = Any
else:
NpFloat = np.floating[NBitBase]

FloatVar = TypeVar("FloatVar", bound=NpFloat)
Shape = TypeVar("Shape")


def _get_valid_pool(pool: Optional[Pool]) -> Optional[Pool]:
"""Return the pool if the pool is appropriate for multiprocessing."""
if pool is not None and isinstance(
Expand Down Expand Up @@ -106,21 +132,21 @@ def _sum_log_scores(s1: float, s2: float) -> float:


def _log_softmax(
x: np.ndarray, # type: ignore [type-arg]
x: np.ndarray[Shape, np.dtype[FloatVar]],
axis: Optional[int] = None,
) -> np.ndarray: # type: ignore [type-arg]
) -> np.ndarray[Shape, np.dtype[FloatVar]]:
"""Logarithm of softmax function, following implementation of scipy.special."""
x_max = np.amax(x, axis=axis, keepdims=True)
if x_max.ndim > 0:
x_max[~np.isfinite(x_max)] = 0
elif not np.isfinite(x_max):
x_max = 0 # pylint: disable=R0204
tmp = x - x_max
exp_tmp = np.exp(tmp)
exp_tmp: np.ndarray[Shape, np.dtype[FloatVar]] = np.exp(tmp)
# suppress warnings about log of zero
with np.errstate(divide="ignore"):
s = np.sum(exp_tmp, axis=axis, keepdims=True) # type: ignore [arg-type]
out: np.ndarray = np.log(s) # type: ignore [type-arg]
s = np.sum(exp_tmp, axis=axis, keepdims=True)
out: np.ndarray[Shape, np.dtype[FloatVar]] = np.log(s)
out = tmp - out
return out

Expand Down Expand Up @@ -237,15 +263,20 @@ def reset_params(
lm_score_boundary: Optional[bool] = None,
) -> None:
"""Reset parameters that don't require re-instantiating the model."""
# todo: make more generic to accomodate other language models
language_model = self._language_model
if language_model is None:
return
params: Dict[str, Any] = {}
if alpha is not None:
language_model.alpha = alpha # type: ignore
params["alpha"] = alpha
if beta is not None:
language_model.beta = beta # type: ignore
params["beta"] = beta
if unk_score_offset is not None:
language_model.unk_score_offset = unk_score_offset # type: ignore
params["unk_score_offset"] = unk_score_offset
if lm_score_boundary is not None:
language_model.score_boundary = lm_score_boundary # type: ignore
params["score_boundary"] = lm_score_boundary
language_model.reset_params(**params)

@classmethod
def clear_class_models(cls) -> None:
Expand All @@ -264,7 +295,7 @@ def _language_model(self) -> Optional[AbstractLanguageModel]:

def _check_logits_dimension(
self,
logits: np.ndarray, # type: ignore [type-arg]
logits: NDArray[NpFloat],
) -> None:
"""Verify correct shape and dimensions for input logits."""
if len(logits.shape) != 2:
Expand Down Expand Up @@ -358,7 +389,7 @@ def _get_lm_beams(

def _decode_logits(
self,
logits: np.ndarray, # type: ignore [type-arg]
logits: NDArray[NpFloat],
beam_width: int,
beam_prune_logp: float,
token_min_logp: float,
Expand Down Expand Up @@ -528,7 +559,7 @@ def _decode_logits(

def decode_beams(
self,
logits: np.ndarray, # type: ignore [type-arg]
logits: NDArray[NpFloat],
beam_width: int = DEFAULT_BEAM_WIDTH,
beam_prune_logp: float = DEFAULT_PRUNE_LOGP,
token_min_logp: float = DEFAULT_MIN_TOKEN_LOGP,
Expand Down Expand Up @@ -575,7 +606,7 @@ def decode_beams(

def _decode_beams_mp_safe(
self,
logits: np.ndarray, # type: ignore [type-arg]
logits: NDArray[NpFloat],
beam_width: int,
beam_prune_logp: float,
token_min_logp: float,
Expand Down Expand Up @@ -603,7 +634,7 @@ def _decode_beams_mp_safe(
def decode_beams_batch(
self,
pool: Optional[Pool],
logits_list: List[np.ndarray], # type: ignore [type-arg]
logits_list: NDArray[NpFloat],
beam_width: int = DEFAULT_BEAM_WIDTH,
beam_prune_logp: float = DEFAULT_PRUNE_LOGP,
token_min_logp: float = DEFAULT_MIN_TOKEN_LOGP,
Expand Down Expand Up @@ -660,7 +691,7 @@ def decode_beams_batch(

def decode(
self,
logits: np.ndarray, # type: ignore [type-arg]
logits: NDArray[NpFloat],
beam_width: int = DEFAULT_BEAM_WIDTH,
beam_prune_logp: float = DEFAULT_PRUNE_LOGP,
token_min_logp: float = DEFAULT_MIN_TOKEN_LOGP,
Expand Down Expand Up @@ -697,7 +728,7 @@ def decode(
def decode_batch(
self,
pool: Optional[Pool],
logits_list: List[np.ndarray], # type: ignore [type-arg]
logits_list: NDArray[NpFloat],
beam_width: int = DEFAULT_BEAM_WIDTH,
beam_prune_logp: float = DEFAULT_PRUNE_LOGP,
token_min_logp: float = DEFAULT_MIN_TOKEN_LOGP,
Expand Down Expand Up @@ -822,8 +853,6 @@ def load_from_hf_hub( # type: ignore
Returns:
instance of BeamSearchDecoderCTC
"""
import sys

if sys.version_info >= (3, 8):
from importlib.metadata import metadata
else:
Expand Down
37 changes: 36 additions & 1 deletion pyctcdecode/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
from typing import Any, Collection, Dict, Iterable, List, Optional, Pattern, Set, Tuple, cast

import numpy as np # type: ignore
import numpy as np
from pygtrie import CharTrie # type: ignore

from .constants import (
Expand Down Expand Up @@ -193,6 +193,9 @@ def load_from_dir(cls, filepath: str) -> "AbstractLanguageModel":
"""Load a model from a directory."""
raise NotImplementedError()

def reset_params(self, **params: Dict[str, Any]) -> None:
"""Reset some of the parameters in place."""


class LanguageModel(AbstractLanguageModel):
# serializatoin constants
Expand Down Expand Up @@ -235,6 +238,38 @@ def __init__(
self.unk_score_offset = unk_score_offset
self.score_boundary = score_boundary

def reset_params(self, **params: Dict[str, Any]) -> None:
"""Reset some of the simple parameters.
The allowed parameters are [alpha, beta, unk_score_offset, score_boundary]
Args:
params: dict of str to anything
"""
alpha = params.get("alpha")
if alpha is not None:
if not isinstance(alpha, float):
raise ValueError(f"alpha must be a float. Got {type(alpha)}.")
self.alpha = alpha

beta = params.get("beta")
if beta is not None:
if not isinstance(beta, float):
raise ValueError(f"beta must be a float. Got {type(beta)}.")
self.beta = beta

unk_score_offset = params.get("unk_score_offset")
if unk_score_offset is not None:
if not isinstance(unk_score_offset, float):
raise ValueError(f"unk_score_offset must be a float. Got {type(unk_score_offset)}.")
self.unk_score_offset = unk_score_offset

score_boundary = params.get("score_boundary")
if score_boundary is not None:
if not isinstance(score_boundary, bool):
raise ValueError(f"score_boundary must be a bool. Got {type(score_boundary)}.")
self.score_boundary = score_boundary

@property
def order(self) -> int:
"""Get the order of the n-gram language model."""
Expand Down
Empty file added pyctcdecode/py.typed
Empty file.

0 comments on commit 0ca5ccc

Please sign in to comment.