Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add greedy time-constrained ORC-WER #91

Merged
merged 7 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion meeteval/wer/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import dataclasses
import glob
import json
import logging
import os
Expand Down Expand Up @@ -301,6 +300,35 @@ def tcorcwer(
_save_results(results, hypothesis, per_reco_out, average_out)


def greedy_tcorcwer(
reference, hypothesis,
average_out='{parent}/{stem}_greedy_tcorcwer.json',
per_reco_out='{parent}/{stem}_greedy_tcorcwer_per_reco.json',
regex=None,
collar=0,
hyp_pseudo_word_timing='character_based_points',
ref_pseudo_word_timing='character_based',
hypothesis_sort='segment',
reference_sort='segment',
uem=None,
normalizer=None,
partial=False,
):
"""Computes the time-constrained ORC WER (tcORC WER)"""
results = meeteval.wer.greedy_tcorcwer(
reference, hypothesis, regex=regex,
collar=collar,
hyp_pseudo_word_timing=hyp_pseudo_word_timing,
ref_pseudo_word_timing=ref_pseudo_word_timing,
hypothesis_sort=hypothesis_sort,
reference_sort=reference_sort,
uem=uem, partial=partial,
normalizer=normalizer,
)
_save_results(results, hypothesis, per_reco_out, average_out)



def greedy_dicpwer(
reference, hypothesis,
average_out='{parent}/{stem}_greedy_dicpwer.json',
Expand Down Expand Up @@ -651,6 +679,7 @@ def cli():
cli.add_command(tcpwer)
cli.add_command(tcorcwer)
cli.add_command(greedy_dicpwer)
cli.add_command(greedy_tcorcwer)
cli.add_command(merge)
cli.add_command(average)

Expand Down
34 changes: 34 additions & 0 deletions meeteval/wer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'mimower',
'tcpwer',
'tcorcwer',
'greedy_tcorcwer',
'greedy_dicpwer',
]

Expand Down Expand Up @@ -282,3 +283,36 @@ def greedy_dicpwer(
hypothesis_sort=hypothesis_sort,
)
return results


def greedy_tcorcwer(
reference, hypothesis,
regex=None,
collar=0,
hyp_pseudo_word_timing='character_based_points',
ref_pseudo_word_timing='character_based',
hypothesis_sort='segment',
reference_sort='segment',
uem=None,
normalizer=None,
partial=False,
):
"""Computes the time-constrained ORC WER with a greedy algorithm"""
from meeteval.wer.wer.time_constrained_orc import greedy_time_constrained_orc_wer_multifile
reference, hypothesis = _load_texts(reference, hypothesis, regex=regex, uem=uem, normalizer=normalizer)
results = greedy_time_constrained_orc_wer_multifile(
reference, hypothesis,
reference_pseudo_word_level_timing=ref_pseudo_word_timing,
hypothesis_pseudo_word_level_timing=hyp_pseudo_word_timing,
collar=collar,
hypothesis_sort=hypothesis_sort,
reference_sort=reference_sort,
partial=partial,
)
from meeteval.wer import combine_error_rates
average: ErrorRate = combine_error_rates(results)
if average.hypothesis_self_overlap is not None:
average.hypothesis_self_overlap.warn('hypothesis')
if average.reference_self_overlap is not None:
average.reference_self_overlap.warn('reference')
return results
48 changes: 48 additions & 0 deletions meeteval/wer/matching/cy_greedy_combination_matching.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,51 @@ def cy_forward_col(
else:
tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1])
return np.asarray(tmp[current]).copy()


@cython.boundscheck(False) # turn off bounds-checking for entire function
@cython.wraparound(False) # turn off negative index wrapping for entire function
def cy_forward_col_time_constrained(
npy_uint64[:] column not None,
npy_uint64[:] a not None,
npy_uint64[:] b not None,
npy_float64[:] a_begin not None,
npy_float64[:] a_end not None,
npy_float64[:] b_begin not None,
npy_float64[:] b_end not None,
uint cost_substitution = 1,
) -> np.ndarray:
"""
Args:
column: The column to be updated
a: Sequence in column direction (make sure that `len(column) == len(a) + 1`! otherwise SEGFAULT!!)
b: Sequence in row direction. This function updates `column` `len(b)` times
tmp: Pre-allocated temporary memory. Must have the same shape as `column`
thequilo marked this conversation as resolved.
Show resolved Hide resolved
cost_substitution: Cost for a substitution
"""
cdef uint i, j, a_, b_, current, prev
cdef double a_begin_, a_end_, b_begin_, b_end_
cdef npy_uint64[:, :] tmp = np.empty((2, column.shape[0]), dtype=np.uint)
tmp[0, ...] = column
current = 0
for j in range(b.shape[0]):
current = (j + 1) % 2
prev = j % 2
b_ = b[j]
b_begin_ = b_begin[j]
b_end_ = b_end[j]

# TODO: can we swap without copy?
thequilo marked this conversation as resolved.
Show resolved Hide resolved
tmp[current, 0] = tmp[prev, 0] + 1
for i in range(1, a.shape[0] + 1):
a_ = a[i - 1]
if a_begin[i - 1] >= b_end_ or b_begin_ >= a_end[i - 1]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use touching as no overlap?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we decided for this in the time-constrained Levenshtein distance inn levenshtein.h and now everything is tested against that. I think the argument was that two time points should not overlap

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I just wanted to ensure, that it is the same.

# No overlap
tmp[current, i] = min([tmp[current, i - 1] + 1, tmp[prev, i] + 1])
elif a_ == b_:
# Overlap correct
tmp[current, i] = tmp[prev, i - 1]
else:
# Overlap incorrect
tmp[current, i] = min([tmp[prev, i - 1] + cost_substitution, tmp[current, i - 1] + 1, tmp[prev, i] + 1])
return np.asarray(tmp[current]).copy()
69 changes: 66 additions & 3 deletions meeteval/wer/matching/greedy_combination_matching.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import itertools
import functools
from typing import List, Iterable

from typing import List, Iterable, Tuple

import numpy as np

from meeteval.io.seglst import SegLST
from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col
from meeteval.wer.matching.cy_greedy_combination_matching import cy_forward_col, cy_forward_col_time_constrained


def _apply_assignment(assignment, segments, n=None):
Expand Down Expand Up @@ -270,3 +269,67 @@ def greedy_combination_matching(
)

return distance, assignment


def greedy_time_constrained_combination_matching(
segments: List[Iterable[Tuple[int, float, float]]],
streams: List[Iterable[Tuple[int, float, float]]],
initial_assignment: List[int],
*,
distancetype: str = '21', # '21', '2', '1'
):
"""
Segments in `segments` are assigned to streams in `streams`.

Args:
segments: A list of segments for which stream labels should be obtained
streams: A list of streams to which the segments are assigned
initial_assignment: The initial assignment of the segments to the streams.
Can be obtained with `initialize_assignment`.
distancetype: The type of distance to use. Can be one of:
- `'1'`: Use insertion cost of 1 (like in Levenshtein distance)
- `'2'`: Use insertion cost of 2 (cost of insertion + deletion)
- `'21'`: Start with '2' until converged and then use '1' until converged

>>> greedy_time_constrained_combination_matching(
... [[(0, 0, 1), (1, 1, 2)]],
... [[(0, 0, 1), (1, 1, 2)]],
... [0]
... )
(0, [0])

>>> greedy_time_constrained_combination_matching(
... [[(0, 0, 1)], [(1, 1, 2)]],
... [[(0, 0, 1)], [(1, 1, 2)]],
... [0, 0]
... )
(0, [0, 1])
"""
if len(segments) == 0:
return sum([len(s) for s in streams]), []
if len(streams) == 0:
return sum([len(s) for s in segments]), [0] * len(segments)

assert len(initial_assignment) == len(segments), (len(initial_assignment), len(segments), initial_assignment)

# Correct assignment
assignment = initial_assignment
assert distancetype in ('1', '2', '21'), distancetype
for d in distancetype:
def forward_col(column, a, b):
return cy_forward_col_time_constrained(
column,
a=np.asarray([t for t, _, _ in a], dtype=np.uint),
b=np.asarray([t for t, _, _ in b], dtype=np.uint),
a_begin=np.asarray([t for _, t, _ in a], dtype=np.float64),
a_end=np.asarray([t for _, _, t in a], dtype=np.float64),
b_begin=np.asarray([t for _, t, _ in b], dtype=np.float64),
b_end=np.asarray([t for _, _, t in b], dtype=np.float64),
cost_substitution=int(d)
)
assignment, distance = _greedy_correct_assignment(
segments, streams, assignment,
forward_col
)

return distance, assignment
100 changes: 99 additions & 1 deletion meeteval/wer/wer/time_constrained_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
__all__ = [
'time_constrained_orc_wer',
'time_constrained_orc_wer_multifile',
'greedy_time_constrained_orc_wer',
'greedy_time_constrained_orc_wer_multifile',
]


Expand All @@ -30,7 +32,7 @@ def time_constrained_orc_wer(
The time-constrained version of the ORC-WER (tcORC-WER).

Special cases where the reference or hypothesis is empty
# >>> time_constrained_orc_wer([], [])
>>> time_constrained_orc_wer([], [])
OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=())
>>> time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=())
Expand Down Expand Up @@ -106,3 +108,99 @@ def time_constrained_orc_wer_multifile(
reference_sort=reference_sort,
), reference, hypothesis, partial=partial)
return r


def greedy_time_constrained_orc_wer(
reference,
hypothesis,
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
collar: int = 0,
reference_sort='segment',
hypothesis_sort='segment',
):
"""
Special cases where the reference or hypothesis is empty
>>> greedy_time_constrained_orc_wer([], [])
OrcErrorRate(errors=0, length=0, insertions=0, deletions=0, substitutions=0, assignment=())
>>> greedy_time_constrained_orc_wer([], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=())
>>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}], [])
OrcErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=1, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('dummy',))
>>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': '', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a', 'speaker': 'A'}])
OrcErrorRate(errors=1, length=0, insertions=1, deletions=0, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',))
>>> greedy_time_constrained_orc_wer([{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a b', 'speaker': 'A'}], [{'session_id': 'a', 'start_time': 0, 'end_time': 1, 'words': 'a d', 'speaker': 'A'}])
OrcErrorRate(error_rate=0.5, errors=1, length=2, insertions=0, deletions=0, substitutions=1, reference_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1), assignment=('A',))
"""
from meeteval.wer.matching.greedy_combination_matching import greedy_time_constrained_combination_matching, \
initialize_assignment
from meeteval.wer.wer.orc import _orc_error_rate

def matching(reference, hypothesis):
"""Use the mimo matching algorithm. Convert inputs and outputs between the formats"""
distance, assignment = greedy_time_constrained_combination_matching(
[list(zip(*r)) for r in reference.T['words', 'start_time', 'end_time']],
[[w for words in stream.T['words', 'start_time', 'end_time'] for w in zip(*words)] for stream in
hypothesis.values()],
initial_assignment=initialize_assignment(reference, hypothesis, initialization='cp'),
)
return distance, assignment

def siso(reference, hypothesis):
return _time_constrained_siso_error_rate(
reference.flatmap(lambda x: [
{**x, 'words': w, 'start_time': s, 'end_time': e}
for w, s, e in zip(x['words'], x['start_time'], x['end_time'])
]),
hypothesis.flatmap(lambda x: [
{**x, 'words': w, 'start_time': s, 'end_time': e}
for w, s, e in zip(x['words'], x['start_time'], x['end_time'])
]),
)

# Drop segment index in reference. It will get a new one after merging by speakers
reference = meeteval.io.asseglst(reference)
reference = reference.map(lambda x: {k: v for k, v in x.items() if k != 'segment_index'})

reference, hypothesis, ref_self_overlap, hyp_self_overlap = preprocess(
reference, hypothesis,
keep_keys=('words', 'segment_index', 'speaker', 'start_time', 'end_time'),
reference_sort=reference_sort,
hypothesis_sort=hypothesis_sort,
segment_representation='segment',
segment_index='segment',
remove_empty_segments=False,
convert_to_int=True,
reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
collar=collar,
)

er = _orc_error_rate(reference, hypothesis, matching, siso)
er = dataclasses.replace(
er,
reference_self_overlap=ref_self_overlap,
hypothesis_self_overlap=hyp_self_overlap,
)
return er


def greedy_time_constrained_orc_wer_multifile(
reference: 'STM', hypothesis: 'STM',
reference_pseudo_word_level_timing='character_based',
hypothesis_pseudo_word_level_timing='character_based_points',
collar: int = 0,
hypothesis_sort='segment',
reference_sort='segment',
partial=False,
) -> 'dict[str, OrcErrorRate]':
from meeteval.io.seglst import apply_multi_file
r = apply_multi_file(lambda r, h: greedy_time_constrained_orc_wer(
r, h,
reference_pseudo_word_level_timing=reference_pseudo_word_level_timing,
hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing,
collar=collar,
hypothesis_sort=hypothesis_sort,
reference_sort=reference_sort,
), reference, hypothesis, partial=partial)
return r
8 changes: 8 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def test_burn_tcorc():
run(f'python -m meeteval.wer tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true')


def test_burn_greedy_tcorc():
run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm')
run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --collar 5')
run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hyp-pseudo-word-timing equidistant_points')
run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.seglst.json -r ref.seglst.json')
run(f'python -m meeteval.wer greedy_tcorcwer -h hyp.stm -r ref.stm --hypothesis-sort true')


def test_burn_md_eval_22():
run(f'python -m meeteval.der md_eval_22 -h hyp.stm -r ref.stm')
run(f'meeteval-der md_eval_22 -h hyp.stm -r ref.stm')
Expand Down
Loading
Loading