From 9a473650520da1465ac74ded639239ecaa8326a7 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 6 Sep 2024 11:04:51 +0200 Subject: [PATCH 1/7] Add greedy_time_constrained_combination_matching --- .../cy_greedy_combination_matching.pyx | 48 +++++++++++++ .../matching/greedy_combination_matching.py | 69 ++++++++++++++++++- 2 files changed, 114 insertions(+), 3 deletions(-) diff --git a/meeteval/wer/matching/cy_greedy_combination_matching.pyx b/meeteval/wer/matching/cy_greedy_combination_matching.pyx index 451adc64..3fdf6db1 100644 --- a/meeteval/wer/matching/cy_greedy_combination_matching.pyx +++ b/meeteval/wer/matching/cy_greedy_combination_matching.pyx @@ -41,3 +41,51 @@ def cy_forward_col( else: tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1]) return np.asarray(tmp[current]).copy() + + +@cython.boundscheck(False) # turn off bounds-checking for entire function +@cython.wraparound(False) # turn off negative index wrapping for entire function +def cy_forward_col_time_constrained( + npy_uint64[:] column not None, + npy_uint64[:] a not None, + npy_uint64[:] b not None, + npy_float64[:] a_begin not None, + npy_float64[:] a_end not None, + npy_float64[:] b_begin not None, + npy_float64[:] b_end not None, + uint cost_substitution = 1, +) -> np.ndarray: + """ + Args: + column: The column to be updated + a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!) + b: Sequence in row direction. This function updates `column` `len(b)` times + tmp: Pre-allocated temporary memory. Must have the same shape as `column` + cost_substitution: Cost for a substitution + """ + cdef uint i, j, a_, b_, current, prev + cdef double a_begin_, a_end_, b_begin_, b_end_ + cdef npy_uint64[:, :] tmp = np.empty((2, column.shape[0]), dtype=np.uint) + tmp[0, ...] = column + current = 0 + for j in range(b.shape[0]): + current = (j + 1) % 2 + prev = j % 2 + b_ = b[j] + b_begin_ = b_begin[j] + b_end_ = b_end[j] + + # TODO: can we swap without copy? + tmp[current, 0] = tmp[prev, 0] + 1 + for i in range(1, a.shape[0] + 1): + a_ = a[i - 1] + if a_begin[i - 1] >= b_end_ or b_begin_ >= a_end[i - 1]: + # No overlap + tmp[current, i] = min([tmp[current, i - 1] + 1, tmp[prev, i] + 1]) + elif a_ == b_: + # Overlap correct + tmp[current, i] = tmp[prev, i - 1] + else: + # Overlap incorrect + tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1]) + return np.asarray(tmp[current]).copy() diff --git a/meeteval/wer/matching/greedy_combination_matching.py b/meeteval/wer/matching/greedy_combination_matching.py index 1ca21b72..ab3506e1 100644 --- a/meeteval/wer/matching/greedy_combination_matching.py +++ b/meeteval/wer/matching/greedy_combination_matching.py @@ -1,12 +1,11 @@ import itertools import functools -from typing import List, Iterable - +from typing import List, Iterable, Tuple import numpy as np from meeteval.io.seglst import SegLST -from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col +from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col, cy_forward_col_time_constrained def _apply_assignment(assignment, segments, n=None): @@ -270,3 +269,67 @@ def greedy_combination_matching( ) return distance, assignment + + +def greedy_time_constrained_combination_matching( + segments: List[Iterable[Tuple[int, float, float]]], + streams: List[Iterable[Tuple[int, float, float]]], + initial_assignment: List[int], + *, + distancetype: str = '21', # '21', '2', '1' +): + """ + Segments in `segments` are assigned to streams in `streams`. + + Args: + segments: A list of segments for which stream labels should be obtained + streams: A list of streams to which the segments are assigned + initial_assignment: The initial assignment of the segments to the streams. + Can be obtained with `initialize_assignment`. + distancetype: The type of distance to use. Can be one of: + - `'1'`: Use insertion cost of 1 (like in Levenshtein distance) + - `'2'`: Use insertion cost of 2 (cost of insertion + deletion) + - `'21'`: Start with '2' until converged and then use '1' until converged + + >>> greedy_time_constrained_combination_matching( + ... [[(0, 0, 1), (1, 1, 2)]], + ... [[(0, 0, 1), (1, 1, 2)]], + ... [0] + ... ) + (0, [0]) + + >>> greedy_time_constrained_combination_matching( + ... [[(0, 0, 1)], [(1, 1, 2)]], + ... [[(0, 0, 1)], [(1, 1, 2)]], + ... [0, 0] + ... ) + (0, [0, 1]) + """ + if len(segments) == 0: + return sum([len(s) for s in streams]), [] + if len(streams) == 0: + return sum([len(s) for s in segments]), [0] * len(segments) + + assert len(initial_assignment) == len(segments), (len(initial_assignment), len(segments), initial_assignment) + + # Correct assignment + assignment = initial_assignment + assert distancetype in ('1', '2', '21'), distancetype + for d in distancetype: + def forward_col(column, a, b): + return cy_forward_col_time_constrained( + column, + a=np.asarray([t for t, _, _ in a], dtype=np.uint), + b=np.asarray([t for t, _, _ in b], dtype=np.uint), + a_begin=np.asarray([t for _, t, _ in a], dtype=np.float64), + a_end=np.asarray([t for _, _, t in a], dtype=np.float64), + b_begin=np.asarray([t for _, t, _ in b], dtype=np.float64), + b_end=np.asarray([t for _, _, t in b], dtype=np.float64), + cost_substitution=int(d) + ) + assignment, distance = _greedy_correct_assignment( + segments, streams, assignment, + forward_col + ) + + return distance, assignment From e3593cfa0e89edcb3f5961f5dd62d35667b2b666 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 6 Sep 2024 11:05:22 +0200 Subject: [PATCH 2/7] Test greedy time constrained core matching function against tc siso matching --- tests/test_greedy_combination_matching.py | 61 +++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_greedy_combination_matching.py b/tests/test_greedy_combination_matching.py index 3823f690..10efcc03 100644 --- a/tests/test_greedy_combination_matching.py +++ b/tests/test_greedy_combination_matching.py @@ -1,3 +1,4 @@ +import itertools from pathlib import Path import pytest @@ -73,3 +74,63 @@ def test_optimal_assignment_is_not_changed(segments, streams): assert optimal_distance == greedy_distance assert optimal_assignment == greedy_assignment + + +# Limit alphabet to ensure a few correct matches +string = st.text(alphabet='abcdefg', min_size=0, max_size=100) + + +@st.composite +def string_with_timing(draw): + """ + Constraints: + - end >= start + - start values must be increasing + """ + s = draw(string) + t = [] + start = 0 + for _ in s: + start = draw(st.integers(min_value=start, max_value=10)) + end = draw(st.integers(min_value=start, max_value=start + 10)) + t.append((start, end)) + return s, t + + +@given( + string_with_timing(), + string_with_timing(), +) +def test_greedy_time_constrained_correct(a, b): + """ + Tests the time-constrained cython matrix implementation against the + time-constrained distance C++ implementation used in the time-constrained + siso WER + """ + from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance + from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col_time_constrained + import numpy as np + + a, a_timing = a + b, b_timing = b + + # cy_forward_col_time_constrained needs the sequences as integers + import collections + sym2int = collections.defaultdict(itertools.count().__next__) + _ = sym2int[''] # Reserve 0 for the empty string + a = [sym2int[c] for c in a] + b = [sym2int[c] for c in b] + + siso_dist = time_constrained_levenshtein_distance(a, b, a_timing, b_timing) + + column = cy_forward_col_time_constrained( + np.arange(len(a) + 1, dtype=np.uint), + np.asarray(a, dtype=np.uint), + np.asarray(b, dtype=np.uint), + np.asarray([t[0] for t in a_timing], float), + np.asarray([t[1] for t in a_timing], float), + np.asarray([t[0] for t in b_timing], float), + np.asarray([t[1] for t in b_timing], float), + ) + + assert siso_dist == column[-1] From afb8184d83a7aaf99567dc287fdc6e711815dc31 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 6 Sep 2024 11:05:57 +0200 Subject: [PATCH 3/7] Add greedy time-constrained ORC wer --- meeteval/wer/__main__.py | 31 ++++++- meeteval/wer/api.py | 34 ++++++++ meeteval/wer/wer/time_constrained_orc.py | 100 ++++++++++++++++++++++- tests/test_cli.py | 8 ++ 4 files changed, 171 insertions(+), 2 deletions(-) diff --git a/meeteval/wer/__main__.py b/meeteval/wer/__main__.py index 992dd2f4..714fefce 100644 --- a/meeteval/wer/__main__.py +++ b/meeteval/wer/__main__.py @@ -1,6 +1,5 @@ import argparse import dataclasses -import glob import json import logging import os @@ -301,6 +300,35 @@ def tcorcwer( _save_results(results, hypothesis, per_reco_out, average_out) +def greedy_tcorcwer( + reference, hypothesis, + average_out='{parent}/{stem}_greedy_tcorcwer.json', + per_reco_out='{parent}/{stem}_greedy_tcorcwer_per_reco.json', + regex=None, + collar=0, + hyp_pseudo_word_timing='character_based_points', + ref_pseudo_word_timing='character_based', + hypothesis_sort='segment', + reference_sort='segment', + uem=None, + normalizer=None, + partial=False, +): + """Computes the time-constrained ORC WER (tcORC WER)""" + results = meeteval.wer.greedy_tcorcwer( + reference, hypothesis, regex=regex, + collar=collar, + hyp_pseudo_word_timing=hyp_pseudo_word_timing, + ref_pseudo_word_timing=ref_pseudo_word_timing, + hypothesis_sort=hypothesis_sort, + reference_sort=reference_sort, + uem=uem, partial=partial, + normalizer=normalizer, + ) + _save_results(results, hypothesis, per_reco_out, average_out) + + + def greedy_dicpwer( reference, hypothesis, average_out='{parent}/{stem}_greedy_dicpwer.json', @@ -651,6 +679,7 @@ def cli(): cli.add_command(tcpwer) cli.add_command(tcorcwer) cli.add_command(greedy_dicpwer) + cli.add_command(greedy_tcorcwer) cli.add_command(merge) cli.add_command(average) diff --git a/meeteval/wer/api.py b/meeteval/wer/api.py index af6d4d32..f1cef3c2 100644 --- a/meeteval/wer/api.py +++ b/meeteval/wer/api.py @@ -13,6 +13,7 @@ 'mimower', 'tcpwer', 'tcorcwer', + 'greedy_tcorcwer', 'greedy_dicpwer', ] @@ -282,3 +283,36 @@ def greedy_dicpwer( hypothesis_sort=hypothesis_sort, ) return results + + +def greedy_tcorcwer( + reference, hypothesis, + regex=None, + collar=0, + hyp_pseudo_word_timing='character_based_points', + ref_pseudo_word_timing='character_based', + hypothesis_sort='segment', + reference_sort='segment', + uem=None, + normalizer=None, + partial=False, +): + """Computes the time-constrained ORC WER with a greedy algorithm""" + from meeteval.wer.wer.time_constrained_orc import greedy_time_constrained_orc_wer_multifile + reference, hypothesis = _load_texts(reference, hypothesis, regex=regex, uem=uem, normalizer=normalizer) + results = greedy_time_constrained_orc_wer_multifile( + reference, hypothesis, + reference_pseudo_word_level_timing=ref_pseudo_word_timing, + hypothesis_pseudo_word_level_timing=hyp_pseudo_word_timing, + collar=collar, + hypothesis_sort=hypothesis_sort, + reference_sort=reference_sort, + partial=partial, + ) + from meeteval.wer import combine_error_rates + average: ErrorRate = combine_error_rates(results) + if average.hypothesis_self_overlap is not None: + average.hypothesis_self_overlap.warn('hypothesis') + if average.reference_self_overlap is not None: + average.reference_self_overlap.warn('reference') + return results diff --git a/meeteval/wer/wer/time_constrained_orc.py b/meeteval/wer/wer/time_constrained_orc.py index 7d80fd6d..4d601da3 100644 --- a/meeteval/wer/wer/time_constrained_orc.py +++ b/meeteval/wer/wer/time_constrained_orc.py @@ -14,6 +14,8 @@ __all__ = [ 'time_constrained_orc_wer', 'time_constrained_orc_wer_multifile', + 'greedy_time_constrained_orc_wer', + 'greedy_time_constrained_orc_wer_multifile', ] @@ -30,7 +32,7 @@ def time_constrained_orc_wer( The time-constrained version of the ORC-WER (tcORC-WER). Special cases where the reference or hypothesis is empty - # >>> time_constrained_orc_wer([], []) + >>> time_constrained_orc_wer([], []) OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=()) >>> time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}]) OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=()) @@ -106,3 +108,99 @@ def time_constrained_orc_wer_multifile( reference_sort=reference_sort, ), reference, hypothesis, partial=partial) return r + + +def greedy_time_constrained_orc_wer( + reference: 'SegLST', + hypothesis: 'SegLST', + reference_pseudo_word_level_timing='character_based', + hypothesis_pseudo_word_level_timing='character_based_points', + collar: int = 0, + reference_sort='segment', + hypothesis_sort='segment', +): + """ + Special cases where the reference or hypothesis is empty + >>> greedy_time_constrained_orc_wer([], []) + OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=()) + >>> greedy_time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}]) + OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=()) + >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}], []) + OrcErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=1, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('dummy',)) + >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': '', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}]) + OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',)) + >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a b', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a d', 'speaker': 'A'}]) + OrcErrorRate(error_rate=0.5, errors=1, length=2, insertions=0, deletions=0, substitutions=1, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',)) + """ + from meeteval.wer.matching.greedy_combination_matching import greedy_time_constrained_combination_matching, \ + initialize_assignment + from meeteval.wer.wer.orc import _orc_error_rate + + def matching(reference, hypothesis): + """Use the mimo matching algorithm. Convert inputs and outputs between the formats""" + distance, assignment = greedy_time_constrained_combination_matching( + [list(zip(*r)) for r in reference.T['words', 'start_time', 'end_time']], + [[w for words in stream.T['words', 'start_time', 'end_time'] for w in zip(*words)] for stream in + hypothesis.values()], + initial_assignment=initialize_assignment(reference, hypothesis, initialization='cp'), + ) + return distance, assignment + + def siso(reference, hypothesis): + return _time_constrained_siso_error_rate( + reference.flatmap(lambda x: [ + {**x, 'words': w, 'start_time': s, 'end_time': e} + for w, s, e in zip(x['words'], x['start_time'], x['end_time']) + ]), + hypothesis.flatmap(lambda x: [ + {**x, 'words': w, 'start_time': s, 'end_time': e} + for w, s, e in zip(x['words'], x['start_time'], x['end_time']) + ]), + ) + + # Drop segment index in reference. It will get a new one after merging by speakers + reference = meeteval.io.asseglst(reference) + reference = reference.map(lambda x: {k: v for k, v in x.items() if k != 'segment_index'}) + + reference, hypothesis, ref_self_overlap, hyp_self_overlap = preprocess( + reference, hypothesis, + keep_keys=('words', 'segment_index', 'speaker', 'start_time', 'end_time'), + reference_sort=reference_sort, + hypothesis_sort=hypothesis_sort, + segment_representation='segment', + segment_index='segment', + remove_empty_segments=False, + convert_to_int=True, + reference_pseudo_word_level_timing=reference_pseudo_word_level_timing, + hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing, + collar=collar, + ) + + er = _orc_error_rate(reference, hypothesis, matching, siso) + er = dataclasses.replace( + er, + reference_self_overlap=ref_self_overlap, + hypothesis_self_overlap=hyp_self_overlap, + ) + return er + + +def greedy_time_constrained_orc_wer_multifile( + reference: 'STM', hypothesis: 'STM', + reference_pseudo_word_level_timing='character_based', + hypothesis_pseudo_word_level_timing='character_based_points', + collar: int = 0, + hypothesis_sort='segment', + reference_sort='segment', + partial=False, +) -> 'dict[str, OrcErrorRate]': + from meeteval.io.seglst import apply_multi_file + r = apply_multi_file(lambda r, h: greedy_time_constrained_orc_wer( + r, h, + reference_pseudo_word_level_timing=reference_pseudo_word_level_timing, + hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing, + collar=collar, + hypothesis_sort=hypothesis_sort, + reference_sort=reference_sort, + ), reference, hypothesis, partial=partial) + return r diff --git a/tests/test_cli.py b/tests/test_cli.py index 9a0e4ec8..2dd5f9ea 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -127,6 +127,14 @@ def test_burn_tcorc(): run(f'python -m meeteval.wer tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true') +def test_burn_greedy_tcorc(): + run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm') + run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --collar 5') + run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hyp-pseudo-word-timing equidistant_points') + run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.seglst.json -r ref.seglst.json') + run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true') + + def test_burn_md_eval_22(): run(f'python -m meeteval.der md_eval_22 -h hyp.stm -r ref.stm') run(f'meeteval-der md_eval_22 -h hyp.stm -r ref.stm') From 0eb7b5b8bd9ed6252b794966db8e004322a5f0ed Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 6 Sep 2024 12:31:44 +0200 Subject: [PATCH 4/7] Add tests for greedy time-constrained combination matching --- tests/test_greedy_combination_matching.py | 67 +++++++++++++++++++---- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/tests/test_greedy_combination_matching.py b/tests/test_greedy_combination_matching.py index 10efcc03..abe1adf8 100644 --- a/tests/test_greedy_combination_matching.py +++ b/tests/test_greedy_combination_matching.py @@ -2,11 +2,10 @@ from pathlib import Path import pytest -import time from meeteval.wer.matching import greedy_combination_matching, orc_matching import meeteval -from hypothesis import given, strategies as st, settings, assume +from hypothesis import given, strategies as st, settings example_files = (Path(__file__).parent.parent / 'example_files').absolute() @@ -18,31 +17,32 @@ streams = lambda max_size: st.lists(utterance, min_size=1, max_size=max_size) -def _check_output(distance, assignment, segments, streams): +def _check_output(distance, assignment, segments, streams, check_distance=True): assert isinstance(distance, int) assert distance >= 0 assert len(assignment) == len(segments) assert all(isinstance(a, int) for a in assignment) assert all(0 <= a < len(streams) for a in assignment) - d = orc_matching._levensthein_distance_for_assignment(segments, streams, assignment) - assert d == distance, (d, distance, assignment) + if check_distance: + d = orc_matching._levensthein_distance_for_assignment(segments, streams, assignment) + assert d == distance, (d, distance, assignment) -@given(segments(10), streams(4)) +@given(segments(10), streams(4), st.sampled_from(['1', '2', '21'])) @settings(deadline=None) # The tests take longer on the GitHub actions test servers -def test_greedy_combination_matching_burn(segments, streams): +def test_greedy_combination_matching_burn(segments, streams, distancetype): """Burn-test. Brute-force is exponential in the number of reference utterances, so choose a small number.""" distance, assignment = greedy_combination_matching.greedy_combination_matching( - segments, streams, [0] * len(segments) + segments, streams, [0] * len(segments), distancetype=distancetype ) - _check_output(distance, assignment, segments, streams) + _check_output(distance, assignment, segments, streams, check_distance=distancetype != '2') -@given(segments(6), streams(4)) -def test_greedy_bound_by_optimal(segments, streams): +@given(segments(6), streams(4), st.sampled_from(['1', '2', '21'])) +def test_greedy_bound_by_optimal(segments, streams, distancetype): greedy_distance, _ = greedy_combination_matching.greedy_combination_matching( - segments, streams, [0] * len(segments) + segments, streams, [0] * len(segments), distancetype=distancetype ) optimal_distance, _ = orc_matching.orc_matching_v3(segments, streams) @@ -134,3 +134,46 @@ def test_greedy_time_constrained_correct(a, b): ) assert siso_dist == column[-1] + + + +utterance_with_timings = st.lists(st.tuples(st.integers(min_value=0, max_value=10), st.floats(allow_nan=False, allow_infinity=False), st.floats(0, 1, allow_nan=False, allow_infinity=False)), min_size=1, max_size=10) +segments_with_timing = lambda max_size: st.lists(utterance_with_timings, min_size=1, max_size=max_size) +streams_with_timing = lambda max_size: st.lists(utterance_with_timings, min_size=1, max_size=max_size) + + +@given(segments_with_timing(2), streams_with_timing(2), st.sampled_from(['1', '2', '21'])) +def test_greedy_time_constrained_bound_by_optimal(segments, streams, distancetype): + greedy_distance, _ = greedy_combination_matching.greedy_time_constrained_combination_matching( + segments, streams, [0] * len(segments), distancetype=distancetype + ) + + from meeteval.wer.matching.cy_time_constrained_orc_matching import time_constrained_orc_levenshtein_distance + optimal_distance, _ = time_constrained_orc_levenshtein_distance( + [[s[0] for s in ss] for ss in segments], + [[s[0] for s in ss] for ss in streams], + [[s[1:] for s in ss] for ss in segments], + [[s[1:] for s in ss] for ss in streams], + ) + assert optimal_distance <= greedy_distance + + +@given(segments(4), streams(6), st.sampled_from(['1', '2', '21'])) +def test_greedy_vs_greedy_time_constrained(segments, streams, distancetype): + """ + Tests that the time-constrained version gives the same result as the + non-time-constrained version when all words overlap. + """ + greedy_distance, greedy_assignment = greedy_combination_matching.greedy_combination_matching( + segments, streams, [0] * len(segments), distancetype=distancetype + ) + + greedy_time_constrained_distance, greedy_time_constrained_assignment = greedy_combination_matching.greedy_time_constrained_combination_matching( + [[(w, 0, 1) for w in segment] for segment in segments], + [[(w, 0, 1) for w in stream] for stream in streams], + [0] * len(segments), + distancetype=distancetype + ) + + assert greedy_time_constrained_distance == greedy_distance + assert greedy_time_constrained_assignment == greedy_assignment From 55f7c63cd613d5f7a01911855eb940642ead8386 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 6 Sep 2024 12:56:55 +0200 Subject: [PATCH 5/7] Flake8 --- meeteval/wer/wer/time_constrained_orc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/meeteval/wer/wer/time_constrained_orc.py b/meeteval/wer/wer/time_constrained_orc.py index 4d601da3..b76763b1 100644 --- a/meeteval/wer/wer/time_constrained_orc.py +++ b/meeteval/wer/wer/time_constrained_orc.py @@ -111,8 +111,8 @@ def time_constrained_orc_wer_multifile( def greedy_time_constrained_orc_wer( - reference: 'SegLST', - hypothesis: 'SegLST', + reference, + hypothesis, reference_pseudo_word_level_timing='character_based', hypothesis_pseudo_word_level_timing='character_based_points', collar: int = 0, From e0dd3a903b7acb6ee57a2554a606cfc496c694a5 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Tue, 17 Sep 2024 11:08:03 +0200 Subject: [PATCH 6/7] Remove old comments from greedy matching Cython code --- meeteval/wer/matching/cy_greedy_combination_matching.pyx | 4 ---- 1 file changed, 4 deletions(-) diff --git a/meeteval/wer/matching/cy_greedy_combination_matching.pyx b/meeteval/wer/matching/cy_greedy_combination_matching.pyx index 3fdf6db1..66b5a874 100644 --- a/meeteval/wer/matching/cy_greedy_combination_matching.pyx +++ b/meeteval/wer/matching/cy_greedy_combination_matching.pyx @@ -21,7 +21,6 @@ def cy_forward_col( column: The column to be updated a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!) b: Sequence in row direction. This function updates `column` `len(b)` times - tmp: Pre-allocated temporary memory. Must have the same shape as `column` cost_substitution: Cost for a substitution """ cdef uint i, j, a_, b_, current, prev @@ -32,7 +31,6 @@ def cy_forward_col( current = (j + 1) % 2 prev = j % 2 b_ = b[j] - # TODO: can we swap without copy? tmp[current, 0] = tmp[prev, 0] + 1 for i in range(1, a.shape[0] + 1): a_ = a[i - 1] @@ -60,7 +58,6 @@ def cy_forward_col_time_constrained( column: The column to be updated a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!) b: Sequence in row direction. This function updates `column` `len(b)` times - tmp: Pre-allocated temporary memory. Must have the same shape as `column` cost_substitution: Cost for a substitution """ cdef uint i, j, a_, b_, current, prev @@ -75,7 +72,6 @@ def cy_forward_col_time_constrained( b_begin_ = b_begin[j] b_end_ = b_end[j] - # TODO: can we swap without copy? tmp[current, 0] = tmp[prev, 0] + 1 for i in range(1, a.shape[0] + 1): a_ = a[i - 1] From ffec650bf204441182997ee6058ade5bfa677bf5 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Tue, 17 Sep 2024 12:41:27 +0200 Subject: [PATCH 7/7] Add greedy tcORC-WER to README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 19d6a7c2..df855389 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ MeetEval supports the following metrics for meeting transcription evaluation: `meeteval-wer tcpwer -r ref.stm -h hyp.stm --collar 5` - **Time-Constrained Optimal Reference Combination Word Error Rate (tcORC WER)**
`meeteval-wer tcorcwer -r ref.stm -h hyp.stm --collar 5` +- **Fast Greedy Approximation of Time-Constrained Optimal Reference Combination Word Error Rate (greedy tcORC WER)**
+ `meeteval-wer greedy_tcorcwer -r ref.stm -h hyp.stm --collar 5` - **Diarization-Invariant cpWER (DI-cpWER)**
`meeteval-wer greedy_dicpwer -r ref.stm -h hyp.stm` - **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl)