Skip to content

Commit

Permalink
Merge pull request #18 from pzelasko/11-descriptive-names-for-edit_di…
Browse files Browse the repository at this point in the history
…stance

Descriptive param names, typing, update docs, add pre-commit hooks
  • Loading branch information
pzelasko authored Mar 5, 2024
2 parents 80e08e9 + a027d5b commit c2a6d29
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 32 deletions.
27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-ast
- id: check-executables-have-shebangs
- id: end-of-file-fixer
- id: mixed-line-ending
- id: trailing-whitespace

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args: ['--select=E9,F63,F7,F82']

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black]

- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
additional_dependencies: ['click==8.0.1']
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ python3 -m pip install --verbose .

### Alignment

`align(seq1, seq2, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence.
`align(ref, hyp, epsilon)` - used to obtain the alignment between two string sequences. `epsilon` should be a null symbol (indicating deletion/insertion) that doesn't exist in either sequence.

```python
from kaldialign import align
Expand All @@ -46,7 +46,7 @@ assert ali == [('a', 'a'), ('b', 's'), (EPS, 'x'), ('c', 'c')]

### Edit distance

`edit_distance(seq1, seq2)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions.
`edit_distance(ref, hyp)` - used to obtain the total edit distance, as well as the number of insertions, deletions and substitutions.

```python
from kaldialign import edit_distance
Expand All @@ -67,7 +67,7 @@ based on SCLITE style weights, i.e., insertion/deletion cost 3 and substitution

### Bootstrapping method to extract WER 95% confidence intervals

`boostrap_wer_ci(ref, hyp)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method.
`boostrap_wer_ci(ref, hyp, hyp2=None)` - obtain the 95% confidence intervals for WER using Bisani and Ney boostrapping method.

```python
from kaldialign import bootstrap_wer_ci
Expand Down Expand Up @@ -123,4 +123,4 @@ assert ans["p_s2_improv_over_s1"] == 1.0

## Motivation

The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using Cython, avoiding the issue altogether.
The need for this arised from the fact that practically all implementations of the Levenshtein distance have slight differences, making it impossible to use a different scoring tool than Kaldi and get the same error rate results. This package copies code from Kaldi directly and wraps it using pybind11, avoiding the issue altogether.
63 changes: 42 additions & 21 deletions kaldialign/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,55 @@
import math
from typing import List, Tuple
import random
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union

import _kaldialign

Symbol = TypeVar("Symbol")


def edit_distance(a, b, sclite_mode=False):
def edit_distance(
ref: Iterable[Symbol], hyp: Iterable[Symbol], sclite_mode: bool = False
) -> Dict[str, Union[int, float]]:
"""
Compute the edit distance between sequences ``a`` and ``b``.
Compute the edit distance between sequences ``ref`` and ``hyp``.
Both sequences can be strings or lists of strings or ints.
Optional ``sclite_mode`` sets INS/DEL/SUB costs to 3/3/4 for
compatibility with sclite tool.
Returns a dict with keys ``ins``, ``del``, ``sub``, ``total``,
which stand for the count of insertions, deletions, substitutions,
and the total number of errors.
Returns a dict with keys:
* ``ins`` -- the number of insertions (in ``hyp`` vs ``ref``)
* ``del`` -- the number of deletions (in ``hyp`` vs ``ref``)
* ``sub`` -- the number of substitutions
* ``total`` -- total number of errors
* ``ref_len`` -- the number of symbols in ``ref``
* ``err_rate`` -- the error rate (total number of errors divided by ``ref_len``)
"""
int2sym = dict(enumerate(sorted(set(a) | set(b))))
int2sym = dict(enumerate(sorted(set(ref) | set(hyp))))
sym2int = {v: k for k, v in int2sym.items()}

ai: List[int] = []
bi: List[int] = []
for sym in a:
ai.append(sym2int[sym])
refi: List[int] = []
hypi: List[int] = []
for sym in ref:
refi.append(sym2int[sym])

for sym in b:
bi.append(sym2int[sym])
for sym in hyp:
hypi.append(sym2int[sym])

return _kaldialign.edit_distance(ai, bi, sclite_mode)
ans = _kaldialign.edit_distance(refi, hypi, sclite_mode)
ans["ref_len"] = len(refi)
ans["err_rate"] = ans["total"] / len(refi)
return ans


def align(a, b, eps_symbol, sclite_mode=False):
def align(
ref: Iterable[Symbol],
hyp: Iterable[Symbol],
eps_symbol: Symbol,
sclite_mode: bool = False,
) -> List[Tuple[Symbol, Symbol]]:
"""
Compute the alignment between sequences ``a`` and ``b``.
Compute the alignment between sequences ``ref`` and ``hyp``.
Both sequences can be strings or lists of strings or ints.
``eps_symbol`` is used as a blank symbol to indicate insertion or deletion.
Expand All @@ -44,16 +61,16 @@ def align(a, b, eps_symbol, sclite_mode=False):
in the first pair index indicates insertion, and in the second pair index, deletion.
Mismatched symbols indicate substitution.
"""
int2sym = dict(enumerate(sorted(set(a) | set(b) | {eps_symbol})))
int2sym = dict(enumerate(sorted(set(ref) | set(hyp) | {eps_symbol})))
sym2int = {v: k for k, v in int2sym.items()}

ai: List[int] = []
bi: List[int] = []

for sym in a:
for sym in ref:
ai.append(sym2int[sym])

for sym in b:
for sym in hyp:
bi.append(sym2int[sym])

eps_int = sym2int[eps_symbol]
Expand All @@ -67,8 +84,12 @@ def align(a, b, eps_symbol, sclite_mode=False):


def bootstrap_wer_ci(
ref_seqs, hyp_seqs, hyp2_seqs=None, replications: int = 10000, seed: int = 0
):
ref_seqs: Sequence[Sequence[Symbol]],
hyp_seqs: Sequence[Sequence[Symbol]],
hyp2_seqs: Optional[Sequence[Sequence[Symbol]]] = None,
replications: int = 10000,
seed: int = 0,
) -> Dict:
"""
Compute a boostrapping of WER to extract the 95% confidence interval (CI)
using the bootstrap method of Bisani and Ney [1].
Expand Down
56 changes: 49 additions & 7 deletions tests/test_align.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kaldialign import align, edit_distance, bootstrap_wer_ci
from kaldialign import align, bootstrap_wer_ci, edit_distance

EPS = "*"

Expand All @@ -9,21 +9,42 @@ def test_align():
ali = align(a, b, EPS)
assert ali == [("a", "a"), ("b", "s"), (EPS, "x"), ("c", "c")]
dist = edit_distance(a, b)
assert dist == {"ins": 1, "del": 0, "sub": 1, "total": 2}
assert dist == {
"ins": 1,
"del": 0,
"sub": 1,
"total": 2,
"ref_len": 3,
"err_rate": 2 / 3,
}

a = ["a", "b"]
b = ["b", "c"]
ali = align(a, b, EPS)
assert ali == [("a", EPS), ("b", "b"), (EPS, "c")]
dist = edit_distance(a, b)
assert dist == {"ins": 1, "del": 1, "sub": 0, "total": 2}
assert dist == {
"ins": 1,
"del": 1,
"sub": 0,
"total": 2,
"ref_len": 2,
"err_rate": 1.0,
}

a = ["A", "B", "C"]
b = ["D", "C", "A"]
ali = align(a, b, EPS)
assert ali == [("A", "D"), ("B", EPS), ("C", "C"), (EPS, "A")]
dist = edit_distance(a, b)
assert dist == {"ins": 1, "del": 1, "sub": 1, "total": 3}
assert dist == {
"ins": 1,
"del": 1,
"sub": 1,
"total": 3,
"ref_len": 3,
"err_rate": 1.0,
}

a = ["A", "B", "C", "D"]
b = ["C", "E", "D", "F"]
Expand All @@ -37,21 +58,42 @@ def test_align():
(EPS, "F"),
]
dist = edit_distance(a, b)
assert dist == {"ins": 2, "del": 2, "sub": 0, "total": 4}
assert dist == {
"ins": 2,
"del": 2,
"sub": 0,
"total": 4,
"ref_len": 4,
"err_rate": 1.0,
}


def test_edit_distance():
a = ["a", "b", "c"]
b = ["a", "s", "x", "c"]
results = edit_distance(a, b)
assert results == {"ins": 1, "del": 0, "sub": 1, "total": 2}
assert results == {
"ins": 1,
"del": 0,
"sub": 1,
"total": 2,
"ref_len": 3,
"err_rate": 2 / 3,
}


def test_edit_distance_sclite():
a = ["a", "b"]
b = ["b", "c"]
results = edit_distance(a, b, sclite_mode=True)
assert results == {"ins": 1, "del": 1, "sub": 0, "total": 2}
assert results == {
"ins": 1,
"del": 1,
"sub": 0,
"total": 2,
"ref_len": 2,
"err_rate": 1.0,
}


def test_bootstrap_wer_ci_1system():
Expand Down

0 comments on commit c2a6d29

Please sign in to comment.