Skip to content

Commit

Permalink
Merge pull request #90 from fgnt/fix_orc_assignment_empty_segments
Browse files Browse the repository at this point in the history
Fix orc assignment empty segments
  • Loading branch information
thequilo authored Sep 5, 2024
2 parents 146b947 + 329dc54 commit 1ef8643
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
1 change: 1 addition & 0 deletions meeteval/io/seglst.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def groupby(

groups = collections.defaultdict(list)
try:
# This keeps the order of the keys
for key, group in itertools.groupby(iterable, _get_key(key)):
groups[key].extend(group)
except KeyError:
Expand Down
14 changes: 9 additions & 5 deletions meeteval/wer/wer/orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,35 @@ def _orc_error_rate(

# Consistency check: Compute WER with the siso algorithm after applying the
# assignment and compare the result with the distance from the ORC algorithm
reference_new = reference_new.groupby('speaker')
reference_new_grouped = reference_new.groupby('speaker')
er = combine_error_rates(*[
siso_error_rate(
reference_new.get(k, meeteval.io.SegLST([])),
reference_new_grouped.get(k, meeteval.io.SegLST([])),
hypothesis.get(k, meeteval.io.SegLST([])),
)
for k in set(hypothesis.keys()) | set(reference_new.keys())
for k in set(hypothesis.keys()) | set(reference_new_grouped.keys())
])
length = sum([len(s['words']) if isinstance(s['words'], list) else 1 for s in reference])
assert er.length == length, (length, er)
if distance is not None:
# The matching function can in some cases not compute the distance
assert er.errors == distance, (distance, er, assignment)

# Get the assignment in the order of segment_index
assignment = [r['speaker'] for r in reference_new.sorted('segment_index')]

# Insert labels for empty segments that got removed
if len(assignment) != total_num_segments:
# Make sure this assignment is reproducible by sorting the keys
dummy_key = sorted(hypothesis_keys)[0]
assignment = sorted([(r['segment_index'], a) for a, r in zip(assignment, reference)])
assignment = sorted([(r, a) for a, r in zip(assignment, sorted(reference.unique('segment_index')))])
for i in range(total_num_segments):
if i >= len(assignment) or assignment[i][0] != i:
assignment.insert(i, (i, dummy_key))
assignment = [a[1] for a in assignment]

assert len(assignment) == total_num_segments, (len(assignment), total_num_segments)

return OrcErrorRate(
er.errors, er.length,
insertions=er.insertions,
Expand Down Expand Up @@ -287,7 +292,6 @@ def apply_orc_assignment(
reference = reference.groupby('segment_index').values()
else:
reference = [[r] for r in reference]

assert len(reference) == len(assignment), (len(reference), len(assignment))
reference = meeteval.io.SegLST([
{**s, 'speaker': a}
Expand Down
1 change: 1 addition & 0 deletions tests/test_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def exec_with_source(code, filename, lineno, globals_=None, locals_=None):
[
(str(filename.relative_to(MEETEVAL_ROOT)), codeblock)
for filename in MEETEVAL_ROOT.glob('**/*.md')
if 'build' not in str(filename) and '/.' not in str(filename)
for codeblock in get_fenced_code_blocks(filename.read_text())
]
)
Expand Down
22 changes: 19 additions & 3 deletions tests/test_time_constrained_orc_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Limit alphabet to ensure a few correct matches
@st.composite
def string(draw, max_length=100):
return ' '.join(draw(st.text(alphabet='abcdefg', min_size=1, max_size=max_length)))
return ' '.join(draw(st.text(alphabet='abcdefg', min_size=0, max_size=max_length)))


# Generate a random SegLST object
Expand Down Expand Up @@ -42,6 +42,22 @@ def seglst(draw, min_segments=0, max_segments=10, max_speakers=2):
)


@given(
seglst(max_speakers=1, min_segments=1),
seglst(max_speakers=3, min_segments=1)
)
@settings(deadline=None) # The tests take longer on the GitHub actions test servers
def test_tcorc_burn(reference, hypothesis):
from meeteval.wer.wer.time_constrained_orc import time_constrained_orc_wer

tcorc = time_constrained_orc_wer(reference, hypothesis, collar=1000, reference_sort=False, hypothesis_sort=False)

assert len(tcorc.assignment) == len(reference)
assert isinstance(tcorc.errors, int)
assert tcorc.errors >= 0
assigned_reference, assigned_hypothesis = tcorc.apply_assignment(reference, hypothesis)


@given(
seglst(max_speakers=1, min_segments=1),
seglst(max_speakers=2, min_segments=1)
Expand Down Expand Up @@ -140,7 +156,7 @@ def test_examples_zero_self_overlap():

def test_assignment_keeps_order():
"""
Tests that elements in the assignment corrspond to the order in the input
Tests that elements in the assignment correspond to the order in the input
to the orc_wer function, not the sorted segments.
"""
from meeteval.wer.wer.time_constrained_orc import time_constrained_orc_wer
Expand All @@ -157,6 +173,6 @@ def test_assignment_keeps_order():
]),
reference_sort='segment',
)
assert tcorc.assignment == ('A1', 'A1', 'A2')
assert tcorc.assignment == ('A1', 'A1', 'A2'), tcorc.assignment


0 comments on commit 1ef8643

Please sign in to comment.