Skip to content

Commit

Permalink
Merge pull request #87 from fgnt/greedy_di_cp
Browse files Browse the repository at this point in the history
Add greedy DI-cpWER
  • Loading branch information
thequilo authored Aug 29, 2024
2 parents 1b073ae + 9bf9be2 commit e3aee5a
Show file tree
Hide file tree
Showing 14 changed files with 323 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ jobs:
python-version: [3.8, 3.9, '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# created from tests
example_files/*.json
example_files/*.yaml
example_files/viz

# Files in this JSON format are example files and would otherwise be excluded by the rules above
!example_files/*.seglst.json
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ MeetEval supports the following metrics for meeting transcription evaluation:
`meeteval-wer tcpwer -r ref.stm -h hyp.stm --collar 5`
- **Time-Constrained Optimal Reference Combination Word Error Rate (tcORC WER)**<br>
`meeteval-wer tcorcwer -r ref.stm -h hyp.stm --collar 5`
- **Diarization-Invariant cpWER (DI-cpWER)**<br>
`meeteval-wer greedy_dicpwer -r ref.stm -h hyp.stm`
- **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl)<br>
`meeteval-der md_eval_22 -r ref.stm -h hyp.stm --collar .25`

Expand Down
23 changes: 23 additions & 0 deletions meeteval/wer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,28 @@ def tcorcwer(
_save_results(results, hypothesis, per_reco_out, average_out)


def greedy_dicpwer(
reference, hypothesis,
average_out='{parent}/{stem}_greedy_dicpwer.json',
per_reco_out='{parent}/{stem}_greedy_dicpwer_per_reco.json',
regex=None,
reference_sort='segment_if_available',
hypothesis_sort='segment_if_available',
uem=None,
partial=False,
normalizer=None,
):
"""Computes the greedy DI-cpWER"""
results = meeteval.wer.api.greedy_dicpwer(
reference, hypothesis, regex=regex,
reference_sort=reference_sort, hypothesis_sort=hypothesis_sort,
uem=uem,
partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)


def _merge(
files: 'list[str]',
out: str = '-',
Expand Down Expand Up @@ -628,6 +650,7 @@ def cli():
cli.add_command(mimower)
cli.add_command(tcpwer)
cli.add_command(tcorcwer)
cli.add_command(greedy_dicpwer)
cli.add_command(merge)
cli.add_command(average)

Expand Down
28 changes: 27 additions & 1 deletion meeteval/wer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'mimower',
'tcpwer',
'tcorcwer',
'greedy_dicpwer',
]


Expand Down Expand Up @@ -131,7 +132,8 @@ def greedy_orcwer(
partial=False,
normalizer=None,
):
"""Computes the Optimal Reference Combination Word Error Rate (ORC WER)"""
"""Computes the Optimal Reference Combination Word Error Rate (ORC WER)
with a greedy algorithm"""
from meeteval.wer.wer.orc import greedy_orc_word_error_rate_multifile
reference, hypothesis = _load_texts(
reference, hypothesis, regex=regex,
Expand Down Expand Up @@ -256,3 +258,27 @@ def tcorcwer(
if average.reference_self_overlap is not None:
average.reference_self_overlap.warn('reference')
return results


def greedy_dicpwer(
reference, hypothesis,
regex=None,
reference_sort='segment_if_available',
hypothesis_sort='segment_if_available',
uem=None,
partial=False,
normalizer=None,
):
"""Computes the Diarization Invariant cpWER (DI-cpWER) with a greedy
algorithm."""
from meeteval.wer.wer.di_cp import greedy_di_cp_word_error_rate_multifile
reference, hypothesis = _load_texts(
reference, hypothesis, regex=regex,
uem=uem, normalizer=normalizer,
)
results = greedy_di_cp_word_error_rate_multifile(
reference, hypothesis, partial=partial,
reference_sort=reference_sort,
hypothesis_sort=hypothesis_sort,
)
return results
8 changes: 8 additions & 0 deletions meeteval/wer/matching/greedy_combination_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ def initialize_assignment(
[0, 0]
"""
if initialization == 'cp':
# Special case when no streams are present
if len(streams) == 0:
return initialize_assignment(segments, streams, 'constant')

# Special case when no segments are present
if len(segments) == 0:
return []

# Compute cpWER to get a good starting point
from meeteval.wer.wer.cp import _minimum_permutation_assignment
from meeteval.wer.wer.siso import siso_levenshtein_distance
Expand Down
17 changes: 14 additions & 3 deletions meeteval/wer/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def words_to_int(*d: 'SegLST'):
words in d and assigning an integer to each word.
>>> words_to_int(SegLST([{'words': 'a b c'}]))
[SegLST(segments=[{'words': 0}])]
[SegLST(segments=[{'words': 1}])]
>>> words_to_int(SegLST([{'words': 'a'}, {'words': 'b'}]), SegLST([{'words': 'c'}, {'words': 'a'}]))
[SegLST(segments=[{'words': 0}, {'words': 1}]), SegLST(segments=[{'words': 2}, {'words': 0}])]
[SegLST(segments=[{'words': 1}, {'words': 2}]), SegLST(segments=[{'words': 3}, {'words': 1}])]
TODO: use cython code for speedup
TODO: unify everything. This stuff is done in multiple places in the code base.
Expand All @@ -210,8 +210,19 @@ def words_to_int(*d: 'SegLST'):
# `'words'` contains a single word only.
import collections
sym2int = collections.defaultdict(itertools.count().__next__)
_ = sym2int[''] # Reserve 0 for the empty string

d = [d_.map(lambda s: {**s, 'words': [sym2int[w] for w in s['words']] if isinstance(s['words'], list) else sym2int[s['words']]}) for d_ in d]
d = [
d_.map(lambda s: {
**s,
'words': [
sym2int[w]
for w in s['words']]
if isinstance(s['words'], list)
else sym2int[s['words']]
})
for d_ in d
]
return d


Expand Down
1 change: 1 addition & 0 deletions meeteval/wer/wer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .siso import siso_word_error_rate, siso_character_error_rate, siso_word_error_rate_multifile
from .error_rate import ErrorRate, combine_error_rates
from .time_constrained import time_constrained_minimum_permutation_word_error_rate, time_constrained_siso_word_error_rate, tcp_word_error_rate_multifile
from .di_cp import greedy_di_cp_word_error_rate, DICPErrorRate, greedy_di_cp_word_error_rate_multifile
4 changes: 2 additions & 2 deletions meeteval/wer/wer/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _minimum_permutation_assignment(
The score matrix. Shape (reference hypothesis)
>>> _minimum_permutation_assignment({}, {}, lambda x, y: 0)
((), 0)
((), 0, array([], dtype=float64))
>>> _minimum_permutation_assignment({}, {'spkA': meeteval.io.SegLST([])}, lambda x, y: 1)
(((None, 'spkA'),), 1, array([[1]]))
>>> _minimum_permutation_assignment({'spkA': meeteval.io.SegLST([])}, {}, lambda x, y: 1)
Expand All @@ -256,7 +256,7 @@ def _minimum_permutation_assignment(
])

if cost_matrix.size == 0:
return (), 0
return (), 0, cost_matrix

# Find the best permutation with hungarian algorithm
row_ind, col_ind = scipy.optimize.linear_sum_assignment(cost_matrix)
Expand Down
165 changes: 165 additions & 0 deletions meeteval/wer/wer/di_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import dataclasses
import functools
from typing import Tuple

import meeteval
from meeteval.io.seglst import SegLST
from meeteval.wer.wer.error_rate import ErrorRate


__all__ = [
'DICPErrorRate',
'greedy_di_cp_word_error_rate',
'greedy_di_cp_word_error_rate_multifile',
'apply_dicp_assignment',
]


@dataclasses.dataclass(frozen=True)
class DICPErrorRate(ErrorRate):
assignment: Tuple[int, ...]

def apply_assignment(self, reference, hypothesis):
return apply_dicp_assignment(self.assignment, reference, hypothesis)

@classmethod
def from_dict(cls, d):
d.pop('error_rate', None)
return cls(**d)


def greedy_di_cp_word_error_rate(
reference,
hypothesis,
reference_sort='segment_if_available',
hypothesis_sort='segment_if_available',
):
"""
Computes the DI-cpWER with a greedy algorithm
>>> reference = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a'},
... {'segment_index': 1, 'speaker': 'A', 'words': 'b'},
... {'segment_index': 2, 'speaker': 'B', 'words': 'c'},
... {'segment_index': 3, 'speaker': 'B', 'words': 'd'},
... ])
>>> hypothesis = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a'},
... {'segment_index': 1, 'speaker': 'A', 'words': 'b'},
... {'segment_index': 2, 'speaker': 'B', 'words': 'c'},
... {'segment_index': 3, 'speaker': 'B', 'words': 'd'},
... ])
>>> greedy_di_cp_word_error_rate(reference, hypothesis)
DICPErrorRate(error_rate=0.0, errors=0, length=4, insertions=0, deletions=0, substitutions=0, reference_self_overlap=None, hypothesis_self_overlap=None, assignment=('A', 'A', 'B', 'B'))
>>> hypothesis = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a'},
... {'segment_index': 1, 'speaker': 'B', 'words': 'b'},
... {'segment_index': 2, 'speaker': 'A', 'words': 'c'},
... {'segment_index': 3, 'speaker': 'B', 'words': 'd'},
... ])
>>> greedy_di_cp_word_error_rate(reference, hypothesis)
DICPErrorRate(error_rate=0.0, errors=0, length=4, insertions=0, deletions=0, substitutions=0, reference_self_overlap=None, hypothesis_self_overlap=None, assignment=('A', 'A', 'B', 'B'))
>>> hypothesis = SegLST([
... {'segment_index': 0, 'speaker': 'A', 'words': 'a b'},
... {'segment_index': 2, 'speaker': 'A', 'words': 'b c d'},
... ])
>>> greedy_di_cp_word_error_rate(reference, hypothesis)
DICPErrorRate(error_rate=0.25, errors=1, length=4, insertions=1, deletions=0, substitutions=0, reference_self_overlap=None, hypothesis_self_overlap=None, assignment=('A', 'B'))
"""

# The assignment of the DI-cpWER is equal to the assignment of the ORC-WER
# with swapped arguments (reference <-> hypothesis)
er = meeteval.wer.wer.orc.greedy_orc_word_error_rate(
hypothesis, reference,
hypothesis_sort, reference_sort
)

# The error rate object can be constructed just from the ORC-WER error rate
# object. Insertions and deletions are swapped, the length is different.
return DICPErrorRate(
er.errors, sum([len(s['words'].split()) for s in reference]),
insertions=er.deletions,
deletions=er.insertions,
substitutions=er.substitutions,
assignment=er.assignment,
reference_self_overlap=er.hypothesis_self_overlap,
hypothesis_self_overlap=er.reference_self_overlap,
)


def greedy_di_cp_word_error_rate_multifile(
reference,
hypothesis,
partial=False,
reference_sort='segment_if_available',
hypothesis_sort='segment_if_available',
) -> 'dict[str, DICPErrorRate]':
"""
Computes the (Greedy) DI-cpWER for each example in the reference and hypothesis files.
To compute the overall WER, use
`sum(greedy_di_cp_word_error_rate_multifile(r, h).values())`.
"""
from meeteval.io.seglst import apply_multi_file
return apply_multi_file(
functools.partial(
greedy_di_cp_word_error_rate,
reference_sort=reference_sort,
hypothesis_sort=hypothesis_sort,
), reference, hypothesis,
partial=partial
)


def apply_dicp_assignment(
assignment: 'list[int | str] | tuple[int | str]',
reference: 'list[list[str]] | dict[str, list[str]] | SegLST',
hypothesis: 'list[str] | dict[str] | SegLST',
):
"""
Apply DI-cp assignment so that the hypothesis streams match the reference streams.
Computing the standard WER on the output of this function yields the same
result as the DI-cpWER on the input of this function.
Arguments:
assignment: The assignment of hypothesis segments to the reference
streams. The length of the assignment must match the number of
segments in the hypothesis. The assignment is a list of stream
labels, one entry for each stream.
reference: Is passed thorugh unchanged but used to determine the format
of the hypothesis output if it is not SegLST.
hypothesis: The hypothesis segments. This can be a list of lists of
segments, or a SegLST object. If it is a SegLST object, the
"segment_index" field is used to group the segments, if present.
>>> assignment = ('A', 'A', 'B')
>>> apply_dicp_assignment(assignment, {'A': 'a c', 'B': 'd e'}, ['a', 'c d', 'e'])
({'A': 'a c', 'B': 'd e'}, {'A': ['a', 'c d'], 'B': ['e']})
>>> assignment = (0, 0, 1)
>>> apply_dicp_assignment(assignment, ['a c', 'd e'], ['a', 'c d', 'e'])
(['a c', 'd e'], [['a', 'c d'], ['e']])
>>> assignment = ('A', )
>>> apply_dicp_assignment(assignment, {'A': 'b', 'B': 'c'}, ['a'])
({'A': 'b', 'B': 'c'}, {'A': ['a'], 'B': []})
>>> ref = meeteval.io.STM.parse('X 1 A 0.0 1.0 a b\\nX 1 A 1.0 2.0 c d\\nX 1 B 0.0 2.0 e f\\n')
>>> hyp = meeteval.io.STM.parse('X 1 1 0.0 2.0 c d\\nX 1 0 0.0 2.0 a b e f\\n')
>>> ref, hyp = apply_dicp_assignment((0, 1, 1), hyp, ref)
>>> print(ref.dumps())
X 1 1 0.0 2.0 c d
X 1 0 0.0 2.0 a b e f
<BLANKLINE>
>>> print(hyp.dumps())
X 1 0 0.0 1.0 a b
X 1 1 1.0 2.0 c d
X 1 1 0.0 2.0 e f
<BLANKLINE>
"""
# The assignment is identical to the ORC assignment, but with
# reference and hypothesis swapped.
from meeteval.wer.wer.orc import apply_orc_assignment
hypothesis, reference = apply_orc_assignment(assignment, hypothesis, reference)
return reference, hypothesis
16 changes: 8 additions & 8 deletions meeteval/wer/wer/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,20 @@ def greedy_orc_word_error_rate(
OrcErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, assignment=(0, 1, 1))
# One utterance is split
>>> er = greedy_orc_word_error_rate(['a', 'c d', 'e'], ['a c', 'd e'])
>>> er = greedy_orc_word_error_rate(['a', 'c d e'], ['a c', 'd e'])
>>> er
OrcErrorRate(error_rate=0.5, errors=2, length=4, insertions=1, deletions=1, substitutions=0, assignment=(0, 0, 1))
>>> er.apply_assignment(['a', 'c d', 'e'], ['a c', 'd e'])
([['a', 'c d'], ['e']], ['a c', 'd e'])
OrcErrorRate(error_rate=0.5, errors=2, length=4, insertions=1, deletions=1, substitutions=0, assignment=(0, 1))
>>> er.apply_assignment(['a', 'c d e'], ['a c', 'd e'])
([['a'], ['c d e']], ['a c', 'd e'])
>>> greedy_orc_word_error_rate(STM.parse('X 1 A 0.0 1.0 a b\\nX 1 B 0.0 2.0 e f\\nX 1 A 1.0 2.0 c d\\n'), STM.parse('X 1 1 0.0 2.0 c d\\nX 1 0 0.0 2.0 a b e f\\n'))
OrcErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=Decimal('0E+1'), overlap_time=0, total_time=Decimal('4.0')), hypothesis_self_overlap=SelfOverlap(overlap_rate=Decimal('0E+1'), overlap_time=0, total_time=Decimal('4.0')), assignment=('0', '0', '1'))
>>> er = greedy_orc_word_error_rate(['a', 'c d', 'e'], {'A': 'a c', 'B': 'd e'})
>>> er = greedy_orc_word_error_rate(['a', 'c d e'], {'A': 'a c', 'B': 'd e'})
>>> er
OrcErrorRate(error_rate=0.5, errors=2, length=4, insertions=1, deletions=1, substitutions=0, assignment=('A', 'A', 'B'))
>>> er.apply_assignment(['a', 'c d', 'e'], {'A': 'a c', 'B': 'd e'})
({'A': ['a', 'c d'], 'B': ['e']}, {'A': 'a c', 'B': 'd e'})
OrcErrorRate(error_rate=0.5, errors=2, length=4, insertions=1, deletions=1, substitutions=0, assignment=('A', 'B'))
>>> er.apply_assignment(['a', 'c d e'], {'A': 'a c', 'B': 'd e'})
({'A': ['a'], 'B': ['c d e']}, {'A': 'a c', 'B': 'd e'})
>>> greedy_orc_word_error_rate([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': '', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',))
Expand Down
Loading

0 comments on commit e3aee5a

Please sign in to comment.