diff --git a/README.md b/README.md index df85538..7f70760 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ MeetEval supports the following metrics for meeting transcription evaluation: `meeteval-wer greedy_tcorcwer -r ref.stm -h hyp.stm --collar 5` - **Diarization-Invariant cpWER (DI-cpWER)**
`meeteval-wer greedy_dicpwer -r ref.stm -h hyp.stm` +- **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl) like dscore (see https://github.com/fgnt/meeteval/issues/97#issuecomment-2508140402)
+ `meeteval-der dscore -r ref.stm -h hyp.stm --collar .25` - **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl)
`meeteval-der md_eval_22 -r ref.stm -h hyp.stm --collar .25` diff --git a/meeteval/der/__main__.py b/meeteval/der/__main__.py index 672fcef..0496c22 100644 --- a/meeteval/der/__main__.py +++ b/meeteval/der/__main__.py @@ -26,6 +26,35 @@ def md_eval_22( _save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER') +def dscore( + reference, + hypothesis, + average_out='{parent}/{stem}_dscore.json', + per_reco_out='{parent}/{stem}_dscore_per_reco.json', + collar=0, + regions='all', + regex=None, + uem=None, +): + """ + Computes the Diarization Error Rate (DER) using md-eval-22.pl, + but create a uem if uem is None, as it is done in dscore [1]. + Commonly used in challenge evaluations, e.g., DIHARD II, CHiME. + + [1] https://github.com/nryant/dscore + """ + from meeteval.der.api import dscore + results = dscore( + reference, + hypothesis, + collar=collar, + regex=regex, + regions=regions, + uem=uem, + ) + _save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER') + + def cli(): from meeteval.wer.__main__ import CLI @@ -57,6 +86,7 @@ def add_argument(self, command_parser, name, p): cli = DerCLI() + cli.add_command(dscore) cli.add_command(md_eval_22) cli.run() diff --git a/meeteval/der/api.py b/meeteval/der/api.py index b6b5e44..04822e5 100644 --- a/meeteval/der/api.py +++ b/meeteval/der/api.py @@ -4,6 +4,7 @@ __all__ = [ 'md_eval_22', + 'dscore', ] @@ -26,3 +27,24 @@ def md_eval_22( r, h, collar, regions=regions, uem=uem ) return results + + +def dscore( + reference, + hypothesis, + collar=0, + regions='all', + regex=None, + uem=None, +): + r, h = _load_texts(reference, hypothesis, regex) + from meeteval.der.dscore import dscore_multifile + if uem is not None: + from meeteval.io.uem import UEM + if isinstance(uem, (str, Path, list, tuple)): + uem = UEM.load(uem) + + results = dscore_multifile( + r, h, collar, regions=regions, uem=uem + ) + return results diff --git a/meeteval/der/dscore.py b/meeteval/der/dscore.py new file mode 100644 index 0000000..121ab4a --- /dev/null +++ b/meeteval/der/dscore.py @@ -0,0 +1,275 @@ +import logging +import sys +import subprocess +import tempfile +import decimal +from pathlib import Path + + +def _dscore_multifile( + reference, hypothesis, collar=0, regions='all', + uem=None +): + """ + dscore produces a table with the final scores, but we need + the details. Hence, call dscore only to compare the error rate + with md_eval_22_multifile. + + >>> import packaging.version + >>> import numpy as np + >>> if packaging.version.parse(np.__version__) >= packaging.version.parse('1.24'): + ... import pytest + ... pytest.skip(f'dscore fails with numpy >= 1.24. Current version: {np.__version__}') + + >>> from meeteval.io.rttm import RTTM + >>> from meeteval.io.uem import UEM + >>> reference = RTTM.parse(''' + ... SPEAKER S1 1 0.0 0.5 spk1 + ... SPEAKER S1 1 0.5 0.5 spk2 + ... SPEAKER S1 1 1.0 0.5 spk1 + ... ''') + >>> hypothesis = RTTM.parse(''' + ... SPEAKER S1 1 0.0 0.5 spk1 + ... SPEAKER S1 1 0.5 0.5 spk2 + ... SPEAKER S1 1 1.0 0.5 spk2 + ... ''') + >>> uem = UEM.parse(''' + ... S1 1 0.0 1.5 + ... ''') + >>> import pprint + >>> pprint.pprint(_dscore_multifile(reference, hypothesis, uem=uem)) + {'S1': Decimal('0.3333')} + """ + from meeteval.der.md_eval import _FilenameEscaper + escaper = _FilenameEscaper() + + from meeteval.io.rttm import RTTM + reference = RTTM.new(reference) + hypothesis = RTTM.new(hypothesis) + + reference = escaper.escape_rttm(reference) + hypothesis = escaper.escape_rttm(hypothesis) + + score_py = Path(__file__).parent / 'dscore_repo' / 'score.py' + if not score_py.exists(): + subprocess.run(['git', 'clone', 'https://github.com/nryant/dscore.git', score_py.parent]) + + filtered = 0 + for line in reference: + if line.duration == 0: + filtered += 1 + if filtered: + logging.info(f'Filtered {filtered} lines with zero duration in reference (dscore doesn\'t support zero duration)') + reference = RTTM([line for line in reference if line.duration != 0]) + + filtered = 0 + for line in hypothesis: + if line.duration == 0: + filtered += 1 + if filtered: + logging.info(f'Filtered {filtered} lines with zero duration in hypothesis (dscore doesn\'t support zero duration)') + hypothesis = RTTM([line for line in hypothesis if line.duration != 0]) + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + + r_file = tmpdir / f'ref.rttm' + h_file = tmpdir / f'hyp.rttm' + reference.dump(r_file) + hypothesis.dump(h_file) + + cmd = [ + sys.executable, str(score_py), + '--collar', str(collar), + # '--ignore_overlaps', regions, + '-r', f'{r_file}', + '-s', f'{h_file}', + '--table_fmt', 'tsv' + ] + if uem: + uem_file = tmpdir / 'uem.rttm' + uem = escaper.escape_uem(uem) + uem.dump(uem_file) + cmd.extend(['-u', f'{uem_file}']) + + if regions == 'all': + pass + elif regions == 'nooverlap': + cmd.append('--ignore_overlaps') + + p = subprocess.run( + cmd, + stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + check=True, universal_newlines=True, + cwd=score_py.parent + ) + + result = {} + for line in p.stdout.strip().split('\n')[1:-1]: + line_parts = line.split('\t') + result[escaper.restore(line_parts[0].strip())] = decimal.Decimal(line_parts[1].strip()) / 100 + + assert result, p.stdout + return result + + +def _maybe_gen_uem(uem, reference, hypothesis): + # Mirror the behavior of dscore + if uem is None: + from meeteval.io.uem import UEM, UEMLine + uem_md_eval = UEM([ + UEMLine( + filename=k, channel='1', + begin_time=min(v.T['start_time']), end_time=max(v.T['end_time']) + ) + for k, v in (reference + hypothesis).to_seglst().groupby('session_id').items() + ]) + + return uem_md_eval, None + else: + return uem, uem + + +def dscore_multifile( + reference, hypothesis, collar=0, regions='all', + uem=None, sanity_check=False, +): + """ + Computes the Diarization Error Rate (DER) using md-eval-22.pl + but create a uem if uem is None, as it is done in dscore [1]. + + Additionally, compare the error rate with dscore [1], if sanity_check is True. + + Args: + reference: + hypothesis: + collar: + regions: 'all' or 'nooverlap' + uem: If None, generate a uem from the reference and hypothesis. + This is the default behavior of dscore, while md-eval-22, + uses only the reference. + sanity_check: Compare the result with dscore to ensure + the correctness of the implementation. + Requires the numpy < 1.24 (e.g. np.int), + because dscore fails with recent numpy versions. + + [1] https://github.com/nryant/dscore + + >>> from meeteval.io.rttm import RTTM + >>> from meeteval.io.uem import UEM + >>> reference = RTTM.parse(''' + ... SPEAKER rec 1 5.00 5.00 spk01 + ... SPEAKER rec 1 10.00 10.00 spk00 + ... ''') + >>> hypothesis = RTTM.parse(''' + ... SPEAKER rec 1 0.00 10.00 spk01 + ... SPEAKER rec 1 10.00 10.00 spk00 + ... ''') + >>> import pprint + >>> pprint.pprint(dscore_multifile(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE + {'rec': DiaErrorRate(error_rate=Decimal('0.3333'), + scored_speaker_time=Decimal('15.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('5.000000'), + speaker_error_time=Decimal('0.000000'))} + + >>> from meeteval.der.md_eval import md_eval_22_multifile + >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE + {'rec': DiaErrorRate(error_rate=Decimal('0.00'), + scored_speaker_time=Decimal('15.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('0.000000'), + speaker_error_time=Decimal('0.000000'))} + + """ + from meeteval.der.md_eval import md_eval_22_multifile + + uem_md_eval, uem_dscore = _maybe_gen_uem(uem, reference, hypothesis) + + result = md_eval_22_multifile( + reference, hypothesis, collar=collar, regions=regions, uem=uem_md_eval + ) + if sanity_check: + dscore_der = _dscore_multifile(reference, hypothesis, collar=collar, regions=regions, uem=uem_dscore) + for key in result: + assert key in dscore_der, (key, result, dscore_der) + assert abs(dscore_der[key] - result[key].error_rate) <= decimal.Decimal('0.0001'), (key, dscore_der[key], result[key]) + + return result + + +def dscore(reference, hypothesis, collar=0, regions='all', uem=None, sanity_check=False): + """ + Computes the Diarization Error Rate (DER) using md-eval-22.pl + but create a uem if uem is None, as it is done in dscore [1]. + + Additionally, compare the error rate with dscore [1], if sanity_check is True. + + Args: + reference: + hypothesis: + collar: + regions: 'all' or 'nooverlap' + uem: If None, generate a uem from the reference and hypothesis. + This is the default behavior of dscore, while md-eval-22, + uses only the reference. + sanity_check: Compare the result with dscore to ensure + the correctness of the implementation. + Requires the numpy < 1.24 (e.g. np.int), + because dscore fails with recent numpy versions. + + [1] https://github.com/nryant/dscore + + >>> from meeteval.io.rttm import RTTM + >>> from meeteval.io.uem import UEM + >>> reference = RTTM.parse(''' + ... SPEAKER rec.a 1 5.00 5.00 spk01 + ... SPEAKER rec.a 1 10.00 10.00 spk00 + ... ''') + >>> hypothesis = RTTM.parse(''' + ... SPEAKER rec.a 1 0.00 10.00 spk01 + ... SPEAKER rec.a 1 10.00 10.00 spk00 + ... ''') + >>> uem = UEM.parse(''' + ... rec.a 1 0.00 15.00 + ... ''') + >>> import pprint + + >>> pprint.pprint(dscore(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE + DiaErrorRate(error_rate=Decimal('0.3333'), + scored_speaker_time=Decimal('15.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('5.000000'), + speaker_error_time=Decimal('0.000000')) + >>> pprint.pprint(dscore(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE + DiaErrorRate(error_rate=Decimal('0.50'), + scored_speaker_time=Decimal('10.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('5.000000'), + speaker_error_time=Decimal('0.000000')) + + # md_eval_22 ignores hyps before the first ref and after the last ref + >>> from meeteval.der.md_eval import md_eval_22 + >>> pprint.pprint(md_eval_22(reference, hypothesis)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DiaErrorRate(error_rate=Decimal('0.00'), + scored_speaker_time=Decimal('15.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('0.000000'), + speaker_error_time=Decimal('0.000000')) + """ + from meeteval.der.md_eval import md_eval_22 + from meeteval.io.rttm import RTTM + + reference = RTTM.new(reference, filename='dummy') + hypothesis = RTTM.new(hypothesis, filename='dummy') + + uem_md_eval, uem_dscore = _maybe_gen_uem(uem, reference, hypothesis) + + result = md_eval_22(reference, hypothesis, collar=collar, regions=regions, uem=uem_md_eval) + + if sanity_check: + dscore_der = _dscore_multifile(reference, hypothesis, collar=collar, regions=regions, uem=uem_dscore) + assert list(dscore_der.values()) == [result.error_rate], (dscore_der, result.error_rate) + + return result diff --git a/meeteval/der/md_eval.py b/meeteval/der/md_eval.py index 447d28f..26929cd 100644 --- a/meeteval/der/md_eval.py +++ b/meeteval/der/md_eval.py @@ -77,6 +77,101 @@ def __add__(self, other: 'DiaErrorRate'): ) +class _FilenameEscaper: + """ + >>> import pprint + >>> reference = meeteval.io.RTTM.parse(''' + ... SPEAKER rec.a 1 5.00 5.00 spk01 + ... SPEAKER rec.a 1 10.00 10.00 spk00 + ... ''') + >>> hypothesis = meeteval.io.RTTM.parse(''' + ... SPEAKER rec.a 1 0.00 10.00 spk01 + ... SPEAKER rec.a 1 10.00 10.00 spk00 + ... ''') + >>> uem = meeteval.io.UEM.parse(''' + ... rec.a 1 0.00 15.00 + ... ''') + + >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE + {'rec.a': DiaErrorRate(error_rate=Decimal('0.50'), + scored_speaker_time=Decimal('10.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('5.000000'), + speaker_error_time=Decimal('0.000000'))} + + >>> _FilenameEscaper._DISABLED = True + >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE + {'rec.a': DiaErrorRate(error_rate=Decimal('0.00'), + scored_speaker_time=Decimal('15.000000'), + missed_speaker_time=Decimal('0.000000'), + falarm_speaker_time=Decimal('0.000000'), + speaker_error_time=Decimal('0.000000'))} + >>> _FilenameEscaper._DISABLED = False + """ + _DISABLED = False + def __init__(self): + self.warned = False + self.cache = {} + + def __call__(self, filename): + if self._DISABLED: + return filename + if filename in self.cache: + return self.cache[filename] + elif '.' in filename: + blocked = set(self.cache.values()) + for replacement in [ + '_', '__', '___', '____', '_____', + ]: + new = filename.replace('.', replacement) + if new not in blocked: + break + else: + raise RuntimeError(f'Cannot find a replacement for {filename}.') + + if not self.warned: + self.warned = True + logging.warning( + f'Warning: Replace UEM filename "{filename}" by "{new}".\n' + f' md-eval-22 removes the first suffix for uem filenames/session_ids but not in rttm files\n' + f' (e.g., uem: some.audio.wav -> some.wav).\n' + f' (e.g., rttm: some.audio.wav -> some.audio.wav).\n' + f" dcores doesn't support dots in uem\n" + f' (e.g., without uem file, rttm filenames/session_ids can have dots).\n' + f' (e.g., with uem file, rttm cannot have dots).\n' + f" -> dcores has no proper support of dots, because they use md-eval-22\n" + f' In meeteval, we assume, that the uem file has the same filenames/session_ids as reference and hypothesis.\n' + f' -> remove dots from filename and restore them later' + ) + self.cache[filename] = new + return new + else: + self.cache[filename] = filename + return filename + + def escape_rttm(self, rttm): + return rttm.__class__( + [line.replace(filename=self(line.filename)) for line in rttm] + ) + + def escape_uem(self, uem): + return uem.__class__( + [line.replace(filename=self(line.filename)) for line in uem] + ) + + def restore(self, filename): + if self._DISABLED: + return filename + if filename in self.cache: + assert self.cache[filename] == filename, (self.cache[filename], filename) + return filename + for k, v in self.cache.items(): + if v == filename: + return k + raise ValueError(f'Cannot find {filename} as value in {self.cache}') + + + def md_eval_22_multifile( reference, hypothesis, collar=0, regions='all', uem=None @@ -104,6 +199,10 @@ def md_eval_22_multifile( reference = RTTM.new(reference) hypothesis = RTTM.new(hypothesis) + escaper = _FilenameEscaper() + reference = escaper.escape_rttm(reference) + hypothesis = escaper.escape_rttm(hypothesis) + reference = _fix_channel(reference) hypothesis = _fix_channel(hypothesis) @@ -128,7 +227,11 @@ def md_eval_22_multifile( urllib.request.urlretrieve(url, md_eval_22) logging.info(f'Wrote {md_eval_22}') + warned = False + def get_details(r, h, key, tmpdir, uem): + nonlocal warned + r_file = tmpdir / f'{key}.ref.rttm' h_file = tmpdir / f'{key}.hyp.rttm' r.dump(r_file) @@ -146,11 +249,15 @@ def get_details(r, h, key, tmpdir, uem): if uem: uem_file = tmpdir / f'{key}.uem' + uem = escaper.escape_uem(uem) uem.dump(uem_file) cmd.extend(['-u', f'{uem_file}']) - + elif not warned: + warned = True + logging.warning(f'No UEM file provided. See https://github.com/fgnt/meeteval/issues/97#issuecomment-2508140402 for details.') cp = subprocess.run(cmd, stdout=subprocess.PIPE, check=True, universal_newlines=True) + # SCORED SPEAKER TIME =4309.340250 secs # MISSED SPEAKER TIME =4309.340250 secs # FALARM SPEAKER TIME =0.000000 secs @@ -179,7 +286,7 @@ def convert(string): with tempfile.TemporaryDirectory() as tmpdir: tmpdir = Path(tmpdir) for key in keys: - per_reco[key] = get_details(r[key], h[key], key, tmpdir, uem) + per_reco[escaper.restore(key)] = get_details(r[key], h[key], key, tmpdir, uem) md_eval = get_details( meeteval.io.RTTM([line for key in keys for line in r[key]]), @@ -197,6 +304,7 @@ def convert(string): f'does not match the average error rate of md-eval-22.pl ' f'applied to each recording ({md_eval.error_rate}).' ) + return per_reco @@ -214,4 +322,4 @@ def md_eval_22(reference, hypothesis, collar=0, regions='all', uem=None): return md_eval_22_multifile( reference, hypothesis, collar, regions=regions, uem=uem - )[reference.filenames()[0]] + )[list(reference.filenames())[0]] diff --git a/setup.py b/setup.py index d8ec2bf..ffff916 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,8 @@ 'scipy', # scipy.optimize.linear_sum_assignment "typing_extensions; python_version<'3.8'", # Missing Literal in py37 "cached_property; python_version<'3.8'", # Missing functools.cached_property in py37 - 'Cython' + 'Cython', + 'packaging', # commonly used to compare python versions, e.g., used by jupyter, matplotlib, pytest, ... ], extras_require=extras_require, package_data={'meeteval': ['**/*.pyx', '**/*.h', '**/*.js', '**/*.css', '**/*.html']}, # https://stackoverflow.com/a/60751886 diff --git a/tests/test_cli.py b/tests/test_cli.py index 51d8e62..3ec939f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -156,6 +156,15 @@ def test_burn_md_eval_22(): # ToDo: Table 2 of https://arxiv.org/pdf/2312.04324.pdf lists collars for # datsets. Add them here. +def test_burn_dscore(): + run(f'python -m meeteval.der dscore -h hyp.stm -r ref.stm') + run(f'meeteval-der dscore -h hyp.stm -r ref.stm') + run(f'python -m meeteval.der dscore -h hyp.stm -r ref.stm --collar 0.25') + run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm') + run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regions all') + run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regions nooverlap') + run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regex ".*A"') + run(f'python -m meeteval.der dscore -h hyp.seglst.json -r ref.seglst.json') def test_burn_merge(): run(f'python -m meeteval.wer cpwer -h hypA.stm -r refA.stm') # create hypA_cpwer_per_reco.json and hypA_cpwer.json