diff --git a/README.md b/README.md
index 19d6a7c2..df855389 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,8 @@ MeetEval supports the following metrics for meeting transcription evaluation:
`meeteval-wer tcpwer -r ref.stm -h hyp.stm --collar 5`
- **Time-Constrained Optimal Reference Combination Word Error Rate (tcORC WER)**
`meeteval-wer tcorcwer -r ref.stm -h hyp.stm --collar 5`
+- **Fast Greedy Approximation of Time-Constrained Optimal Reference Combination Word Error Rate (greedy tcORC WER)**
+ `meeteval-wer greedy_tcorcwer -r ref.stm -h hyp.stm --collar 5`
- **Diarization-Invariant cpWER (DI-cpWER)**
`meeteval-wer greedy_dicpwer -r ref.stm -h hyp.stm`
- **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl)
diff --git a/meeteval/wer/__main__.py b/meeteval/wer/__main__.py
index 992dd2f4..714fefce 100644
--- a/meeteval/wer/__main__.py
+++ b/meeteval/wer/__main__.py
@@ -1,6 +1,5 @@
import argparse
import dataclasses
-import glob
import json
import logging
import os
@@ -301,6 +300,35 @@ def tcorcwer(
_save_results(results, hypothesis, per_reco_out, average_out)
+def greedy_tcorcwer(
+ reference, hypothesis,
+ average_out='{parent}/{stem}_greedy_tcorcwer.json',
+ per_reco_out='{parent}/{stem}_greedy_tcorcwer_per_reco.json',
+ regex=None,
+ collar=0,
+ hyp_pseudo_word_timing='character_based_points',
+ ref_pseudo_word_timing='character_based',
+ hypothesis_sort='segment',
+ reference_sort='segment',
+ uem=None,
+ normalizer=None,
+ partial=False,
+):
+ """Computes the time-constrained ORC WER (tcORC WER)"""
+ results = meeteval.wer.greedy_tcorcwer(
+ reference, hypothesis, regex=regex,
+ collar=collar,
+ hyp_pseudo_word_timing=hyp_pseudo_word_timing,
+ ref_pseudo_word_timing=ref_pseudo_word_timing,
+ hypothesis_sort=hypothesis_sort,
+ reference_sort=reference_sort,
+ uem=uem, partial=partial,
+ normalizer=normalizer,
+ )
+ _save_results(results, hypothesis, per_reco_out, average_out)
+
+
+
def greedy_dicpwer(
reference, hypothesis,
average_out='{parent}/{stem}_greedy_dicpwer.json',
@@ -651,6 +679,7 @@ def cli():
cli.add_command(tcpwer)
cli.add_command(tcorcwer)
cli.add_command(greedy_dicpwer)
+ cli.add_command(greedy_tcorcwer)
cli.add_command(merge)
cli.add_command(average)
diff --git a/meeteval/wer/api.py b/meeteval/wer/api.py
index af6d4d32..f1cef3c2 100644
--- a/meeteval/wer/api.py
+++ b/meeteval/wer/api.py
@@ -13,6 +13,7 @@
'mimower',
'tcpwer',
'tcorcwer',
+ 'greedy_tcorcwer',
'greedy_dicpwer',
]
@@ -282,3 +283,36 @@ def greedy_dicpwer(
hypothesis_sort=hypothesis_sort,
)
return results
+
+
+def greedy_tcorcwer(
+ reference, hypothesis,
+ regex=None,
+ collar=0,
+ hyp_pseudo_word_timing='character_based_points',
+ ref_pseudo_word_timing='character_based',
+ hypothesis_sort='segment',
+ reference_sort='segment',
+ uem=None,
+ normalizer=None,
+ partial=False,
+):
+ """Computes the time-constrained ORC WER with a greedy algorithm"""
+ from meeteval.wer.wer.time_constrained_orc import greedy_time_constrained_orc_wer_multifile
+ reference, hypothesis = _load_texts(reference, hypothesis, regex=regex, uem=uem, normalizer=normalizer)
+ results = greedy_time_constrained_orc_wer_multifile(
+ reference, hypothesis,
+ reference_pseudo_word_level_timing=ref_pseudo_word_timing,
+ hypothesis_pseudo_word_level_timing=hyp_pseudo_word_timing,
+ collar=collar,
+ hypothesis_sort=hypothesis_sort,
+ reference_sort=reference_sort,
+ partial=partial,
+ )
+ from meeteval.wer import combine_error_rates
+ average: ErrorRate = combine_error_rates(results)
+ if average.hypothesis_self_overlap is not None:
+ average.hypothesis_self_overlap.warn('hypothesis')
+ if average.reference_self_overlap is not None:
+ average.reference_self_overlap.warn('reference')
+ return results
diff --git a/meeteval/wer/matching/cy_greedy_combination_matching.pyx b/meeteval/wer/matching/cy_greedy_combination_matching.pyx
index 451adc64..66b5a874 100644
--- a/meeteval/wer/matching/cy_greedy_combination_matching.pyx
+++ b/meeteval/wer/matching/cy_greedy_combination_matching.pyx
@@ -21,7 +21,6 @@ def cy_forward_col(
column: The column to be updated
a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!)
b: Sequence in row direction. This function updates `column` `len(b)` times
- tmp: Pre-allocated temporary memory. Must have the same shape as `column`
cost_substitution: Cost for a substitution
"""
cdef uint i, j, a_, b_, current, prev
@@ -32,7 +31,6 @@ def cy_forward_col(
current = (j + 1) % 2
prev = j % 2
b_ = b[j]
- # TODO: can we swap without copy?
tmp[current, 0] = tmp[prev, 0] + 1
for i in range(1, a.shape[0] + 1):
a_ = a[i - 1]
@@ -41,3 +39,49 @@ def cy_forward_col(
else:
tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1])
return np.asarray(tmp[current]).copy()
+
+
+@cython.boundscheck(False) # turn off bounds-checking for entire function
+@cython.wraparound(False) # turn off negative index wrapping for entire function
+def cy_forward_col_time_constrained(
+ npy_uint64[:] column not None,
+ npy_uint64[:] a not None,
+ npy_uint64[:] b not None,
+ npy_float64[:] a_begin not None,
+ npy_float64[:] a_end not None,
+ npy_float64[:] b_begin not None,
+ npy_float64[:] b_end not None,
+ uint cost_substitution = 1,
+) -> np.ndarray:
+ """
+ Args:
+ column: The column to be updated
+ a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!)
+ b: Sequence in row direction. This function updates `column` `len(b)` times
+ cost_substitution: Cost for a substitution
+ """
+ cdef uint i, j, a_, b_, current, prev
+ cdef double a_begin_, a_end_, b_begin_, b_end_
+ cdef npy_uint64[:, :] tmp = np.empty((2, column.shape[0]), dtype=np.uint)
+ tmp[0, ...] = column
+ current = 0
+ for j in range(b.shape[0]):
+ current = (j + 1) % 2
+ prev = j % 2
+ b_ = b[j]
+ b_begin_ = b_begin[j]
+ b_end_ = b_end[j]
+
+ tmp[current, 0] = tmp[prev, 0] + 1
+ for i in range(1, a.shape[0] + 1):
+ a_ = a[i - 1]
+ if a_begin[i - 1] >= b_end_ or b_begin_ >= a_end[i - 1]:
+ # No overlap
+ tmp[current, i] = min([tmp[current, i - 1] + 1, tmp[prev, i] + 1])
+ elif a_ == b_:
+ # Overlap correct
+ tmp[current, i] = tmp[prev, i - 1]
+ else:
+ # Overlap incorrect
+ tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1])
+ return np.asarray(tmp[current]).copy()
diff --git a/meeteval/wer/matching/greedy_combination_matching.py b/meeteval/wer/matching/greedy_combination_matching.py
index 1ca21b72..ab3506e1 100644
--- a/meeteval/wer/matching/greedy_combination_matching.py
+++ b/meeteval/wer/matching/greedy_combination_matching.py
@@ -1,12 +1,11 @@
import itertools
import functools
-from typing import List, Iterable
-
+from typing import List, Iterable, Tuple
import numpy as np
from meeteval.io.seglst import SegLST
-from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col
+from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col, cy_forward_col_time_constrained
def _apply_assignment(assignment, segments, n=None):
@@ -270,3 +269,67 @@ def greedy_combination_matching(
)
return distance, assignment
+
+
+def greedy_time_constrained_combination_matching(
+ segments: List[Iterable[Tuple[int, float, float]]],
+ streams: List[Iterable[Tuple[int, float, float]]],
+ initial_assignment: List[int],
+ *,
+ distancetype: str = '21', # '21', '2', '1'
+):
+ """
+ Segments in `segments` are assigned to streams in `streams`.
+
+ Args:
+ segments: A list of segments for which stream labels should be obtained
+ streams: A list of streams to which the segments are assigned
+ initial_assignment: The initial assignment of the segments to the streams.
+ Can be obtained with `initialize_assignment`.
+ distancetype: The type of distance to use. Can be one of:
+ - `'1'`: Use insertion cost of 1 (like in Levenshtein distance)
+ - `'2'`: Use insertion cost of 2 (cost of insertion + deletion)
+ - `'21'`: Start with '2' until converged and then use '1' until converged
+
+ >>> greedy_time_constrained_combination_matching(
+ ... [[(0, 0, 1), (1, 1, 2)]],
+ ... [[(0, 0, 1), (1, 1, 2)]],
+ ... [0]
+ ... )
+ (0, [0])
+
+ >>> greedy_time_constrained_combination_matching(
+ ... [[(0, 0, 1)], [(1, 1, 2)]],
+ ... [[(0, 0, 1)], [(1, 1, 2)]],
+ ... [0, 0]
+ ... )
+ (0, [0, 1])
+ """
+ if len(segments) == 0:
+ return sum([len(s) for s in streams]), []
+ if len(streams) == 0:
+ return sum([len(s) for s in segments]), [0] * len(segments)
+
+ assert len(initial_assignment) == len(segments), (len(initial_assignment), len(segments), initial_assignment)
+
+ # Correct assignment
+ assignment = initial_assignment
+ assert distancetype in ('1', '2', '21'), distancetype
+ for d in distancetype:
+ def forward_col(column, a, b):
+ return cy_forward_col_time_constrained(
+ column,
+ a=np.asarray([t for t, _, _ in a], dtype=np.uint),
+ b=np.asarray([t for t, _, _ in b], dtype=np.uint),
+ a_begin=np.asarray([t for _, t, _ in a], dtype=np.float64),
+ a_end=np.asarray([t for _, _, t in a], dtype=np.float64),
+ b_begin=np.asarray([t for _, t, _ in b], dtype=np.float64),
+ b_end=np.asarray([t for _, _, t in b], dtype=np.float64),
+ cost_substitution=int(d)
+ )
+ assignment, distance = _greedy_correct_assignment(
+ segments, streams, assignment,
+ forward_col
+ )
+
+ return distance, assignment
diff --git a/meeteval/wer/wer/time_constrained_orc.py b/meeteval/wer/wer/time_constrained_orc.py
index 7d80fd6d..b76763b1 100644
--- a/meeteval/wer/wer/time_constrained_orc.py
+++ b/meeteval/wer/wer/time_constrained_orc.py
@@ -14,6 +14,8 @@
__all__ = [
'time_constrained_orc_wer',
'time_constrained_orc_wer_multifile',
+ 'greedy_time_constrained_orc_wer',
+ 'greedy_time_constrained_orc_wer_multifile',
]
@@ -30,7 +32,7 @@ def time_constrained_orc_wer(
The time-constrained version of the ORC-WER (tcORC-WER).
Special cases where the reference or hypothesis is empty
- # >>> time_constrained_orc_wer([], [])
+ >>> time_constrained_orc_wer([], [])
OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=())
>>> time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=())
@@ -106,3 +108,99 @@ def time_constrained_orc_wer_multifile(
reference_sort=reference_sort,
), reference, hypothesis, partial=partial)
return r
+
+
+def greedy_time_constrained_orc_wer(
+ reference,
+ hypothesis,
+ reference_pseudo_word_level_timing='character_based',
+ hypothesis_pseudo_word_level_timing='character_based_points',
+ collar: int = 0,
+ reference_sort='segment',
+ hypothesis_sort='segment',
+):
+ """
+ Special cases where the reference or hypothesis is empty
+ >>> greedy_time_constrained_orc_wer([], [])
+ OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=())
+ >>> greedy_time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
+ OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=())
+ >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}], [])
+ OrcErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=1, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('dummy',))
+ >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': '', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
+ OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',))
+ >>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a b', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a d', 'speaker': 'A'}])
+ OrcErrorRate(error_rate=0.5, errors=1, length=2, insertions=0, deletions=0, substitutions=1, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',))
+ """
+ from meeteval.wer.matching.greedy_combination_matching import greedy_time_constrained_combination_matching, \
+ initialize_assignment
+ from meeteval.wer.wer.orc import _orc_error_rate
+
+ def matching(reference, hypothesis):
+ """Use the mimo matching algorithm. Convert inputs and outputs between the formats"""
+ distance, assignment = greedy_time_constrained_combination_matching(
+ [list(zip(*r)) for r in reference.T['words', 'start_time', 'end_time']],
+ [[w for words in stream.T['words', 'start_time', 'end_time'] for w in zip(*words)] for stream in
+ hypothesis.values()],
+ initial_assignment=initialize_assignment(reference, hypothesis, initialization='cp'),
+ )
+ return distance, assignment
+
+ def siso(reference, hypothesis):
+ return _time_constrained_siso_error_rate(
+ reference.flatmap(lambda x: [
+ {**x, 'words': w, 'start_time': s, 'end_time': e}
+ for w, s, e in zip(x['words'], x['start_time'], x['end_time'])
+ ]),
+ hypothesis.flatmap(lambda x: [
+ {**x, 'words': w, 'start_time': s, 'end_time': e}
+ for w, s, e in zip(x['words'], x['start_time'], x['end_time'])
+ ]),
+ )
+
+ # Drop segment index in reference. It will get a new one after merging by speakers
+ reference = meeteval.io.asseglst(reference)
+ reference = reference.map(lambda x: {k: v for k, v in x.items() if k != 'segment_index'})
+
+ reference, hypothesis, ref_self_overlap, hyp_self_overlap = preprocess(
+ reference, hypothesis,
+ keep_keys=('words', 'segment_index', 'speaker', 'start_time', 'end_time'),
+ reference_sort=reference_sort,
+ hypothesis_sort=hypothesis_sort,
+ segment_representation='segment',
+ segment_index='segment',
+ remove_empty_segments=False,
+ convert_to_int=True,
+ reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
+ hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
+ collar=collar,
+ )
+
+ er = _orc_error_rate(reference, hypothesis, matching, siso)
+ er = dataclasses.replace(
+ er,
+ reference_self_overlap=ref_self_overlap,
+ hypothesis_self_overlap=hyp_self_overlap,
+ )
+ return er
+
+
+def greedy_time_constrained_orc_wer_multifile(
+ reference: 'STM', hypothesis: 'STM',
+ reference_pseudo_word_level_timing='character_based',
+ hypothesis_pseudo_word_level_timing='character_based_points',
+ collar: int = 0,
+ hypothesis_sort='segment',
+ reference_sort='segment',
+ partial=False,
+) -> 'dict[str, OrcErrorRate]':
+ from meeteval.io.seglst import apply_multi_file
+ r = apply_multi_file(lambda r, h: greedy_time_constrained_orc_wer(
+ r, h,
+ reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
+ hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
+ collar=collar,
+ hypothesis_sort=hypothesis_sort,
+ reference_sort=reference_sort,
+ ), reference, hypothesis, partial=partial)
+ return r
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9a0e4ec8..2dd5f9ea 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -127,6 +127,14 @@ def test_burn_tcorc():
run(f'python -m meeteval.wer tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true')
+def test_burn_greedy_tcorc():
+ run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm')
+ run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --collar 5')
+ run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hyp-pseudo-word-timing equidistant_points')
+ run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.seglst.json -r ref.seglst.json')
+ run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true')
+
+
def test_burn_md_eval_22():
run(f'python -m meeteval.der md_eval_22 -h hyp.stm -r ref.stm')
run(f'meeteval-der md_eval_22 -h hyp.stm -r ref.stm')
diff --git a/tests/test_greedy_combination_matching.py b/tests/test_greedy_combination_matching.py
index 3823f690..abe1adf8 100644
--- a/tests/test_greedy_combination_matching.py
+++ b/tests/test_greedy_combination_matching.py
@@ -1,11 +1,11 @@
+import itertools
from pathlib import Path
import pytest
-import time
from meeteval.wer.matching import greedy_combination_matching, orc_matching
import meeteval
-from hypothesis import given, strategies as st, settings, assume
+from hypothesis import given, strategies as st, settings
example_files = (Path(__file__).parent.parent / 'example_files').absolute()
@@ -17,31 +17,32 @@
streams = lambda max_size: st.lists(utterance, min_size=1, max_size=max_size)
-def _check_output(distance, assignment, segments, streams):
+def _check_output(distance, assignment, segments, streams, check_distance=True):
assert isinstance(distance, int)
assert distance >= 0
assert len(assignment) == len(segments)
assert all(isinstance(a, int) for a in assignment)
assert all(0 <= a < len(streams) for a in assignment)
- d = orc_matching._levensthein_distance_for_assignment(segments, streams, assignment)
- assert d == distance, (d, distance, assignment)
+ if check_distance:
+ d = orc_matching._levensthein_distance_for_assignment(segments, streams, assignment)
+ assert d == distance, (d, distance, assignment)
-@given(segments(10), streams(4))
+@given(segments(10), streams(4), st.sampled_from(['1', '2', '21']))
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
-def test_greedy_combination_matching_burn(segments, streams):
+def test_greedy_combination_matching_burn(segments, streams, distancetype):
"""Burn-test. Brute-force is exponential in the number of reference
utterances, so choose a small number."""
distance, assignment = greedy_combination_matching.greedy_combination_matching(
- segments, streams, [0] * len(segments)
+ segments, streams, [0] * len(segments), distancetype=distancetype
)
- _check_output(distance, assignment, segments, streams)
+ _check_output(distance, assignment, segments, streams, check_distance=distancetype != '2')
-@given(segments(6), streams(4))
-def test_greedy_bound_by_optimal(segments, streams):
+@given(segments(6), streams(4), st.sampled_from(['1', '2', '21']))
+def test_greedy_bound_by_optimal(segments, streams, distancetype):
greedy_distance, _ = greedy_combination_matching.greedy_combination_matching(
- segments, streams, [0] * len(segments)
+ segments, streams, [0] * len(segments), distancetype=distancetype
)
optimal_distance, _ = orc_matching.orc_matching_v3(segments, streams)
@@ -73,3 +74,106 @@ def test_optimal_assignment_is_not_changed(segments, streams):
assert optimal_distance == greedy_distance
assert optimal_assignment == greedy_assignment
+
+
+# Limit alphabet to ensure a few correct matches
+string = st.text(alphabet='abcdefg', min_size=0, max_size=100)
+
+
+@st.composite
+def string_with_timing(draw):
+ """
+ Constraints:
+ - end >= start
+ - start values must be increasing
+ """
+ s = draw(string)
+ t = []
+ start = 0
+ for _ in s:
+ start = draw(st.integers(min_value=start, max_value=10))
+ end = draw(st.integers(min_value=start, max_value=start + 10))
+ t.append((start, end))
+ return s, t
+
+
+@given(
+ string_with_timing(),
+ string_with_timing(),
+)
+def test_greedy_time_constrained_correct(a, b):
+ """
+ Tests the time-constrained cython matrix implementation against the
+ time-constrained distance C++ implementation used in the time-constrained
+ siso WER
+ """
+ from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance
+ from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col_time_constrained
+ import numpy as np
+
+ a, a_timing = a
+ b, b_timing = b
+
+ # cy_forward_col_time_constrained needs the sequences as integers
+ import collections
+ sym2int = collections.defaultdict(itertools.count().__next__)
+ _ = sym2int[''] # Reserve 0 for the empty string
+ a = [sym2int[c] for c in a]
+ b = [sym2int[c] for c in b]
+
+ siso_dist = time_constrained_levenshtein_distance(a, b, a_timing, b_timing)
+
+ column = cy_forward_col_time_constrained(
+ np.arange(len(a) + 1, dtype=np.uint),
+ np.asarray(a, dtype=np.uint),
+ np.asarray(b, dtype=np.uint),
+ np.asarray([t[0] for t in a_timing], float),
+ np.asarray([t[1] for t in a_timing], float),
+ np.asarray([t[0] for t in b_timing], float),
+ np.asarray([t[1] for t in b_timing], float),
+ )
+
+ assert siso_dist == column[-1]
+
+
+
+utterance_with_timings = st.lists(st.tuples(st.integers(min_value=0, max_value=10), st.floats(allow_nan=False, allow_infinity=False), st.floats(0, 1, allow_nan=False, allow_infinity=False)), min_size=1, max_size=10)
+segments_with_timing = lambda max_size: st.lists(utterance_with_timings, min_size=1, max_size=max_size)
+streams_with_timing = lambda max_size: st.lists(utterance_with_timings, min_size=1, max_size=max_size)
+
+
+@given(segments_with_timing(2), streams_with_timing(2), st.sampled_from(['1', '2', '21']))
+def test_greedy_time_constrained_bound_by_optimal(segments, streams, distancetype):
+ greedy_distance, _ = greedy_combination_matching.greedy_time_constrained_combination_matching(
+ segments, streams, [0] * len(segments), distancetype=distancetype
+ )
+
+ from meeteval.wer.matching.cy_time_constrained_orc_matching import time_constrained_orc_levenshtein_distance
+ optimal_distance, _ = time_constrained_orc_levenshtein_distance(
+ [[s[0] for s in ss] for ss in segments],
+ [[s[0] for s in ss] for ss in streams],
+ [[s[1:] for s in ss] for ss in segments],
+ [[s[1:] for s in ss] for ss in streams],
+ )
+ assert optimal_distance <= greedy_distance
+
+
+@given(segments(4), streams(6), st.sampled_from(['1', '2', '21']))
+def test_greedy_vs_greedy_time_constrained(segments, streams, distancetype):
+ """
+ Tests that the time-constrained version gives the same result as the
+ non-time-constrained version when all words overlap.
+ """
+ greedy_distance, greedy_assignment = greedy_combination_matching.greedy_combination_matching(
+ segments, streams, [0] * len(segments), distancetype=distancetype
+ )
+
+ greedy_time_constrained_distance, greedy_time_constrained_assignment = greedy_combination_matching.greedy_time_constrained_combination_matching(
+ [[(w, 0, 1) for w in segment] for segment in segments],
+ [[(w, 0, 1) for w in stream] for stream in streams],
+ [0] * len(segments),
+ distancetype=distancetype
+ )
+
+ assert greedy_time_constrained_distance == greedy_distance
+ assert greedy_time_constrained_assignment == greedy_assignment