diff --git a/README.md b/README.md
index df85538..7f70760 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,8 @@ MeetEval supports the following metrics for meeting transcription evaluation:
`meeteval-wer greedy_tcorcwer -r ref.stm -h hyp.stm --collar 5`
- **Diarization-Invariant cpWER (DI-cpWER)**
`meeteval-wer greedy_dicpwer -r ref.stm -h hyp.stm`
+- **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl) like dscore (see https://github.com/fgnt/meeteval/issues/97#issuecomment-2508140402)
+ `meeteval-der dscore -r ref.stm -h hyp.stm --collar .25`
- **Diarization Error Rate (DER)** by wrapping [mdeval](https://github.com/nryant/dscore/raw/master/scorelib/md-eval-22.pl)
`meeteval-der md_eval_22 -r ref.stm -h hyp.stm --collar .25`
diff --git a/meeteval/der/__main__.py b/meeteval/der/__main__.py
index 672fcef..0496c22 100644
--- a/meeteval/der/__main__.py
+++ b/meeteval/der/__main__.py
@@ -26,6 +26,35 @@ def md_eval_22(
_save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER')
+def dscore(
+ reference,
+ hypothesis,
+ average_out='{parent}/{stem}_dscore.json',
+ per_reco_out='{parent}/{stem}_dscore_per_reco.json',
+ collar=0,
+ regions='all',
+ regex=None,
+ uem=None,
+):
+ """
+ Computes the Diarization Error Rate (DER) using md-eval-22.pl,
+ but create a uem if uem is None, as it is done in dscore [1].
+ Commonly used in challenge evaluations, e.g., DIHARD II, CHiME.
+
+ [1] https://github.com/nryant/dscore
+ """
+ from meeteval.der.api import dscore
+ results = dscore(
+ reference,
+ hypothesis,
+ collar=collar,
+ regex=regex,
+ regions=regions,
+ uem=uem,
+ )
+ _save_results(results, hypothesis, per_reco_out, average_out, wer_name='DER')
+
+
def cli():
from meeteval.wer.__main__ import CLI
@@ -57,6 +86,7 @@ def add_argument(self, command_parser, name, p):
cli = DerCLI()
+ cli.add_command(dscore)
cli.add_command(md_eval_22)
cli.run()
diff --git a/meeteval/der/api.py b/meeteval/der/api.py
index b6b5e44..04822e5 100644
--- a/meeteval/der/api.py
+++ b/meeteval/der/api.py
@@ -4,6 +4,7 @@
__all__ = [
'md_eval_22',
+ 'dscore',
]
@@ -26,3 +27,24 @@ def md_eval_22(
r, h, collar, regions=regions, uem=uem
)
return results
+
+
+def dscore(
+ reference,
+ hypothesis,
+ collar=0,
+ regions='all',
+ regex=None,
+ uem=None,
+):
+ r, h = _load_texts(reference, hypothesis, regex)
+ from meeteval.der.dscore import dscore_multifile
+ if uem is not None:
+ from meeteval.io.uem import UEM
+ if isinstance(uem, (str, Path, list, tuple)):
+ uem = UEM.load(uem)
+
+ results = dscore_multifile(
+ r, h, collar, regions=regions, uem=uem
+ )
+ return results
diff --git a/meeteval/der/dscore.py b/meeteval/der/dscore.py
new file mode 100644
index 0000000..121ab4a
--- /dev/null
+++ b/meeteval/der/dscore.py
@@ -0,0 +1,275 @@
+import logging
+import sys
+import subprocess
+import tempfile
+import decimal
+from pathlib import Path
+
+
+def _dscore_multifile(
+ reference, hypothesis, collar=0, regions='all',
+ uem=None
+):
+ """
+ dscore produces a table with the final scores, but we need
+ the details. Hence, call dscore only to compare the error rate
+ with md_eval_22_multifile.
+
+ >>> import packaging.version
+ >>> import numpy as np
+ >>> if packaging.version.parse(np.__version__) >= packaging.version.parse('1.24'):
+ ... import pytest
+ ... pytest.skip(f'dscore fails with numpy >= 1.24. Current version: {np.__version__}')
+
+ >>> from meeteval.io.rttm import RTTM
+ >>> from meeteval.io.uem import UEM
+ >>> reference = RTTM.parse('''
+ ... SPEAKER S1 1 0.0 0.5 spk1
+ ... SPEAKER S1 1 0.5 0.5 spk2
+ ... SPEAKER S1 1 1.0 0.5 spk1
+ ... ''')
+ >>> hypothesis = RTTM.parse('''
+ ... SPEAKER S1 1 0.0 0.5 spk1
+ ... SPEAKER S1 1 0.5 0.5 spk2
+ ... SPEAKER S1 1 1.0 0.5 spk2
+ ... ''')
+ >>> uem = UEM.parse('''
+ ... S1 1 0.0 1.5
+ ... ''')
+ >>> import pprint
+ >>> pprint.pprint(_dscore_multifile(reference, hypothesis, uem=uem))
+ {'S1': Decimal('0.3333')}
+ """
+ from meeteval.der.md_eval import _FilenameEscaper
+ escaper = _FilenameEscaper()
+
+ from meeteval.io.rttm import RTTM
+ reference = RTTM.new(reference)
+ hypothesis = RTTM.new(hypothesis)
+
+ reference = escaper.escape_rttm(reference)
+ hypothesis = escaper.escape_rttm(hypothesis)
+
+ score_py = Path(__file__).parent / 'dscore_repo' / 'score.py'
+ if not score_py.exists():
+ subprocess.run(['git', 'clone', 'https://github.com/nryant/dscore.git', score_py.parent])
+
+ filtered = 0
+ for line in reference:
+ if line.duration == 0:
+ filtered += 1
+ if filtered:
+ logging.info(f'Filtered {filtered} lines with zero duration in reference (dscore doesn\'t support zero duration)')
+ reference = RTTM([line for line in reference if line.duration != 0])
+
+ filtered = 0
+ for line in hypothesis:
+ if line.duration == 0:
+ filtered += 1
+ if filtered:
+ logging.info(f'Filtered {filtered} lines with zero duration in hypothesis (dscore doesn\'t support zero duration)')
+ hypothesis = RTTM([line for line in hypothesis if line.duration != 0])
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+
+ r_file = tmpdir / f'ref.rttm'
+ h_file = tmpdir / f'hyp.rttm'
+ reference.dump(r_file)
+ hypothesis.dump(h_file)
+
+ cmd = [
+ sys.executable, str(score_py),
+ '--collar', str(collar),
+ # '--ignore_overlaps', regions,
+ '-r', f'{r_file}',
+ '-s', f'{h_file}',
+ '--table_fmt', 'tsv'
+ ]
+ if uem:
+ uem_file = tmpdir / 'uem.rttm'
+ uem = escaper.escape_uem(uem)
+ uem.dump(uem_file)
+ cmd.extend(['-u', f'{uem_file}'])
+
+ if regions == 'all':
+ pass
+ elif regions == 'nooverlap':
+ cmd.append('--ignore_overlaps')
+
+ p = subprocess.run(
+ cmd,
+ stdout=subprocess.PIPE,
+ # stderr=subprocess.PIPE,
+ check=True, universal_newlines=True,
+ cwd=score_py.parent
+ )
+
+ result = {}
+ for line in p.stdout.strip().split('\n')[1:-1]:
+ line_parts = line.split('\t')
+ result[escaper.restore(line_parts[0].strip())] = decimal.Decimal(line_parts[1].strip()) / 100
+
+ assert result, p.stdout
+ return result
+
+
+def _maybe_gen_uem(uem, reference, hypothesis):
+ # Mirror the behavior of dscore
+ if uem is None:
+ from meeteval.io.uem import UEM, UEMLine
+ uem_md_eval = UEM([
+ UEMLine(
+ filename=k, channel='1',
+ begin_time=min(v.T['start_time']), end_time=max(v.T['end_time'])
+ )
+ for k, v in (reference + hypothesis).to_seglst().groupby('session_id').items()
+ ])
+
+ return uem_md_eval, None
+ else:
+ return uem, uem
+
+
+def dscore_multifile(
+ reference, hypothesis, collar=0, regions='all',
+ uem=None, sanity_check=False,
+):
+ """
+ Computes the Diarization Error Rate (DER) using md-eval-22.pl
+ but create a uem if uem is None, as it is done in dscore [1].
+
+ Additionally, compare the error rate with dscore [1], if sanity_check is True.
+
+ Args:
+ reference:
+ hypothesis:
+ collar:
+ regions: 'all' or 'nooverlap'
+ uem: If None, generate a uem from the reference and hypothesis.
+ This is the default behavior of dscore, while md-eval-22,
+ uses only the reference.
+ sanity_check: Compare the result with dscore to ensure
+ the correctness of the implementation.
+ Requires the numpy < 1.24 (e.g. np.int),
+ because dscore fails with recent numpy versions.
+
+ [1] https://github.com/nryant/dscore
+
+ >>> from meeteval.io.rttm import RTTM
+ >>> from meeteval.io.uem import UEM
+ >>> reference = RTTM.parse('''
+ ... SPEAKER rec 1 5.00 5.00 spk01
+ ... SPEAKER rec 1 10.00 10.00 spk00
+ ... ''')
+ >>> hypothesis = RTTM.parse('''
+ ... SPEAKER rec 1 0.00 10.00 spk01
+ ... SPEAKER rec 1 10.00 10.00 spk00
+ ... ''')
+ >>> import pprint
+ >>> pprint.pprint(dscore_multifile(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE
+ {'rec': DiaErrorRate(error_rate=Decimal('0.3333'),
+ scored_speaker_time=Decimal('15.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('5.000000'),
+ speaker_error_time=Decimal('0.000000'))}
+
+ >>> from meeteval.der.md_eval import md_eval_22_multifile
+ >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE
+ {'rec': DiaErrorRate(error_rate=Decimal('0.00'),
+ scored_speaker_time=Decimal('15.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('0.000000'),
+ speaker_error_time=Decimal('0.000000'))}
+
+ """
+ from meeteval.der.md_eval import md_eval_22_multifile
+
+ uem_md_eval, uem_dscore = _maybe_gen_uem(uem, reference, hypothesis)
+
+ result = md_eval_22_multifile(
+ reference, hypothesis, collar=collar, regions=regions, uem=uem_md_eval
+ )
+ if sanity_check:
+ dscore_der = _dscore_multifile(reference, hypothesis, collar=collar, regions=regions, uem=uem_dscore)
+ for key in result:
+ assert key in dscore_der, (key, result, dscore_der)
+ assert abs(dscore_der[key] - result[key].error_rate) <= decimal.Decimal('0.0001'), (key, dscore_der[key], result[key])
+
+ return result
+
+
+def dscore(reference, hypothesis, collar=0, regions='all', uem=None, sanity_check=False):
+ """
+ Computes the Diarization Error Rate (DER) using md-eval-22.pl
+ but create a uem if uem is None, as it is done in dscore [1].
+
+ Additionally, compare the error rate with dscore [1], if sanity_check is True.
+
+ Args:
+ reference:
+ hypothesis:
+ collar:
+ regions: 'all' or 'nooverlap'
+ uem: If None, generate a uem from the reference and hypothesis.
+ This is the default behavior of dscore, while md-eval-22,
+ uses only the reference.
+ sanity_check: Compare the result with dscore to ensure
+ the correctness of the implementation.
+ Requires the numpy < 1.24 (e.g. np.int),
+ because dscore fails with recent numpy versions.
+
+ [1] https://github.com/nryant/dscore
+
+ >>> from meeteval.io.rttm import RTTM
+ >>> from meeteval.io.uem import UEM
+ >>> reference = RTTM.parse('''
+ ... SPEAKER rec.a 1 5.00 5.00 spk01
+ ... SPEAKER rec.a 1 10.00 10.00 spk00
+ ... ''')
+ >>> hypothesis = RTTM.parse('''
+ ... SPEAKER rec.a 1 0.00 10.00 spk01
+ ... SPEAKER rec.a 1 10.00 10.00 spk00
+ ... ''')
+ >>> uem = UEM.parse('''
+ ... rec.a 1 0.00 15.00
+ ... ''')
+ >>> import pprint
+
+ >>> pprint.pprint(dscore(reference, hypothesis)) # doctest: +NORMALIZE_WHITESPACE
+ DiaErrorRate(error_rate=Decimal('0.3333'),
+ scored_speaker_time=Decimal('15.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('5.000000'),
+ speaker_error_time=Decimal('0.000000'))
+ >>> pprint.pprint(dscore(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE
+ DiaErrorRate(error_rate=Decimal('0.50'),
+ scored_speaker_time=Decimal('10.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('5.000000'),
+ speaker_error_time=Decimal('0.000000'))
+
+ # md_eval_22 ignores hyps before the first ref and after the last ref
+ >>> from meeteval.der.md_eval import md_eval_22
+ >>> pprint.pprint(md_eval_22(reference, hypothesis)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
+ DiaErrorRate(error_rate=Decimal('0.00'),
+ scored_speaker_time=Decimal('15.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('0.000000'),
+ speaker_error_time=Decimal('0.000000'))
+ """
+ from meeteval.der.md_eval import md_eval_22
+ from meeteval.io.rttm import RTTM
+
+ reference = RTTM.new(reference, filename='dummy')
+ hypothesis = RTTM.new(hypothesis, filename='dummy')
+
+ uem_md_eval, uem_dscore = _maybe_gen_uem(uem, reference, hypothesis)
+
+ result = md_eval_22(reference, hypothesis, collar=collar, regions=regions, uem=uem_md_eval)
+
+ if sanity_check:
+ dscore_der = _dscore_multifile(reference, hypothesis, collar=collar, regions=regions, uem=uem_dscore)
+ assert list(dscore_der.values()) == [result.error_rate], (dscore_der, result.error_rate)
+
+ return result
diff --git a/meeteval/der/md_eval.py b/meeteval/der/md_eval.py
index 447d28f..26929cd 100644
--- a/meeteval/der/md_eval.py
+++ b/meeteval/der/md_eval.py
@@ -77,6 +77,101 @@ def __add__(self, other: 'DiaErrorRate'):
)
+class _FilenameEscaper:
+ """
+ >>> import pprint
+ >>> reference = meeteval.io.RTTM.parse('''
+ ... SPEAKER rec.a 1 5.00 5.00 spk01
+ ... SPEAKER rec.a 1 10.00 10.00 spk00
+ ... ''')
+ >>> hypothesis = meeteval.io.RTTM.parse('''
+ ... SPEAKER rec.a 1 0.00 10.00 spk01
+ ... SPEAKER rec.a 1 10.00 10.00 spk00
+ ... ''')
+ >>> uem = meeteval.io.UEM.parse('''
+ ... rec.a 1 0.00 15.00
+ ... ''')
+
+ >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE
+ {'rec.a': DiaErrorRate(error_rate=Decimal('0.50'),
+ scored_speaker_time=Decimal('10.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('5.000000'),
+ speaker_error_time=Decimal('0.000000'))}
+
+ >>> _FilenameEscaper._DISABLED = True
+ >>> pprint.pprint(md_eval_22_multifile(reference, hypothesis, uem=uem)) # doctest: +NORMALIZE_WHITESPACE
+ {'rec.a': DiaErrorRate(error_rate=Decimal('0.00'),
+ scored_speaker_time=Decimal('15.000000'),
+ missed_speaker_time=Decimal('0.000000'),
+ falarm_speaker_time=Decimal('0.000000'),
+ speaker_error_time=Decimal('0.000000'))}
+ >>> _FilenameEscaper._DISABLED = False
+ """
+ _DISABLED = False
+ def __init__(self):
+ self.warned = False
+ self.cache = {}
+
+ def __call__(self, filename):
+ if self._DISABLED:
+ return filename
+ if filename in self.cache:
+ return self.cache[filename]
+ elif '.' in filename:
+ blocked = set(self.cache.values())
+ for replacement in [
+ '_', '__', '___', '____', '_____',
+ ]:
+ new = filename.replace('.', replacement)
+ if new not in blocked:
+ break
+ else:
+ raise RuntimeError(f'Cannot find a replacement for {filename}.')
+
+ if not self.warned:
+ self.warned = True
+ logging.warning(
+ f'Warning: Replace UEM filename "{filename}" by "{new}".\n'
+ f' md-eval-22 removes the first suffix for uem filenames/session_ids but not in rttm files\n'
+ f' (e.g., uem: some.audio.wav -> some.wav).\n'
+ f' (e.g., rttm: some.audio.wav -> some.audio.wav).\n'
+ f" dcores doesn't support dots in uem\n"
+ f' (e.g., without uem file, rttm filenames/session_ids can have dots).\n'
+ f' (e.g., with uem file, rttm cannot have dots).\n'
+ f" -> dcores has no proper support of dots, because they use md-eval-22\n"
+ f' In meeteval, we assume, that the uem file has the same filenames/session_ids as reference and hypothesis.\n'
+ f' -> remove dots from filename and restore them later'
+ )
+ self.cache[filename] = new
+ return new
+ else:
+ self.cache[filename] = filename
+ return filename
+
+ def escape_rttm(self, rttm):
+ return rttm.__class__(
+ [line.replace(filename=self(line.filename)) for line in rttm]
+ )
+
+ def escape_uem(self, uem):
+ return uem.__class__(
+ [line.replace(filename=self(line.filename)) for line in uem]
+ )
+
+ def restore(self, filename):
+ if self._DISABLED:
+ return filename
+ if filename in self.cache:
+ assert self.cache[filename] == filename, (self.cache[filename], filename)
+ return filename
+ for k, v in self.cache.items():
+ if v == filename:
+ return k
+ raise ValueError(f'Cannot find {filename} as value in {self.cache}')
+
+
+
def md_eval_22_multifile(
reference, hypothesis, collar=0, regions='all',
uem=None
@@ -104,6 +199,10 @@ def md_eval_22_multifile(
reference = RTTM.new(reference)
hypothesis = RTTM.new(hypothesis)
+ escaper = _FilenameEscaper()
+ reference = escaper.escape_rttm(reference)
+ hypothesis = escaper.escape_rttm(hypothesis)
+
reference = _fix_channel(reference)
hypothesis = _fix_channel(hypothesis)
@@ -128,7 +227,11 @@ def md_eval_22_multifile(
urllib.request.urlretrieve(url, md_eval_22)
logging.info(f'Wrote {md_eval_22}')
+ warned = False
+
def get_details(r, h, key, tmpdir, uem):
+ nonlocal warned
+
r_file = tmpdir / f'{key}.ref.rttm'
h_file = tmpdir / f'{key}.hyp.rttm'
r.dump(r_file)
@@ -146,11 +249,15 @@ def get_details(r, h, key, tmpdir, uem):
if uem:
uem_file = tmpdir / f'{key}.uem'
+ uem = escaper.escape_uem(uem)
uem.dump(uem_file)
cmd.extend(['-u', f'{uem_file}'])
-
+ elif not warned:
+ warned = True
+ logging.warning(f'No UEM file provided. See https://github.com/fgnt/meeteval/issues/97#issuecomment-2508140402 for details.')
cp = subprocess.run(cmd, stdout=subprocess.PIPE,
check=True, universal_newlines=True)
+
# SCORED SPEAKER TIME =4309.340250 secs
# MISSED SPEAKER TIME =4309.340250 secs
# FALARM SPEAKER TIME =0.000000 secs
@@ -179,7 +286,7 @@ def convert(string):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
for key in keys:
- per_reco[key] = get_details(r[key], h[key], key, tmpdir, uem)
+ per_reco[escaper.restore(key)] = get_details(r[key], h[key], key, tmpdir, uem)
md_eval = get_details(
meeteval.io.RTTM([line for key in keys for line in r[key]]),
@@ -197,6 +304,7 @@ def convert(string):
f'does not match the average error rate of md-eval-22.pl '
f'applied to each recording ({md_eval.error_rate}).'
)
+
return per_reco
@@ -214,4 +322,4 @@ def md_eval_22(reference, hypothesis, collar=0, regions='all', uem=None):
return md_eval_22_multifile(
reference, hypothesis, collar, regions=regions, uem=uem
- )[reference.filenames()[0]]
+ )[list(reference.filenames())[0]]
diff --git a/setup.py b/setup.py
index d8ec2bf..ffff916 100644
--- a/setup.py
+++ b/setup.py
@@ -128,7 +128,8 @@
'scipy', # scipy.optimize.linear_sum_assignment
"typing_extensions; python_version<'3.8'", # Missing Literal in py37
"cached_property; python_version<'3.8'", # Missing functools.cached_property in py37
- 'Cython'
+ 'Cython',
+ 'packaging', # commonly used to compare python versions, e.g., used by jupyter, matplotlib, pytest, ...
],
extras_require=extras_require,
package_data={'meeteval': ['**/*.pyx', '**/*.h', '**/*.js', '**/*.css', '**/*.html']}, # https://stackoverflow.com/a/60751886
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 51d8e62..3ec939f 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -156,6 +156,15 @@ def test_burn_md_eval_22():
# ToDo: Table 2 of https://arxiv.org/pdf/2312.04324.pdf lists collars for
# datsets. Add them here.
+def test_burn_dscore():
+ run(f'python -m meeteval.der dscore -h hyp.stm -r ref.stm')
+ run(f'meeteval-der dscore -h hyp.stm -r ref.stm')
+ run(f'python -m meeteval.der dscore -h hyp.stm -r ref.stm --collar 0.25')
+ run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm')
+ run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regions all')
+ run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regions nooverlap')
+ run(f'python -m meeteval.der dscore -h hyp.rttm -r ref.rttm --regex ".*A"')
+ run(f'python -m meeteval.der dscore -h hyp.seglst.json -r ref.seglst.json')
def test_burn_merge():
run(f'python -m meeteval.wer cpwer -h hypA.stm -r refA.stm') # create hypA_cpwer_per_reco.json and hypA_cpwer.json