diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..8b24099 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-executables-have-shebangs + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + + - repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: ['--select=E9,F63,F7,F82'] + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile=black] + + - repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black + additional_dependencies: ['click==8.0.1'] diff --git a/README.md b/README.md index 2c72bf2..bc8c7be 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ python3 -m pip install --verbose . ### Alignment -`align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence. +`align(ref, hyp, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence. ```python from kaldialign import align @@ -46,7 +46,7 @@ assert ali == [('a', 'a'), ('b', 's'), (EPS, 'x'), ('c', 'c')] ### Edit distance -`edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions. +`edit_distance(ref, hyp)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions. ```python from kaldialign import edit_distance @@ -67,7 +67,7 @@ based on SCLITE style weights, i.e., insertion/deletion cost 3 and substitution ### Bootstrapping method to extract WER 95% confidence intervals -`boostrap_wer_ci(ref, hyp)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method. +`boostrap_wer_ci(ref, hyp, hyp2=None)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method. ```python from kaldialign import bootstrap_wer_ci @@ -123,4 +123,4 @@ assert ans["p_s2_improv_over_s1"] == 1.0 ## Motivation -The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using Cython, avoiding the issue altogether. +The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using pybind11, avoiding the issue altogether. diff --git a/kaldialign/__init__.py b/kaldialign/__init__.py index bcd3692..fb83d5c 100644 --- a/kaldialign/__init__.py +++ b/kaldialign/__init__.py @@ -1,38 +1,55 @@ import math -from typing import List, Tuple import random +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union + import _kaldialign +Symbol = TypeVar("Symbol") + -def edit_distance(a, b, sclite_mode=False): +def edit_distance( + ref: Iterable[Symbol], hyp: Iterable[Symbol], sclite_mode: bool = False +) -> Dict[str, Union[int, float]]: """ - Compute the edit distance between sequences ``a`` and ``b``. + Compute the edit distance between sequences ``ref`` and ``hyp``. Both sequences can be strings or lists of strings or ints. Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for compatibility with sclite tool. - Returns a dict with keys ``ins``, ``del``, ``sub``, ``total``, - which stand for the count of insertions, deletions, substitutions, - and the total number of errors. + Returns a dict with keys: + * ``ins`` -- the number of insertions (in ``hyp`` vs ``ref``) + * ``del`` -- the number of deletions (in ``hyp`` vs ``ref``) + * ``sub`` -- the number of substitutions + * ``total`` -- total number of errors + * ``ref_len`` -- the number of symbols in ``ref`` + * ``err_rate`` -- the error rate (total number of errors divided by ``ref_len``) """ - int2sym = dict(enumerate(sorted(set(a) | set(b)))) + int2sym = dict(enumerate(sorted(set(ref) | set(hyp)))) sym2int = {v: k for k, v in int2sym.items()} - ai: List[int] = [] - bi: List[int] = [] - for sym in a: - ai.append(sym2int[sym]) + refi: List[int] = [] + hypi: List[int] = [] + for sym in ref: + refi.append(sym2int[sym]) - for sym in b: - bi.append(sym2int[sym]) + for sym in hyp: + hypi.append(sym2int[sym]) - return _kaldialign.edit_distance(ai, bi, sclite_mode) + ans = _kaldialign.edit_distance(refi, hypi, sclite_mode) + ans["ref_len"] = len(refi) + ans["err_rate"] = ans["total"] / len(refi) + return ans -def align(a, b, eps_symbol, sclite_mode=False): +def align( + ref: Iterable[Symbol], + hyp: Iterable[Symbol], + eps_symbol: Symbol, + sclite_mode: bool = False, +) -> List[Tuple[Symbol, Symbol]]: """ - Compute the alignment between sequences ``a`` and ``b``. + Compute the alignment between sequences ``ref`` and ``hyp``. Both sequences can be strings or lists of strings or ints. ``eps_symbol`` is used as a blank symbol to indicate insertion or deletion. @@ -44,16 +61,16 @@ def align(a, b, eps_symbol, sclite_mode=False): in the first pair index indicates insertion, and in the second pair index, deletion. Mismatched symbols indicate substitution. """ - int2sym = dict(enumerate(sorted(set(a) | set(b) | {eps_symbol}))) + int2sym = dict(enumerate(sorted(set(ref) | set(hyp) | {eps_symbol}))) sym2int = {v: k for k, v in int2sym.items()} ai: List[int] = [] bi: List[int] = [] - for sym in a: + for sym in ref: ai.append(sym2int[sym]) - for sym in b: + for sym in hyp: bi.append(sym2int[sym]) eps_int = sym2int[eps_symbol] @@ -67,8 +84,12 @@ def align(a, b, eps_symbol, sclite_mode=False): def bootstrap_wer_ci( - ref_seqs, hyp_seqs, hyp2_seqs=None, replications: int = 10000, seed: int = 0 -): + ref_seqs: Sequence[Sequence[Symbol]], + hyp_seqs: Sequence[Sequence[Symbol]], + hyp2_seqs: Optional[Sequence[Sequence[Symbol]]] = None, + replications: int = 10000, + seed: int = 0, +) -> Dict: """ Compute a boostrapping of WER to extract the 95% confidence interval (CI) using the bootstrap method of Bisani and Ney [1]. diff --git a/tests/test_align.py b/tests/test_align.py index 3e504c2..b2548c4 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,4 +1,4 @@ -from kaldialign import align, edit_distance, bootstrap_wer_ci +from kaldialign import align, bootstrap_wer_ci, edit_distance EPS = "*" @@ -9,21 +9,42 @@ def test_align(): ali = align(a, b, EPS) assert ali == [("a", "a"), ("b", "s"), (EPS, "x"), ("c", "c")] dist = edit_distance(a, b) - assert dist == {"ins": 1, "del": 0, "sub": 1, "total": 2} + assert dist == { + "ins": 1, + "del": 0, + "sub": 1, + "total": 2, + "ref_len": 3, + "err_rate": 2 / 3, + } a = ["a", "b"] b = ["b", "c"] ali = align(a, b, EPS) assert ali == [("a", EPS), ("b", "b"), (EPS, "c")] dist = edit_distance(a, b) - assert dist == {"ins": 1, "del": 1, "sub": 0, "total": 2} + assert dist == { + "ins": 1, + "del": 1, + "sub": 0, + "total": 2, + "ref_len": 2, + "err_rate": 1.0, + } a = ["A", "B", "C"] b = ["D", "C", "A"] ali = align(a, b, EPS) assert ali == [("A", "D"), ("B", EPS), ("C", "C"), (EPS, "A")] dist = edit_distance(a, b) - assert dist == {"ins": 1, "del": 1, "sub": 1, "total": 3} + assert dist == { + "ins": 1, + "del": 1, + "sub": 1, + "total": 3, + "ref_len": 3, + "err_rate": 1.0, + } a = ["A", "B", "C", "D"] b = ["C", "E", "D", "F"] @@ -37,21 +58,42 @@ def test_align(): (EPS, "F"), ] dist = edit_distance(a, b) - assert dist == {"ins": 2, "del": 2, "sub": 0, "total": 4} + assert dist == { + "ins": 2, + "del": 2, + "sub": 0, + "total": 4, + "ref_len": 4, + "err_rate": 1.0, + } def test_edit_distance(): a = ["a", "b", "c"] b = ["a", "s", "x", "c"] results = edit_distance(a, b) - assert results == {"ins": 1, "del": 0, "sub": 1, "total": 2} + assert results == { + "ins": 1, + "del": 0, + "sub": 1, + "total": 2, + "ref_len": 3, + "err_rate": 2 / 3, + } def test_edit_distance_sclite(): a = ["a", "b"] b = ["b", "c"] results = edit_distance(a, b, sclite_mode=True) - assert results == {"ins": 1, "del": 1, "sub": 0, "total": 2} + assert results == { + "ins": 1, + "del": 1, + "sub": 0, + "total": 2, + "ref_len": 2, + "err_rate": 1.0, + } def test_bootstrap_wer_ci_1system():