diff --git a/meeteval/io/base.py b/meeteval/io/base.py index acf88ef6..eeb8f130 100644 --- a/meeteval/io/base.py +++ b/meeteval/io/base.py @@ -1,16 +1,18 @@ +import abc import io import os import sys import typing from pathlib import Path import contextlib -from typing import Dict, List, NamedTuple import dataclasses from dataclasses import dataclass from itertools import groupby +import decimal if typing.TYPE_CHECKING: from typing import Self + from meeteval.io.seglst import SegLstSegment, SegLST from meeteval.io.uem import UEM, UEMLine from meeteval.io.stm import STM, STMLine from meeteval.io.ctm import CTM, CTMLine @@ -20,12 +22,33 @@ Subclasses = 'UEM | STM | CTM | RTTM' + +class BaseABC: + @classmethod + def new(cls, d, **defaults): + # Example code: + # from meeteval.io.seglst import asseglst + # seglst = asseglst(d).map(lambda s: {**defaults, **s}) + # ... (convert seglst to cls) + raise NotImplementedError(cls) + + def to_seglst(self): + raise NotImplementedError() + + @dataclass(frozen=True) class BaseLine: @classmethod def parse(cls, line: str) -> 'Self': raise NotImplementedError(cls) + @classmethod + def from_dict(cls, segment: 'SegLstSegment') -> 'Self': + raise NotImplementedError(cls) + + def to_seglst_segment(self) -> 'SegLstSegment': + raise NotImplementedError(self) + def serialize(self): raise NotImplementedError(type(self)) @@ -112,26 +135,31 @@ def replace(self, **kwargs) -> 'Self': return dataclasses.replace(self, **kwargs) -@dataclass(frozen=True) -class Base: - lines: 'List[LineSubclasses]' +class Base(BaseABC): + lines: 'list[LineSubclasses]' line_cls = 'LineSubclasses' - @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[Self.line_cls]': - raise NotImplementedError() + def __init__(self, data): + self.lines = data @classmethod - def load(cls, file: [Path, str, io.TextIOBase, tuple, list], parse_float=float) -> 'Self': + def load(cls, file: [Path, str, io.TextIOBase, tuple, list], parse_float=decimal.Decimal) -> 'Self': files = file if isinstance(file, (tuple, list)) else [file] parsed_lines = [] for f in files: with _open(f, 'r') as fd: - parsed_lines.extend(cls._load(fd, parse_float=parse_float)) + parsed_lines.extend(cls.parse(fd.read(), parse_float=parse_float)) return cls(parsed_lines) + @classmethod + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self': + # Many of the supported file-formats have different conventions for comments. + # Below is an example for a file that doesn't have comments. + # return cls([cls.line_cls.parse(line) for line in s.splitlines() if line.strip()]) + raise NotImplementedError + def _repr_pretty_(self, p, cycle): name = self.__class__.__name__ with p.group(len(name) + 1, name + '(', ')'): @@ -167,7 +195,7 @@ def __add__(self, other): return NotImplemented return self.__class__(self.lines + other.lines) - def groupby(self, key) -> Dict[str, 'Self']: + def groupby(self, key) -> 'dict[str, Self]': """ >>> from meeteval.io.stm import STM, STMLine >>> stm = STM([STMLine.parse('rec1 0 A 10 20 Hello World')]) @@ -191,7 +219,7 @@ def groupby(self, key) -> Dict[str, 'Self']: ) } - def grouped_by_filename(self) -> Dict[str, 'Self']: + def grouped_by_filename(self) -> 'dict[str, Self]': return self.groupby(lambda x: x.filename) def grouped_by_speaker_id(self): @@ -318,6 +346,16 @@ def cut_by_uem(self: 'Subclasses', uem: 'UEM', verbose=False): def filenames(self): return {x.filename for x in self.lines} + def to_seglst(self) -> 'SegLST': + from meeteval.io.seglst import SegLST + return SegLST([l.to_seglst_segment() for l in self.lines]) + + @classmethod + def new(cls, s, **defaults) -> 'Self': + from meeteval.io.seglst import asseglst + return cls([cls.line_cls.from_dict({**defaults, **segment}) for segment in asseglst(s)]) + + def _open(f, mode='r'): if isinstance(f, io.TextIOBase): @@ -328,7 +366,7 @@ def _open(f, mode='r'): raise TypeError(type(f), f) -def load(file, parse_float=float): +def load(file, parse_float=decimal.Decimal): import meeteval file = Path(file) if file.suffix == '.stm': diff --git a/meeteval/io/ctm.py b/meeteval/io/ctm.py index 9c14cb5a..b47394cc 100644 --- a/meeteval/io/ctm.py +++ b/meeteval/io/ctm.py @@ -1,9 +1,13 @@ +import typing import warnings from dataclasses import dataclass -from typing import Dict, List, Optional -from typing import NamedTuple -from meeteval.io.base import Base, BaseLine +from typing import Optional +from meeteval.io.base import Base, BaseLine, BaseABC +import decimal +if typing.TYPE_CHECKING: + from typing import Self + from meeteval.io.seglst import SegLstSegment, SegLST __all__ = [ 'CTMLine', @@ -12,6 +16,8 @@ ] + + @dataclass(frozen=True) class CTMLine(BaseLine): """ @@ -30,13 +36,13 @@ class CTMLine(BaseLine): """ filename: str channel: 'str | int' - begin_time: float - duration: float + begin_time: 'decimal.Decimal | float' + duration: 'decimal.Decimal | float' word: str confidence: Optional[int] = None @classmethod - def parse(cls, line: str, parse_float=float) -> 'CTMLine': + def parse(cls, line: str, parse_float=decimal.Decimal) -> 'CTMLine': try: # CB: Should we disable the support for missing confidence? filename, channel, begin_time, duration, word, *confidence = line.strip().split() @@ -49,7 +55,6 @@ def parse(cls, line: str, parse_float=float) -> 'CTMLine': word, confidence[0] if confidence else None ) - assert ctm_line.begin_time >= 0, ctm_line assert ctm_line.duration >= 0, ctm_line except Exception as e: raise ValueError(f'Unable to parse CTM line: {line}') from e @@ -61,42 +66,80 @@ def serialize(self): >>> line.serialize() 'rec1 0 10 2 Hello 1' """ - return (f'{self.filename} {self.channel} {self.begin_time} ' - f'{self.duration} {self.word} {self.confidence}') + s = f'{self.filename} {self.channel} {self.begin_time} {self.duration} {self.word}' + if self.confidence is not None: + s += f' {self.confidence}' + return s + + @classmethod + def from_dict(cls, segment: 'SegLstSegment') -> 'Self': + # CTM only supports words as segments. + # If this check fails, the input data was not converted to words before. + assert ' ' not in segment['words'], segment + return cls( + filename=segment['session_id'], + channel=segment['channel'], + begin_time=segment['start_time'], + duration=segment['end_time'] - segment['start_time'], + word=segment['words'], + confidence=segment.get('confidence', None), + ) + + def to_seglst_segment(self) -> 'SegLstSegment': + d = { + 'session_id': self.filename, + 'channel': self.channel, + 'start_time': self.begin_time, + 'end_time': self.begin_time + self.duration, + 'words': self.word, + } + if self.confidence is not None: + d['confidence'] = self.confidence + return d @dataclass(frozen=True) class CTM(Base): - lines: List[CTMLine] - line_cls = CTMLine + lines: 'list[CTMLine]' + line_cls = 'CTMLine' @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[CTMLine]': - return [ + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self': + return cls([ CTMLine.parse(line, parse_float=parse_float) - for line in map(str.strip, file_descriptor) + for line in map(str.strip, s.split('\n')) if len(line) > 0 if not line.startswith(';;') - ] + ]) def merged_transcripts(self) -> str: return ' '.join([x.word for x in sorted(self.lines, key=lambda x: x.begin_time)]) - def utterance_transcripts(self) -> List[str]: + def utterance_transcripts(self) -> 'list[str]': """There is no notion of an "utterance" in CTM files.""" raise NotImplementedError() + @classmethod + def new(cls, s, **defaults) -> 'Self': + # CTM only supports a single speaker. Use CTMGroup to represent multiple speakers with this format. + if len(s.unique('speaker')) > 1: + raise ValueError( + f'CTM only supports a single speaker, but found {len(s.unique("speaker"))} speakers ' + f'({s.unique("speaker")}). Use CTMGroup to represent multiple speakers with this format.' + ) + return super().new(s, **defaults) + @dataclass(frozen=True) -class CTMGroup: - ctms: 'Dict[str, CTM]' +class CTMGroup(BaseABC): + ctms: 'dict[str, CTM]' @classmethod - def load(cls, ctm_files, parse_float=float): + def load(cls, ctm_files, parse_float=decimal.Decimal): return cls({str(ctm_file): CTM.load(ctm_file, parse_float=parse_float) for ctm_file in ctm_files}) - def grouped_by_filename(self) -> Dict[str, 'CTMGroup']: + def grouped_by_filename(self) -> 'dict[str, CTMGroup]': groups = { k: ctm.grouped_by_filename() for k, ctm in self.ctms.items() } @@ -128,9 +171,18 @@ def grouped_by_filename(self) -> Dict[str, 'CTMGroup']: for key in keys } - def grouped_by_speaker_id(self) -> Dict[str, CTM]: + def grouped_by_speaker_id(self) -> 'dict[str, CTM]': return self.ctms + @classmethod + def new(cls, s: 'SegLST', **defaults) -> 'Self': + from meeteval.io.seglst import asseglst + return cls({k: CTM.new(v) for k, v in asseglst(s).map(lambda s: {**defaults, **s}).groupby('speaker').items()}) + + def to_seglst(self) -> 'SegLST': + from meeteval.io.seglst import SegLST + return SegLST.merge(*[ctm.to_seglst().map(lambda x: {**x, 'speaker': speaker}) for speaker, ctm in self.ctms.items()]) + def to_stm(self): from meeteval.io import STM, STMLine stm = [] diff --git a/meeteval/io/keyed_text.py b/meeteval/io/keyed_text.py index 8caa9110..33c3def4 100644 --- a/meeteval/io/keyed_text.py +++ b/meeteval/io/keyed_text.py @@ -1,7 +1,13 @@ +import typing from dataclasses import dataclass -from typing import List from meeteval.io.base import BaseLine, Base +from meeteval.io.seglst import SegLstSegment +import decimal + + +if typing.TYPE_CHECKING: + from typing import Self @dataclass(frozen=True) @@ -10,7 +16,7 @@ class KeyedTextLine(BaseLine): transcript: str @classmethod - def parse(cls, line: str, parse_float=float) -> 'KeyedTextLine': + def parse(cls, line: str, parse_float=decimal.Decimal) -> 'KeyedTextLine': """ >>> KeyedTextLine.parse("key a transcript") KeyedTextLine(filename='key', transcript='a transcript') @@ -29,24 +35,40 @@ def parse(cls, line: str, parse_float=float) -> 'KeyedTextLine': transcript = '' return cls(filename, transcript) + def serialize(self): + return f'{self.filename} {self.transcript}' + + @classmethod + def from_dict(cls, segment: 'SegLstSegment') -> 'Self': + return cls( + filename=segment['session_id'], + transcript=segment['words'], + ) + + def to_seglst_segment(self) -> 'SegLstSegment': + return { + 'session_id': self.filename, + 'words': self.transcript, + } + @dataclass(frozen=True) class KeyedText(Base): - lines: List[KeyedTextLine] + lines: 'list[KeyedTextLine]' line_cls = KeyedTextLine @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[KeyedTextLine]': - return [ + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self': + return cls([ KeyedTextLine.parse(line, parse_float=parse_float) - for line in map(str.strip, file_descriptor) + for line in map(str.strip, s.split('\n')) if len(line) > 0 # if not line.startswith(';;') - ] + ]) def merged_transcripts(self) -> str: raise NotImplementedError() - def utterance_transcripts(self) -> List[str]: + def utterance_transcripts(self) -> 'list[str]': """There is no notion of an "utterance" in CTM files.""" raise NotImplementedError() diff --git a/meeteval/io/pbjson.py b/meeteval/io/pbjson.py index 04ed20b7..f5fbd7a5 100644 --- a/meeteval/io/pbjson.py +++ b/meeteval/io/pbjson.py @@ -6,8 +6,14 @@ """ import json +import typing from pathlib import Path +from meeteval.io.base import BaseABC + +if typing.TYPE_CHECKING: + from meeteval.io.seglst import SegLST + def _load_json(file): with open(file) as fd: @@ -40,6 +46,88 @@ def get_sample_rate(ex): return sample_rate +class PBJsonUtt(BaseABC): + """ + A JSON format where each entry/example represents a single utterance. + + Example: + ```python + pbjson = { + 'datasets': { + '': { + '': { + 'num_samples': { + 'original_source': , # in samples + }, + 'offset': , # in samples + 'speaker_id': , + 'transcription': , + 'kaldi_transcription': , # optional + } + } + } + } + ``` + """ + + def __init__(self, json, sample_rate=16000): + self.json = json + self.sample_rate = sample_rate + + @classmethod + def load(cls, file): + return cls(_load_json(file)) + + @classmethod + def new(cls, s, *, sample_rate=16000, dataset_name='default_dataset', **defaults): + # This copies the segments (since we are going to `pop` keys later), applies defaults and makes sure + # that the dataset_name key is set for all segments. + from meeteval.io.seglst import asseglst + s = asseglst(s).map(lambda x: {**defaults, **x, 'dataset_name': x['dataset_name'] or dataset_name}) + return cls({ + 'datasets': { + dataset_name: { + example_id: { + # Translate structure from SegLST to pbjson + 'num_samples': { + 'original_source': (example.pop('end_time') - example.pop('start_time')) * sample_rate, + }, + 'offset': example.pop('start_time') * sample_rate, + 'speaker_id': example.pop('speaker_id'), + 'transcription': example.pop('words'), + # Any additional keys are simply appended to the json + **example, + } + for example_id, example in dataset.groupby('example_id') + } + for dataset_name, dataset in s.groupby('dataset_name') + } + }) + + def to_seglst(self) -> 'SegLST': + return SegLST([ + { + # Translate known keys + 'dataset_name': dataset_name, + 'example_id': example_id, + 'words': example.get('kaldi_transcription') or example['transcription'], + 'speaker': example['speaker_id'], + 'start_time': example['offset'] / self.sample_rate, + 'end_time': (example['offset'] + example['num_samples']['original_source']) / self.sample_rate, + # Any other keys are appended + **{ + k: v + for k, v in example.items() if k not in { + 'offset', 'num_samples', 'speaker_id', 'transcription', + 'kaldi_transcription', + } + }, + } + for dataset_name, dataset in self.json['datasets'].items() + for example_id, example in dataset.items() + ]) + + def to_stm(json, out, datasets=None): import lazy_dataset.database out = Path(out) @@ -69,7 +157,7 @@ def add_line(speaker_id, begin_time, end_time, transcript): stm_lines.append(STMLine( filename=ex['example_id'], channel=0, speaker_id=speaker_id, - begin_time=begin_time/sample_rate, end_time=end_time/sample_rate, + begin_time=begin_time / sample_rate, end_time=end_time / sample_rate, transcript=transcript)) for spk, o, n, t in zip_strict( @@ -80,9 +168,9 @@ def add_line(speaker_id, begin_time, end_time, transcript): ): if isinstance(t, (tuple, list)): for o_, n_, t_ in zip_strict(o, n, t): - add_line(spk, o_, o_+n_, t_) + add_line(spk, o_, o_ + n_, t_) else: - add_line(spk, o, o+n, t) + add_line(spk, o, o + n, t) file = out / f'{dataset_name}_ref.stm' STM(stm_lines).dump(file) @@ -91,7 +179,7 @@ def add_line(speaker_id, begin_time, end_time, transcript): if __name__ == '__main__': import fire + fire.Fire({ 'to_stm': to_stm, }) - diff --git a/meeteval/io/py.py b/meeteval/io/py.py new file mode 100644 index 00000000..bf887846 --- /dev/null +++ b/meeteval/io/py.py @@ -0,0 +1,261 @@ +import dataclasses +import typing +from typing import Any + +from meeteval.io.base import BaseABC + +if typing.TYPE_CHECKING: + from typing import Self + from meeteval.io.seglst import SegLST + + +def _convert_python_structure(structure, *, keys=(), final_key='words', _final_types=str): + from meeteval.io.seglst import SegLST + from meeteval.wer.utils import _keys + + def _convert(d, index=0): + # A single string is converted to a single segment + # If the final key is not set, we set it to 0 + if isinstance(d, str): + segment = {final_key: d} + + if len(keys) > index + 1: + raise ValueError( + f"The structure contains fewer levels than keys provided. " + f"keys: {keys}, structure depth: {index}" + ) + + if index < len(keys): + segment[keys[-1]] = 0 + return [segment], (str,) + + # A structure. `all_keys[index]` is the key for this level + if isinstance(d, (list, dict, tuple)): + if len(d) == 0: + # Special case where no structure information is available + # This case is not invertible! + return [], None if len(keys) > index else (type(d),) + + # Check if we have a key for this level. If not, raise an exception + if index >= len(keys): + # Only raise the exception if the final key doesn't match the `final_types` + if _final_types is not None and not isinstance(d, _final_types): + raise ValueError( + f'{structure} cannot be converted because it contains more nested levels than keys given! ' + f'keys={keys!r}, ' + f'final_key={final_key!r}, final_types={_final_types!r}' + ) + return [{final_key: d}], (type(d),) + + key = keys[index] + converted = {} + types = {} + + for k in _keys(d): + converted[k], types[k] = _convert(d[k], index=index + 1) + + # We can only convert back if all items in can be converted. + # Return None if not invertible. + types = {k: types[k] for k in types.keys() if types[k] is not None} + + if len(set(types.values())) == 1: + types = (type(d),) + next(iter(types.values())) + else: + types = None + + # Set or check (when already set) the key. Make sure that + # - The group keys are unique + # - All items in a group have the same value for `key` + for k, segments in converted.items(): + for s in segments: + s.setdefault(key, k) + assert set(s[key] for s in segments) == {k} or len(segments) == 0, (k, segments) + assert set(c[0][key] for c in converted.values() if c).issubset( + set(converted.keys())), f'Group values are not unique! {converted}' + return [v for l in converted.values() for v in l], types + raise TypeError(d) + + segments, types = _convert(structure) + return SegLST(segments), types + + +def _invert_python_structure(t: 'SegLST', types, keys): + if len(types) != len(keys): + if len(types) < len(keys): + # _convert_python_structure adds a dummy key if the final key is not set. We have to remove it here + # again. + for k in keys[len(types) - 1:-1]: + if len(t.unique(k)) != 1: + raise ValueError( + f'Cannot convert SegLST to Python structure with t.keys={t.T.keys()!r}, types={types!r} and ' + f'keys={keys!r}. Each non-unique key must have a type, otherwise the structure is not ' + f'convertible.' + ) + keys = keys[:len(types) - 1] + (keys[-1],) + else: + raise ValueError( + f'Cannot convert SegLST to Python structure with ' + f't.keys={t.keys!r}, types={types!r} and keys={keys!r}.' + ) + + if len(types) == 1: + # After modification, it can happen that the SegLST representation contains multiple segments. + # We concatenate here to keep some sort of "old" behavior. + # + # Note: This inversion gives back the original when the SegLST representation is not modified, + # but it can give a different representation when it was modified. + if keys[0] == 'words': + words = [s['words'] for s in t] + if types[0] == str: + return ' '.join(words) + else: + return types[0](words) + assert len(t) == 1, t + return types[0](t.segments[0][keys[0]]) + groups = {k: _invert_python_structure(v, types[1:], keys[1:]) for k, v in t.groupby(keys[0]).items()} + if types[0] in (list, tuple): + if any([not isinstance(k, int) for k in groups.keys()]): + # The behavior is not well-defined for non-integer keys. It is unclear whether keys should be sorted + # with Python's sort or natsort or whether they should be sorted at all. Hence, we only allow integer keys + # for conversion to list/tuple, where sorting is reasonable. + raise ValueError( + f'Cannot convert SegLST to Python sequence structures (list or tuple) with non-int keys. ' + f'Expected integer keys, but found {groups.keys()}.' + ) + groups = types[0](v for _, v in sorted(groups.items())) + return groups + + +@dataclasses.dataclass(frozen=True) +class NestedStructure(BaseABC): + """ + Wraps a Python structure where the structure levels represent keys. + + Example structure for cpWER: + ```python + structure = { + 'Alice': ['Transcript of segment 1', 'Transcript of segment 2'], # Speaker 1: Alice + 'Bob': ['Utterance 1', 'Utterance 2'], # Speaker 2: Bob + } + NestedStructure(structure, ('speaker', 'segment')) + ``` + """ + structure: Any + level_keys: 'list[str, ...] | tuple[str, ...]' = ('speaker', 'segment_index') + final_key: 'str' = 'words' + + # Private attributes. Intended for use in a generalized apply_assignment function. + # Use with care! It can lead to unexpected behavior! + _final_types: 'type | list[type] | tuple[type]' = str + + # Cache variables + _types = None + _used_keys = None + _seglst = None + + def new(self, t: 'SegLST', **defaults) -> 'Self': + """ + This is usually a classmethod, but here, it's an instance method + because we need `keys` and `types` for conversion. + + >>> def convert_cycle(structure, keys, mod=None): + ... s = NestedStructure(structure, keys) + ... t = s.to_seglst() + ... if mod: + ... t = mod(t) + ... return s.new(t).structure + >>> convert_cycle('a b c', keys=()) + 'a b c' + >>> convert_cycle(['a b c', 'd e f'], keys=('speaker',)) + ['a b c', 'd e f'] + >>> convert_cycle({'A': 'a b c', 'B': 'd e f'}, keys=('speaker',)) + {'A': 'a b c', 'B': 'd e f'} + >>> s = NestedStructure({'B': 'd e f', 'A': 'a b c'}, level_keys=('speaker',)) + >>> s2 = NestedStructure(['a b c', 'd e f'], level_keys=('speaker',)) + >>> s.new(s2.to_seglst()).structure + {0: 'a b c', 1: 'd e f'} + + Empty structures are only invertible if all keys can be inferred and empty nesting levels get lost + >>> convert_cycle([], keys=()) + [] + >>> convert_cycle([], keys=('speaker',)) + Traceback (most recent call last): + ... + ValueError: Cannot convert to Python structure because this structure is not invertible. + >>> convert_cycle([[]], keys=('speaker', 'segment_index')) + Traceback (most recent call last): + ... + ValueError: Cannot convert to Python structure because this structure is not invertible. + >>> convert_cycle([['abc'], []], keys=('speaker', 'segment_index')) + [['abc']] + """ + if self.types is None: + raise ValueError('Cannot convert to Python structure because this structure is not invertible.') + from meeteval.io.seglst import asseglst + t = asseglst(t).map(lambda x: {**defaults, **x}) + return NestedStructure(_invert_python_structure(t, self.types, self.level_keys + (self.final_key,)), + self.level_keys, self.final_key, self._final_types) + + @property + def types(self): + if self._seglst is None: + self.to_seglst() + return self._types + + def to_seglst(self): + """ + Converting Python structures for cpWER, ORC WER and MIMO WER + >>> NestedStructure('a b c', level_keys=()).to_seglst() + SegLST(segments=[{'words': 'a b c'}]) + + Structure levels are interpreted as the keys in keys. + >>> NestedStructure('a b c', level_keys=('speaker',)).to_seglst() + SegLST(segments=[{'words': 'a b c', 'speaker': 0}]) + >>> NestedStructure(['a b c', 'd e f'], level_keys=('speaker',)).to_seglst() + SegLST(segments=[{'words': 'a b c', 'speaker': 0}, {'words': 'd e f', 'speaker': 1}]) + >>> NestedStructure({'A': 'a b c', 'B': 'd e f'}, level_keys=('speaker',)).to_seglst() + SegLST(segments=[{'words': 'a b c', 'speaker': 'A'}, {'words': 'd e f', 'speaker': 'B'}]) + >>> NestedStructure({'ex1': {'A': 'a b c', }, 'ex2': {'C': 'd e f'}}, level_keys=('session_id', 'speaker')).to_seglst() + SegLST(segments=[{'words': 'a b c', 'speaker': 'A', 'session_id': 'ex1'}, {'words': 'd e f', 'speaker': 'C', 'session_id': 'ex2'}]) + + All keys in `keys` must be present, otherwise an exception is raised + >>> NestedStructure('a b c', level_keys=('speaker', 'channel')).to_seglst() + Traceback (most recent call last): + ... + ValueError: The structure contains fewer levels than keys provided. keys: ('speaker', 'channel'), structure depth: 0 + + Empty structures + + Empty structures result in empty `SegLST` objects + >>> NestedStructure([], level_keys=('speaker',)).to_seglst() + SegLST(segments=[]) + >>> NestedStructure({}, level_keys=('speaker',)).to_seglst() + SegLST(segments=[]) + >>> NestedStructure([{}], level_keys=('speaker',)).to_seglst() + SegLST(segments=[]) + + Empty nested structures are allowed but not represented in the `SegLST` format + >>> NestedStructure([['ab'], []], level_keys=('speaker', 'segment_index')).to_seglst() + SegLST(segments=[{'words': 'ab', 'segment_index': 0, 'speaker': 0}]) + + The last key is filled with a dummy key + >>> NestedStructure(['a b c', 'd e f'], level_keys=('speaker', 'channel')).to_seglst() + SegLST(segments=[{'words': 'a b c', 'channel': 0, 'speaker': 0}, {'words': 'd e f', 'channel': 0, 'speaker': 1}]) + + Nested structures beyond the specified groups are by default not allowed. With `ensure_word_is_string=False`, + you can have nested structures. But be careful, this can lead to unexpected results! + >>> NestedStructure({'A': ['abc', 'def']}, level_keys=(), _final_types=None).to_seglst() + SegLST(segments=[{'words': {'A': ['abc', 'def']}}]) + >>> NestedStructure({'A': ['abc', 'def']}, level_keys=('speaker',), _final_types=None).to_seglst() + SegLST(segments=[{'words': ['abc', 'def'], 'speaker': 'A'}]) + >>> NestedStructure({'A': [{'x': 'abc'}, {'y': 'def'}]}, level_keys=(), _final_types=None).to_seglst() + SegLST(segments=[{'words': {'A': [{'x': 'abc'}, {'y': 'def'}]}}]) + """ + if self._seglst is None: + self.__dict__['_seglst'], self.__dict__['_types'] = _convert_python_structure( + self.structure, + keys=self.level_keys, + final_key=self.final_key, + _final_types=self._final_types, + ) + return self._seglst diff --git a/meeteval/io/rttm.py b/meeteval/io/rttm.py index a0324070..e53c550c 100644 --- a/meeteval/io/rttm.py +++ b/meeteval/io/rttm.py @@ -1,11 +1,10 @@ import typing -from typing import List, NamedTuple from dataclasses import dataclass from meeteval.io.base import Base, BaseLine - +from meeteval.io.seglst import SegLstSegment +import decimal if typing.TYPE_CHECKING: - import decimal from meeteval.io.uem import UEM, UEMLine @@ -53,10 +52,10 @@ class RTTMLine(BaseLine): signal_look_ahead_time: str = '' @classmethod - def parse(cls, line: str, parse_float=float) -> 'RTTMLine': + def parse(cls, line: str, parse_float=decimal.Decimal) -> 'RTTMLine': """ >>> RTTMLine.parse('SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 juliet ') - RTTMLine(type='SPEAKER', filename='CMU_20020319-1400_d01_NONE', channel='1', begin_time=130.43, duration=2.35, othography='', speaker_type='', speaker_id='juliet', confidence='', signal_look_ahead_time='') + RTTMLine(type='SPEAKER', filename='CMU_20020319-1400_d01_NONE', channel='1', begin_time=Decimal('130.430000'), duration=Decimal('2.350'), othography='', speaker_type='', speaker_id='juliet', confidence='', signal_look_ahead_time='') """ type_, filename, channel, begin_time, duration, othography, \ speaker_type, speaker_id, confidence, signal_look_ahead_time, \ @@ -75,11 +74,36 @@ def parse(cls, line: str, parse_float=float) -> 'RTTMLine': signal_look_ahead_time=signal_look_ahead_time, ) + @classmethod + def from_dict(cls, segment: 'SegLstSegment') -> 'RTTMLine': + # TODO: read spec and handle speech segments with transcripts + return RTTMLine( + filename=segment['session_id'], + channel=segment.get('channel', cls.channel), + speaker_id=segment['speaker'], + begin_time=segment['start_time'], + duration=segment['end_time'] - segment['start_time'], + ) + + def to_seglst_segment(self) -> 'SegLstSegment': + # TODO: read spec and handle speech segments with transcripts and other types + d = { + 'session_id': self.filename, + 'speaker': self.speaker_id, + 'start_time': self.begin_time, + 'end_time': self.begin_time + self.duration, + } + + if self.channel != self.__class__.channel: + d['channel'] = self.channel + + return d + def serialize(self): """ >>> line = RTTMLine.parse('SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 juliet ') >>> line.serialize() - 'SPEAKER CMU_20020319-1400_d01_NONE 1 130.43 2.35 juliet ' + 'SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 juliet ' """ return (f'{self.type} {self.filename} {self.channel} ' f'{self.begin_time} {self.duration} {self.othography} ' @@ -89,13 +113,13 @@ def serialize(self): @dataclass(frozen=True) class RTTM(Base): - lines: List[RTTMLine] + lines: 'list[RTTMLine]' line_cls = RTTMLine @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[RTTMLine]': - return [ + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'RTTM': + return cls([ RTTMLine.parse(line, parse_float) - for line in file_descriptor + for line in s.split('\n') if len(line.strip()) > 0 and not line.strip().startswith(';') - ] + ]) diff --git a/meeteval/io/seglst.py b/meeteval/io/seglst.py new file mode 100644 index 00000000..1c6daab6 --- /dev/null +++ b/meeteval/io/seglst.py @@ -0,0 +1,413 @@ +import dataclasses +import functools +import typing + +from meeteval.io.base import BaseABC +from meeteval.io.py import NestedStructure +from meeteval._typing import TypedDict + +if typing.TYPE_CHECKING: + from meeteval.wer.wer.error_rate import ErrorRate + from typing import Callable, Iterable, Any + + + +class SegLstSegment(TypedDict, total=False): + """ + A segment. + + Note: + We do not define an enum with all these keys for speed reasons + """ + session_id: str + start_time: float + end_time: float + words: str + speaker: str + segment_index: int + + # Unused but by MeetEval but present in some file formats. They are defined here for compatibility and + # conversion in both directions + channel: int + confidence: float + + +@dataclasses.dataclass(frozen=True) +class SegLST(BaseABC): + """ + A collection of segments in SegLST format. This the input type to most + functions in MeetEval that process transcript segments. + """ + segments: 'list[SegLstSegment]' + + # Caches + _unique = None + + @property + class T: + """ + The "transpose" of the segments, i.e., a mapping that maps keys to lists of values. + + The name `T` is inspired by the `T` in `pandas.DataFrame.T` and `numpy.ndarray.T`. + """ + def __init__(self, outer): + self._outer = outer + + def keys(self): + """ + The keys that are common among all segments + """ + return set.intersection(*[set(s.keys()) for s in self._outer.segments]) + + def __getitem__(self, key): + """ + Returns the values for `key` of all segments as a list. + """ + return [s[key] for s in self._outer.segments] + + def unique(self, key) -> 'set[Any]': + """ + Returns the unique values for `key` among all segments. + """ + return set([s[key] for s in self.segments]) + + def __iter__(self): + return iter(self.segments) + + def __getitem__(self, item): + return self.segments[item] + + def __len__(self): + return len(self.segments) + + def __add__(self, other): + if isinstance(other, SegLST): + return SegLST(self.segments + other.segments) + return NotImplemented + + def groupby(self, key) -> 'dict[Any, SegLST]': + """ + >>> t = asseglst(['a b c', 'd e f', 'g h i']) + >>> t.segments + [{'words': 'a b c', 'segment_index': 0, 'speaker': 0}, {'words': 'd e f', 'segment_index': 0, 'speaker': 1}, {'words': 'g h i', 'segment_index': 0, 'speaker': 2}] + + >>> from pprint import pprint + >>> pprint(t.groupby('speaker')) # doctest: +ELLIPSIS + {0: SegLST(segments=[{'words': 'a b c', 'segment_index': 0, 'speaker': 0}]), + 1: SegLST(segments=[{'words': 'd e f', 'segment_index': 0, 'speaker': 1}]), + 2: SegLST(segments=[{'words': 'g h i', 'segment_index': 0, 'speaker': 2}])} + """ + return {k: SegLST(g) for k, g in groupby(self.segments, key=key).items()} + + def sorted(self, key) -> 'SegLST': + """ + Returns a copy of this object with the segments sorted by `key`. + """ + return SegLST(sorted(self.segments, key=_get_key(key))) + + def map(self, fn: 'Callable[[SegLstSegment], SegLstSegment]') -> 'SegLST': + """ + Applies `fn` to all segments and returns a new `SegLST` object with the results. + """ + return SegLST([fn(s) for s in self.segments]) + + def flatmap(self, fn: 'Callable[[list[SegLstSegment]], Iterable[SegLstSegment]]') -> 'SegLST': + """ + Returns a new `SegLST` by applying `fn`, which is exptected to return an iterable of `SegLstSegment`s, + to all segments and flattening the output. + + The name is inspired by other programming languages (e.g., JavaScript, Rust, Java, Scala) where + flatmap is a common operation on lists / arrays / iterators. In data loading frameworks, + this operation is known as map followed by unbatch. + + Example: Split utterances into words + >>> SegLST([{'words': 'a b c'}]).flatmap(lambda x: [{'words': w} for w in x['words'].split()]) + SegLST(segments=[{'words': 'a'}, {'words': 'b'}, {'words': 'c'}]) + """ + return SegLST([s for t in self.segments for s in fn(t)]) + + def filter(self, fn: 'Callable[[SegLstSegment], bool]') -> 'SegLST': + """ + Applies `fn` to all segments and returns a new `SegLST` object with the segments for which `fn` returns true. + """ + return SegLST([s for s in self.segments if fn(s)]) + + @classmethod + def merge(cls, *t) -> 'SegLST': + """ + Merges multiple `SegLST` objects into one by concatenating all segments. + """ + return SegLST([s for t_ in t for s in t_.segments]) + + def to_seglst(self) -> 'SegLST': + return self + + @classmethod + def new(cls, d, **defaults) -> 'SegLST': + d = asseglst(d) + if defaults: + d = d.map(lambda s: {**defaults, **s}) + return d + + def _repr_pretty_(self, p, cycle): + """ + >>> from IPython.lib.pretty import pprint + >>> pprint(SegLST([{'words': 'a b c', 'segment_index': 0, 'speaker': 0}])) + SegLST([{'words': 'a b c', 'segment_index': 0, 'speaker': 0}]) + >>> pprint(SegLST([{'words': 'a b c', 'segment_index': 0, 'speaker': 0}, {'words': 'd e f', 'segment_index': 0, 'speaker': 1}, {'words': 'g h i', 'segment_index': 0, 'speaker': 2}])) + SegLST([{'words': 'a b c', 'segment_index': 0, 'speaker': 0}, + {'words': 'd e f', 'segment_index': 0, 'speaker': 1}, + {'words': 'g h i', 'segment_index': 0, 'speaker': 2}]) + """ + name = self.__class__.__name__ + with p.group(len(name) + 1, name + '(', ')'): + if cycle: + p.text('...') + else: + p.pretty(list(self.segments)) + + +def asseglistconvertible(d, *, py_convert=NestedStructure): + """ + Converts `d` into a structure that is convertible to the SegLST format, i.e., that + has `to_seglst` (and often `new`) defined. + """ + # Already convertible + if hasattr(d, 'to_seglst'): + return d + + # Chime7 format / List of `SegLstSegment`s + if isinstance(d, list) and (len(d) == 0 or isinstance(d[0], dict) and 'words' in d[0]): + # TODO: Conversion back to list of segments (Python structure)? + return SegLST(d) + + # TODO: pandas DataFrame + + # Convert Python structures + if isinstance(d, (list, tuple, dict, str)): + if py_convert is None: + raise TypeError(f'Cannot convert {type(d)} to SegLST with py_convert={py_convert!r}!') + # TODO: Conversion back to Python structure? + return py_convert(d) + + raise NotImplementedError(f'No conversion implemented for {type(d)}!') + + +def asseglst(d, *, required_keys=(), py_convert=NestedStructure) -> 'SegLST': + """ + Converts an object `d` into SegLST data format. `d` can be anything convertible to the SegLST format. + Returns `d` if `isinstance(d, SegLST)`. + + Python structures have to have one or two nested levels. The first level is interpreted as the speaker key and the + second level as the segment key. + >>> asseglst(['a b c']) + SegLST(segments=[{'words': 'a b c', 'segment_index': 0, 'speaker': 0}]) + >>> asseglst([['a b c', 'd e f'], ['g h i']]) + SegLST(segments=[{'words': 'a b c', 'segment_index': 0, 'speaker': 0}, {'words': 'd e f', 'segment_index': 1, 'speaker': 0}, {'words': 'g h i', 'segment_index': 0, 'speaker': 1}]) + >>> asseglst({'A': ['a b c', 'd e f'], 'B': ['g h i']}) + SegLST(segments=[{'words': 'a b c', 'segment_index': 0, 'speaker': 'A'}, {'words': 'd e f', 'segment_index': 1, 'speaker': 'A'}, {'words': 'g h i', 'segment_index': 0, 'speaker': 'B'}]) + + Data formats are also converted + >>> from meeteval.io.stm import STM, STMLine + >>> stm = STM.parse('ex 1 A 0 1 a b c') + >>> asseglst(stm).segments + [{'session_id': 'ex', 'channel': 1, 'speaker': 'A', 'start_time': 0, 'end_time': 1, 'words': 'a b c'}] + + The SegLST representation can be converted back to its original representation + >>> print(stm.new(stm).dumps()) + ex 1 A 0 1 a b c + + + And modified before inversion + >>> s = asseglst(stm) + >>> s.segments[0]['words'] = 'x y z' + >>> print(stm.new(s).dumps()) + ex 1 A 0 1 x y z + + """ + assert isinstance(required_keys, tuple), required_keys + + # Exit early if already in the correct format + if isinstance(d, SegLST): + return d + + # Get a type that is convertible to SegLST + d = asseglistconvertible(d, py_convert=py_convert) + + t = d.to_seglst() + + # Check that `t` has all required keys + if len(t) and not set(required_keys).issubset(t.T.keys()): + required_keys = set(required_keys) + raise ValueError( + f'Some required keys are not present in the converted data structure!\n' + f'Required: {required_keys}, found: {t.T.keys()}, missing: {required_keys - t.T.keys()}' + ) + return t + + +def _get_key(key): + import operator + if callable(key) or key is None: + return key + elif isinstance(key, (str, int)): + return operator.itemgetter(key) + else: + raise TypeError(f'Invalid type for key: {type(key)}') + + +def groupby( + iterable, + key=None, + default_key=None, +): + """ + A non-lazy variant of `itertools.groupby` with advanced features. + + Copied from `paderbox.utils.iterable.groupby`. + + Args: + iterable: Iterable to group + key: Determines by what to group. Can be: + - `None`: Use the iterables elements as keys directly + - `callable`: Gets called with every element and returns the group + key + - `str`, or `int`: Use `__getitem__` on elements in `iterable` + to obtain the key + - `Iterable`: Provides the keys. Has to have the same length as + `iterable`. + + Examples: + >>> groupby('ab'*3) + {'a': ['a', 'a', 'a'], 'b': ['b', 'b', 'b']} + >>> groupby(range(10), lambda x: x%2) + {0: [0, 2, 4, 6, 8], 1: [1, 3, 5, 7, 9]} + >>> groupby(({'a': x%2, 'b': x} for x in range(3)), 'a') + {0: [{'a': 0, 'b': 0}, {'a': 0, 'b': 2}], 1: [{'a': 1, 'b': 1}]} + >>> groupby(['abc', 'bd', 'abd', 'cdef', 'c'], 0) + {'a': ['abc', 'abd'], 'b': ['bd'], 'c': ['cdef', 'c']} + >>> groupby('abc', {}) + Traceback (most recent call last): + ... + TypeError: Invalid type for key: + """ + import collections + import itertools + + groups = collections.defaultdict(list) + try: + for key, group in itertools.groupby(iterable, _get_key(key)): + groups[key].extend(group) + except KeyError: + if default_key is None: + raise + else: + assert len(groups) == 0, (groups, iterable, key) + return iterable + return dict(groups) + + +def seglst_map(*, required_keys=(), py_convert=NestedStructure): + """ + Decorator to for a function that takes a (single) `SegLST` object as input and returns a (single) `SegLST` object + as output. Automatically converts the input to `SegLST` and converts the returned value back to its original type. + + >>> @seglst_map(required_keys=('speaker',)) + ... def fn(seglst, *, speaker='X'): + ... return seglst.map(lambda x: {**x, 'speaker': speaker}) + >>> from meeteval.io.stm import STM + >>> fn(STM.parse('X 1 A 0 1 a b c')) + STM(lines=[STMLine(filename='X', channel=1, speaker_id='X', begin_time=0, end_time=1, transcript='a b c')]) + >>> from meeteval.io.rttm import RTTM + >>> fn(RTTM.parse('SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 juliet ')) + RTTM(lines=[RTTMLine(type='SPEAKER', filename='CMU_20020319-1400_d01_NONE', channel='1', begin_time=Decimal('130.430000'), duration=Decimal('2.350000'), othography='', speaker_type='', speaker_id='X', confidence='', signal_look_ahead_time='')]) + >>> fn({'A': 'abc', 'B': 'def'}).structure + {'X': 'abc def'} + """ + + def _seglst_map(fn): + @functools.wraps(fn) + def _seglst_map(arg, *args, **kwargs): + c = asseglistconvertible(arg, py_convert=py_convert) + arg = asseglst(c, required_keys=required_keys) + arg = fn(arg, *args, **kwargs) + return c.new(arg) + + return _seglst_map + + return _seglst_map + + +def apply_multi_file( + fn: 'Callable[[SegLST, SegLST], ErrorRate]', + reference, hypothesis, + *, + allowed_empty_examples_ratio=0.1 +): + """ + Applies a function individually to all sessions / files. + + `reference` and `hypothesis` must be convertible to `SegLST`. If they are a Python structure, the first level + is interpreted as the session / file key. + + >>> from meeteval.wer.wer.cp import cp_word_error_rate + >>> from pprint import pprint + >>> ref = [['a b c', 'd e f'], ['g h i']] + >>> hyp = [['a b c'], ['d e f', 'g h i']] + >>> er = apply_multi_file(cp_word_error_rate, ref, hyp) + >>> pprint(er) + {0: CPErrorRate(error_rate=0.5, errors=3, length=6, insertions=0, deletions=3, substitutions=0, missed_speaker=1, falarm_speaker=0, scored_speaker=2, assignment=((0, 0), (1, None))), + 1: CPErrorRate(error_rate=1.0, errors=3, length=3, insertions=3, deletions=0, substitutions=0, missed_speaker=0, falarm_speaker=1, scored_speaker=1, assignment=((0, 1), (None, 0)))} + """ + import logging + reference = asseglst( + reference, required_keys=('session_id',), + py_convert=lambda p: NestedStructure(p, ('session_id', 'speaker', 'segment_id')) + ).groupby('session_id') + hypothesis = asseglst( + hypothesis, required_keys=('session_id',), + py_convert=lambda p: NestedStructure(p, ('session_id', 'speaker', 'segment_id')) + ).groupby('session_id') + + # Check session keys. Print a warning if they differ and raise an exception when they differ too much + if reference.keys() != hypothesis.keys(): + h_minus_r = list(set(hypothesis.keys()) - set(reference.keys())) + r_minus_h = list(set(reference.keys()) - set(hypothesis.keys())) + + ratio = len(r_minus_h) / len(reference.keys()) + + if h_minus_r: + # This is a warning, because missing in reference is not a problem, + # we can safely ignore it. Missing in hypothesis is a problem, + # because we cannot distinguish between silence and missing. + logging.warning( + f'Keys of reference and hypothesis differ\n' + f'hypothesis - reference: e.g. {h_minus_r[:5]} (Total: {len(h_minus_r)} of {len(reference)})\n' + f'Drop them.', + ) + hypothesis = { + k: v + for k, v in hypothesis.items() + if k not in h_minus_r + } + + if len(r_minus_h) == 0: + pass + elif ratio <= allowed_empty_examples_ratio: + logging.warning( + f'Missing {ratio * 100:.3} % = {len(r_minus_h)}/{len(reference.keys())} of recordings in hypothesis.\n' + f'Please check your system, if it ignored some recordings or predicted no transcriptions for some recordings.\n' + f'Continue with the assumption, that the system predicted silence for the missing recordings.', + ) + else: + raise RuntimeError( + 'Keys of reference and hypothesis differ\n' + f'hypothesis - reference: e.g. {h_minus_r[:5]} (Total: {len(h_minus_r)} of {len(hypothesis)})\n' + f'reference - hypothesis: e.g. {r_minus_h[:5]} (Total: {len(r_minus_h)} of {len(reference)})' + ) + + results = {} + for session in reference.keys(): + results[session] = fn(reference[session], hypothesis[session]) + + return results diff --git a/meeteval/io/stm.py b/meeteval/io/stm.py index ea43afce..18997f3c 100644 --- a/meeteval/io/stm.py +++ b/meeteval/io/stm.py @@ -1,14 +1,8 @@ -import sys -import typing from dataclasses import dataclass -from typing import List, NamedTuple from meeteval.io.base import Base, BaseLine -import logging +import decimal -if typing.TYPE_CHECKING: - import decimal - from meeteval.io.uem import UEM, UEMLine - from meeteval.wer import ErrorRate +from meeteval.io.seglst import SegLstSegment __all__ = [ 'STMLine', @@ -40,7 +34,7 @@ class STMLine(BaseLine): transcript: str @classmethod - def parse(cls, line: str, parse_float=float) -> 'STMLine': + def parse(cls, line: str, parse_float=decimal.Decimal) -> 'STMLine': filename, channel, speaker_id, begin_time, end_time, *transcript = line.strip().split(maxsplit=5) if len(transcript) == 1: @@ -60,11 +54,32 @@ def parse(cls, line: str, parse_float=float) -> 'STMLine': ) except Exception as e: raise ValueError(f'Unable to parse STM line: {line}') from e - assert stm_line.begin_time >= 0, stm_line + # assert stm_line.begin_time >= 0, stm_line # We currently ignore the end time, so it's fine when it's before begin_time # assert stm_line.end_time >= stm_line.begin_time, stm_line return stm_line + @classmethod + def from_dict(cls, segment: 'SegLstSegment'): + return cls( + filename=segment['session_id'], + channel=segment.get('channel', 1), + speaker_id=segment['speaker'], + begin_time=segment['start_time'], + end_time=segment['end_time'], + transcript=segment['words'], + ) + + def to_seglst_segment(self) -> 'SegLstSegment': + return { + 'session_id': self.filename, + 'channel': self.channel, + 'speaker': self.speaker_id, + 'start_time': self.begin_time, + 'end_time': self.end_time, + 'words': self.transcript, + } + def serialize(self): """ >>> line = STMLine.parse('rec1 0 A 10 20 Hello World') @@ -74,29 +89,19 @@ def serialize(self): return (f'{self.filename} {self.channel} {self.speaker_id} ' f'{self.begin_time} {self.end_time} {self.transcript}') - def segment_dict(self): - """Returns a segment dict in the style of Chime-7 annotations""" - return { - 'start_time': self.begin_time, - 'end_time': self.end_time, - 'words': self.transcript, - 'speaker': self.speaker_id, - 'session_id': self.filename - } - @dataclass(frozen=True) class STM(Base): - lines: List[STMLine] + lines: 'list[STMLine]' line_cls = STMLine @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[STMLine]': - return [ + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'STM': + return cls([ STMLine.parse(line, parse_float) - for line in file_descriptor + for line in s.split('\n') if len(line.strip()) > 0 and not line.strip().startswith(';') - ] + ]) @classmethod def merge(cls, *stms) -> 'STM': @@ -108,19 +113,7 @@ def to_rttm(self): # ToDo: Fix `line.end_time - line.begin_time`, when they are floats. # Sometimes there is a small error and the error will be written # to the rttm file. - - return RTTM([ - RTTMLine( - filename=line.filename, - channel=line.channel, - begin_time=line.begin_time, - duration=line.end_time - line.begin_time, - speaker_id=line.speaker_id, - # line.transcript RTTM doesn't support transcript - # hence this information is dropped. - ) - for line in self.lines - ]) + return RTTM.new(self.to_seglst()) def to_array_interval(self, sample_rate, group=True): import paderbox as pb @@ -137,78 +130,21 @@ def to_array_interval(self, sample_rate, group=True): (round(line.begin_time * sample_rate), round(line.end_time * sample_rate)) for line in self.lines]) - def utterance_transcripts(self) -> List[str]: + def utterance_transcripts(self) -> 'list[str]': return [x.transcript for x in sorted(self.lines, key=lambda x: x.begin_time)] def merged_transcripts(self) -> str: return ' '.join(self.utterance_transcripts()) - def segments(self): - return [l.segment_dict() for l in self] - - -def iter_examples(reference: 'STM', hypothesis: 'STM', *, allowed_empty_examples_ratio=0.1): - reference = reference.grouped_by_filename() - hypothesis = hypothesis.grouped_by_filename() - - if reference.keys() != hypothesis.keys(): - h_minus_r = list(set(hypothesis.keys()) - set(reference.keys())) - r_minus_h = list(set(reference.keys()) - set(hypothesis.keys())) - ratio = len(r_minus_h) / len(reference.keys()) +if __name__ == '__main__': + def to_rttm(file): + from pathlib import Path + STM.load(file).to_rttm().dump(Path(file).with_suffix('.rttm')) - if h_minus_r: - # This is a warning, because missing in reference is not a problem, - # we can safely ignore it. Missing in hypothesis is a problem, - # because we cannot distinguish between silence and missing. - logging.warning( - f'Keys of reference and hypothesis differ\n' - f'hypothesis - reference: e.g. {h_minus_r[:5]} (Total: {len(h_minus_r)} of {len(reference)})\n' - f'Drop them.', - ) - hypothesis = { - k: v - for k, v in hypothesis.items() - if k not in h_minus_r - } - if len(r_minus_h) == 0: - pass - elif ratio <= allowed_empty_examples_ratio: - logging.warning( - f'Missing {ratio * 100:.3} % = {len(r_minus_h)}/{len(reference.keys())} of recordings in hypothesis.\n' - f'Please check your system, if it ignored some recordings or predicted no transcriptions for some recordings.\n' - f'Continue with the assumption, that the system predicted silence for the missing recordings.', - ) - else: - raise RuntimeError( - 'Keys of reference and hypothesis differ\n' - f'hypothesis - reference: e.g. {h_minus_r[:5]} (Total: {len(h_minus_r)} of {len(hypothesis)})\n' - f'reference - hypothesis: e.g. {r_minus_h[:5]} (Total: {len(r_minus_h)} of {len(reference)})' - ) - - for filename in reference: - yield filename, reference[filename], hypothesis[filename] - - -def apply_stm_multi_file( - fn: 'typing.Callable[[STM, STM], ErrorRate]', - reference: 'STM', - hypothesis: 'STM', - *, - allowed_empty_examples_ratio=0.1 -): - result = {} - for f, r, h in iter_examples( - reference, hypothesis, - allowed_empty_examples_ratio=allowed_empty_examples_ratio - ): - logging.debug(f'Processing example {f}') - try: - result[f] = fn(r, h) - logging.debug(f'Result of example {f}: {result[f]}') - except Exception: - logging.error(f'Exception in example {f}') - raise - return result + import fire + fire.Fire({ + 'to_rttm': to_rttm, + }) diff --git a/meeteval/io/uem.py b/meeteval/io/uem.py index 22dd3573..b1e82956 100644 --- a/meeteval/io/uem.py +++ b/meeteval/io/uem.py @@ -1,4 +1,3 @@ -from typing import List, NamedTuple import decimal from dataclasses import dataclass from meeteval.io.base import Base, BaseLine @@ -38,10 +37,10 @@ class UEMLine(BaseLine): end_time: 'float | int | decimal.Decimal' = 0 @classmethod - def parse(cls, line: str, parse_float=float) -> 'UEMLine': + def parse(cls, line: str, parse_float=decimal.Decimal) -> 'UEMLine': """ >>> UEMLine.parse('S01 1 60.001 79.003') - UEMLine(filename='S01', channel='1', begin_time=60.001, end_time=79.003) + UEMLine(filename='S01', channel='1', begin_time=Decimal('60.001'), end_time=Decimal('79.003')) """ filename, channel, begin_time, end_time = line.split() @@ -64,7 +63,7 @@ def serialize(self): @dataclass(frozen=True) class UEM(Base): - lines: List[UEMLine] + lines: 'list[UEMLine]' line_cls = UEMLine @cached_property @@ -74,12 +73,12 @@ def _key_to_index(self): return {k: v for v, k in enumerate(keys)} @classmethod - def _load(cls, file_descriptor, parse_float) -> 'List[UEMLine]': - return [ + def parse(cls, s: str, parse_float=decimal.Decimal) -> 'UEM': + return cls([ UEMLine.parse(line, parse_float) - for line in file_descriptor + for line in s.spilt('\n') if len(line.strip()) > 0 # and not line.strip().startswith(';') # Does uem allow comments? - ] + ]) def __getitem__(self, item): if isinstance(item, str): diff --git a/meeteval/wer/__main__.py b/meeteval/wer/__main__.py index 4e957e4f..ea79afbe 100644 --- a/meeteval/wer/__main__.py +++ b/meeteval/wer/__main__.py @@ -7,7 +7,6 @@ import re import decimal from pathlib import Path -from typing import List, Tuple import meeteval.io from meeteval.io.ctm import CTMGroup @@ -94,12 +93,12 @@ def _load(path: Path): raise NotImplementedError(f'Unknown file ext: {path.suffix}') -def _load_reference(reference: 'Path | List[Path]'): +def _load_reference(reference: 'Path | list[Path]'): """Loads a reference transcription file. Currently only STM supported""" return STM.load(reference) -def _load_hypothesis(hypothesis: List[Path]): +def _load_hypothesis(hypothesis: 'list[Path]'): """Loads the hypothesis. Supports one STM file or multiple CTM files (one per channel)""" if len(hypothesis) > 1: @@ -126,7 +125,7 @@ def _load_hypothesis(hypothesis: List[Path]): raise RuntimeError(hypothesis, filename) -def _load_texts(reference_paths: List[str], hypothesis_paths: List[str], regex) -> Tuple[STM, List[Path], STM, List[Path]]: +def _load_texts(reference_paths: 'list[str]', hypothesis_paths: 'list[str]', regex) -> 'tuple[STM, list[Path], STM, list[Path]]': """Load and validate reference and hypothesis texts. Validation checks that reference and hypothesis have the same example IDs. @@ -164,7 +163,7 @@ def filter(stm): return reference, reference_paths, hypothesis, hypothesis_paths -def _get_parent_stem(hypothesis_paths: List[Path]): +def _get_parent_stem(hypothesis_paths: 'list[Path]'): hypothesis_paths = [p.resolve() for p in hypothesis_paths] if len(hypothesis_paths) == 1: @@ -184,7 +183,7 @@ def _get_parent_stem(hypothesis_paths: List[Path]): def _save_results( per_reco, - hypothesis_paths: List[Path], + hypothesis_paths: 'list[Path]', per_reco_out: str, average_out: str, ): @@ -224,8 +223,8 @@ def wer( f'Got: {reference_paths} for reference and {hypothesis_paths} for hypothesis.') reference = KeyedText.load(reference) hypothesis = KeyedText.load(hypothesis) - from meeteval.wer.wer.siso import siso_word_error_rate_keyed_text - results = siso_word_error_rate_keyed_text(reference, hypothesis) + from meeteval.wer.wer.siso import siso_word_error_rate_multifile + results = siso_word_error_rate_multifile(reference, hypothesis) _save_results(results, hypothesis_paths, per_reco_out, average_out) @@ -236,10 +235,10 @@ def orcwer( regex=None, ): """Computes the Optimal Reference Combination Word Error Rate (ORC WER)""" - from meeteval.wer.wer.orc import orc_word_error_rate_stm + from meeteval.wer.wer.orc import orc_word_error_rate_multifile reference, _, hypothesis, hypothesis_paths = _load_texts( reference, hypothesis, regex=regex) - results = orc_word_error_rate_stm(reference, hypothesis) + results = orc_word_error_rate_multifile(reference, hypothesis) _save_results(results, hypothesis_paths, per_reco_out, average_out) @@ -250,10 +249,10 @@ def cpwer( regex=None, ): """Computes the Concatenated minimum-Permutation Word Error Rate (cpWER)""" - from meeteval.wer.wer.cp import cp_word_error_rate_stm + from meeteval.wer.wer.cp import cp_word_error_rate_multifile reference, _, hypothesis, hypothesis_paths = _load_texts( reference, hypothesis, regex) - results = cp_word_error_rate_stm(reference, hypothesis) + results = cp_word_error_rate_multifile(reference, hypothesis) _save_results(results, hypothesis_paths, per_reco_out, average_out) @@ -264,10 +263,10 @@ def mimower( regex=None, ): """Computes the MIMO WER""" - from meeteval.wer.wer.mimo import mimo_word_error_rate_stm + from meeteval.wer.wer.mimo import mimo_word_error_rate_multifile reference, _, hypothesis, hypothesis_paths = _load_texts( reference, hypothesis, regex=regex) - results = mimo_word_error_rate_stm(reference, hypothesis) + results = mimo_word_error_rate_multifile(reference, hypothesis) _save_results(results, hypothesis_paths, per_reco_out, average_out) @@ -283,10 +282,10 @@ def tcpwer( hypothesis_sort='segment', ): """Computes the time-constrained minimum permutation WER""" - from meeteval.wer.wer.time_constrained import tcp_word_error_rate_stm + from meeteval.wer.wer.time_constrained import tcp_word_error_rate_multifile reference, _, hypothesis, hypothesis_paths = _load_texts( reference, hypothesis, regex=regex) - results = tcp_word_error_rate_stm( + results = tcp_word_error_rate_multifile( reference, hypothesis, reference_pseudo_word_level_timing=ref_pseudo_word_timing, hypothesis_pseudo_word_level_timing=hyp_pseudo_word_timing, @@ -301,7 +300,7 @@ def tcpwer( def _merge( - files: List[str], + files: 'list[str]', out: str = None, average: bool = None ): diff --git a/meeteval/wer/matching/mimo_matching.py b/meeteval/wer/matching/mimo_matching.py index 5c41e524..1a9d2881 100644 --- a/meeteval/wer/matching/mimo_matching.py +++ b/meeteval/wer/matching/mimo_matching.py @@ -1,9 +1,9 @@ import typing Utterance = typing.Sequence[typing.Hashable] -Hypothesis = typing.List[Utterance] -Reference = typing.List[typing.List[Utterance]] -Assignment = typing.List[typing.Tuple[int, int]] +Hypothesis = 'list[Utterance]' +Reference = 'list[list[Utterance]]' +Assignment = 'list[tuple[int, int]]' def levenshtein_distance(ref, hyp): diff --git a/meeteval/wer/matching/orc_matching.py b/meeteval/wer/matching/orc_matching.py index 5f210f53..d7c8a73c 100644 --- a/meeteval/wer/matching/orc_matching.py +++ b/meeteval/wer/matching/orc_matching.py @@ -11,14 +11,14 @@ import typing Utterance = typing.Iterable[typing.Hashable] -Assignment = typing.Tuple[int, ...] +Assignment = 'tuple[int, ...]' def _get_channel_transcription_from_assignment( - utterances: typing.List[Utterance], + utterances: 'list[Utterance]', assignment: Assignment, num_channels: int -) -> typing.List[typing.List[typing.Hashable]]: +) -> 'list[list[typing.Hashable]]': import itertools c = [[] for _ in range(num_channels)] @@ -150,8 +150,8 @@ def lev(ref_len, hyps_len, idx=None): def orc_matching_v3( - ref: typing.List[Utterance], - hyps: typing.List[typing.List[typing.Hashable]] + ref: 'list[Utterance]', + hyps: 'list[list[typing.Hashable]]' ): """ A Cython implementation of the ORC matching algorithm diff --git a/meeteval/wer/wer/__init__.py b/meeteval/wer/wer/__init__.py index 4647699d..222cc76b 100644 --- a/meeteval/wer/wer/__init__.py +++ b/meeteval/wer/wer/__init__.py @@ -1,6 +1,6 @@ -from .cp import cp_word_error_rate, CPErrorRate, cp_word_error_rate_stm -from .mimo import mimo_word_error_rate, MimoErrorRate, mimo_word_error_rate_stm -from .orc import orc_word_error_rate, orc_word_error_rate_stm, OrcErrorRate -from .siso import siso_word_error_rate, siso_character_error_rate, siso_word_error_rate_keyed_text +from .cp import cp_word_error_rate, CPErrorRate, cp_word_error_rate_multifile +from .mimo import mimo_word_error_rate, MimoErrorRate, mimo_word_error_rate_multifile +from .orc import orc_word_error_rate, orc_word_error_rate_multifile, OrcErrorRate +from .siso import siso_word_error_rate, siso_character_error_rate, siso_word_error_rate_multifile from .error_rate import ErrorRate, combine_error_rates -from .time_constrained import time_constrained_minimum_permutation_word_error_rate, time_constrained_siso_word_error_rate, tcp_word_error_rate_stm +from .time_constrained import time_constrained_minimum_permutation_word_error_rate, time_constrained_siso_word_error_rate, tcp_word_error_rate_multifile diff --git a/meeteval/wer/wer/cp.py b/meeteval/wer/wer/cp.py index 9eb07b09..ca6a5dcb 100644 --- a/meeteval/wer/wer/cp.py +++ b/meeteval/wer/wer/cp.py @@ -1,16 +1,14 @@ import dataclasses import itertools import string -from typing import Optional, Tuple, List, Dict, Any, Iterable +from typing import Optional, Any, Iterable from meeteval._typing import Literal -from meeteval.io import STM +from meeteval.io.seglst import SegLST, asseglst from meeteval.wer.wer.error_rate import ErrorRate -from meeteval.wer.wer.siso import siso_word_error_rate, _siso_error_rate -from meeteval.wer.utils import _items, _values, _keys, _map -__all__ = ['CPErrorRate', 'cp_word_error_rate', 'apply_cp_assignment', 'cp_word_error_rate_stm'] +__all__ = ['CPErrorRate', 'cp_word_error_rate', 'apply_cp_assignment', 'cp_word_error_rate_multifile'] @dataclasses.dataclass(frozen=True, repr=False) @@ -28,8 +26,8 @@ class CPErrorRate(ErrorRate): missed_speaker: int falarm_speaker: int scored_speaker: int - # assignment: Optional[Tuple[int, ...]] = None - assignment: Optional[Tuple['int | str | Any', ...]] = None + # assignment: 'Optional[tuple[int, ...]]' = None + assignment: 'Optional[tuple[int | str | Any, ...]]' = None @classmethod def zero(cls): @@ -92,24 +90,17 @@ def apply_assignment( ) -def cp_error_rate( - reference: 'List[Iterable] | Dict[Any, Iterable]', - hypothesis: 'List[Iterable] | Dict[Any, Iterable]', -) -> CPErrorRate: - from meeteval.wer.matching.cy_levenshtein import levenshtein_distance - +def cp_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> CPErrorRate: + from meeteval.wer.wer.siso import _seglst_siso_error_rate, siso_levenshtein_distance return _cp_error_rate( reference, hypothesis, - distance_fn=levenshtein_distance, - siso_error_rate=_siso_error_rate, + distance_fn=siso_levenshtein_distance, + siso_error_rate=_seglst_siso_error_rate, ) -def cp_word_error_rate( - reference: 'List[str | Iterable[str]] | Dict[str | Iterable[str]] | STM', - hypothesis: 'List[str | Iterable[str]] | Dict[str | Iterable[str]] | STM', -): +def cp_word_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> CPErrorRate: """ The Concatenated minimum Permutation WER (cpWER). @@ -154,53 +145,31 @@ def cp_word_error_rate( >>> cp_word_error_rate(['a b c'.split(), 'd e f'.split()], ['a b c'.split(), 'd e f'.split()]) CPErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, missed_speaker=0, falarm_speaker=0, scored_speaker=2, assignment=((0, 0), (1, 1))) """ - import meeteval.io - - def transcription_to_words(x): - def split(words): - if isinstance(words, str): - return words.split() - elif isinstance(words, list): - assert isinstance(words[0], str), (type(words[0]), words) - return words - elif isinstance(words, meeteval.io.stm.STM): - assert len({(line.filename, line.speaker_id) for line in words}) == 1, words - return [ - word - for line in words.sorted_by_begin_time() - for word in line.transcript.split() - ] - else: - raise TypeError(type(words), words) - - if isinstance(x, meeteval.io.stm.STM): - assert len(x.filenames()) <= 1, (len(x.filenames()), x.filenames(), x) - return transcription_to_words(x.grouped_by_speaker_id()) - else: - return _map(split, x) - - return cp_error_rate( - transcription_to_words(reference), - transcription_to_words(hypothesis), - ) + reference = asseglst(reference, required_keys=('speaker', 'words')) + hypothesis = asseglst(hypothesis, required_keys=('speaker', 'words')) + + def split_words(d: 'SegLST'): + return d.flatmap( + lambda s: [{**s, 'words': w} for w in (s['words'].split() if s['words'].strip() else [''])]) + + return cp_error_rate(split_words(reference), split_words(hypothesis)) -def cp_word_error_rate_stm(reference_stm: 'STM', hypothesis_stm: 'STM') -> 'Dict[str, CPErrorRate]': +def cp_word_error_rate_multifile(reference_stm, hypothesis_stm) -> 'dict[str, CPErrorRate]': """ Computes the cpWER for each example in the reference and hypothesis STM files. - To compute the overall WER, use `sum(cp_word_error_rate_stm(r, h).values())`. + To compute the overall WER, use `sum(cp_word_error_rate_multifile(r, h).values())`. """ - from meeteval.io.stm import apply_stm_multi_file - return apply_stm_multi_file(cp_word_error_rate, reference_stm, hypothesis_stm) + from meeteval.io.seglst import apply_multi_file + return apply_multi_file(cp_word_error_rate, reference_stm, hypothesis_stm) def _cp_error_rate( - reference, - hypothesis, + reference: SegLST, + hypothesis: SegLST, distance_fn: callable, siso_error_rate: callable, - missing=(), ): # Used in # cp_word_error_rate @@ -210,6 +179,9 @@ def _cp_error_rate( import scipy.optimize import numpy as np + reference = reference.groupby('speaker') + hypothesis = hypothesis.groupby('speaker') + if max(len(hypothesis), len(reference)) > 20: num_speakers = max(len(hypothesis), len(reference)) raise RuntimeError( @@ -224,15 +196,15 @@ def _cp_error_rate( [ distance_fn(tt, et) for et, _ in itertools.zip_longest( - _values(hypothesis), - reference, # ignored, "padding" for underestimation - fillvalue=missing, - ) + hypothesis.values(), + reference.values(), # ignored, "padding" for underestimation + fillvalue=SegLST([]), + ) ] for tt, _ in itertools.zip_longest( - _values(reference), - hypothesis, # ignored, "padding" for overestimation - fillvalue=missing, + reference.values(), + hypothesis.values(), # ignored, "padding" for overestimation + fillvalue=SegLST([]), ) ]) @@ -244,8 +216,8 @@ def _cp_error_rate( # Compute WER from distance distance = sum(distances) - reference_keys = dict(enumerate(_keys(reference))) # need `dict.get` of the keys for overestimation - hypothesis_keys = dict(enumerate(_keys(hypothesis))) # need `dict.get` of the keys for underestimation + reference_keys = dict(enumerate(reference.keys())) # need `dict.get` of the keys for overestimation + hypothesis_keys = dict(enumerate(hypothesis.keys())) # need `dict.get` of the keys for underestimation assignment = tuple([ (reference_keys.get(r), hypothesis_keys.get(c)) @@ -259,12 +231,12 @@ def _cp_error_rate( assignment, reference=reference, hypothesis=hypothesis, - missing=missing, + missing=SegLST([]), ) er = sum([ siso_error_rate(r, hypothesis_new[speaker]) - for speaker, r in _items(reference_new) + for speaker, r in reference_new.items() ]) assert distance == er.errors, (distance, er) @@ -284,7 +256,7 @@ def _cp_error_rate( def apply_cp_assignment( - assignment: 'List[Tuple[Any, ...]] | Tuple[Tuple[Any, ...], ...]', + assignment: 'list[tuple[Any, ...]] | tuple[tuple[Any, ...], ...]', reference: dict, hypothesis: dict, style: 'Literal["hyp", "ref"]' = 'ref', diff --git a/meeteval/wer/wer/error_rate.py b/meeteval/wer/wer/error_rate.py index 326257dc..9be6f9a3 100644 --- a/meeteval/wer/wer/error_rate.py +++ b/meeteval/wer/wer/error_rate.py @@ -2,7 +2,7 @@ __all__ = ['ErrorRate', 'combine_error_rates'] -from typing import Optional +from typing import Optional, Any import logging logger = logging.getLogger('error_rate') @@ -38,6 +38,12 @@ def __add__(self, other: 'SelfOverlap') -> 'SelfOverlap': self.total_time + other.total_time, ) + def __radd__(self, other: 'int') -> 'SelfOverlap': + if isinstance(other, int) and other == 0: + # Special case to support sum. + return self + return NotImplemented + @classmethod def from_dict(cls, d: dict): return cls(d['overlap_time'], d['total_time']) @@ -236,3 +242,35 @@ def combine_error_rates(*error_rates: ErrorRate) -> ErrorRate: if len(error_rates) == 1: return error_rates[0] return sum(error_rates) + + +@dataclasses.dataclass(frozen=True) +class CombinedErrorRate(ErrorRate): + details: 'dict[Any, ErrorRate]' + + @classmethod + def from_error_rates(cls, error_rates: 'dict[Any, ErrorRate]'): + from meeteval.wer.utils import _values + er = sum(_values(error_rates)) + return cls( + errors=er.errors, + length=er.length, + insertions=er.insertions, + deletions=er.deletions, + substitutions=er.substitutions, + reference_self_overlap=er.reference_self_overlap, + hypothesis_self_overlap=er.hypothesis_self_overlap, + details=error_rates, + ) + + def __repr__(self): + return ( + self.__class__.__qualname__ + '(' + + ', '.join([ + f"{f.name}={getattr(self, f.name)!r}" + if f.name != 'details' else 'details=...' + for f in dataclasses.fields(self) + if getattr(self, f.name) is not None + ]) + ')' + ) + diff --git a/meeteval/wer/wer/mimo.py b/meeteval/wer/wer/mimo.py index 606dfc05..b38d32a7 100644 --- a/meeteval/wer/wer/mimo.py +++ b/meeteval/wer/wer/mimo.py @@ -1,11 +1,12 @@ import dataclasses -from typing import Tuple, List, Dict, Iterable, Any +from typing import Iterable, Any +from meeteval.io.seglst import asseglst from meeteval.wer.wer.error_rate import ErrorRate -from meeteval.wer.wer.siso import siso_word_error_rate, _siso_error_rate -from meeteval.wer.utils import _keys, _items, _values, _map +from meeteval.wer.wer.siso import _siso_error_rate +from meeteval.wer.utils import _keys, _items, _values -__all__ = ['MimoErrorRate', 'mimo_word_error_rate', 'apply_mimo_assignment', 'mimo_word_error_rate_stm'] +__all__ = ['MimoErrorRate', 'mimo_word_error_rate', 'apply_mimo_assignment', 'mimo_word_error_rate_multifile'] from meeteval.io import STM @@ -18,12 +19,12 @@ class MimoErrorRate(ErrorRate): >>> MimoErrorRate(0, 10, 0, 0, 0, None, None, [(0, 0)]) + MimoErrorRate(10, 10, 0, 0, 10, None, None, [(0, 0)]) ErrorRate(error_rate=0.5, errors=10, length=20, insertions=0, deletions=0, substitutions=10) """ - assignment: Tuple[int, ...] + assignment: 'tuple[int, ...]' def mimo_error_rate( - reference: 'List[List[Iterable]] | Dict[Any, List[Iterable]]', - hypothesis: 'List[Iterable] | Dict[Iterable]', + reference: 'list[list[Iterable]] | dict[Any, list[Iterable]]', + hypothesis: 'list[Iterable] | dict[Iterable]', ): if max(len(hypothesis), len(reference)) > 10: num_speakers = max(len(hypothesis), len(reference)) @@ -69,13 +70,7 @@ def mimo_error_rate( ) - - - -def mimo_word_error_rate( - reference: 'List[List[str] | Dict[Any, str]] | Dict[List[str], Dict[Any, str]] | STM', - hypothesis: 'List[str] | Dict[str] | STM', -) -> MimoErrorRate: +def mimo_word_error_rate(reference, hypothesis) -> MimoErrorRate: """ The Multiple Input speaker, Multiple Output channel (MIMO) WER. @@ -94,38 +89,47 @@ def mimo_word_error_rate( ... {'O1': 'c d', 'O2': 'a b e f'}) MimoErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, assignment=[('A', 'O2'), ('B', 'O2'), ('A', 'O1')]) + >>> mimo_word_error_rate(STM.parse('X 1 A 0.0 1.0 a b\\nX 1 A 1.0 2.0 c d\\nX 1 B 0.0 2.0 e f\\n'), STM.parse('X 1 1 0.0 2.0 c d\\nX 1 0 0.0 2.0 a b e f\\n')) + MimoErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, assignment=[('A', '0'), ('B', '0'), ('A', '1')]) """ - if isinstance(reference, STM) or isinstance(hypothesis, STM): - from meeteval.wer.wer.utils import _check_valid_input_files - _check_valid_input_files(reference, hypothesis) - reference = { - speaker_id: r.utterance_transcripts() - for speaker_id, r in reference.grouped_by_speaker_id().items() - } - hypothesis = { - speaker_id: h.merged_transcripts() - for speaker_id, h in hypothesis.grouped_by_speaker_id().items() - } - - reference = _map(lambda x: _map(str.split, x), reference) - hypothesis = _map(str.split, hypothesis) + reference = asseglst(reference) + hypothesis = asseglst(hypothesis) + + # Sort by start time if the start time is available + # TODO: implement something like reference_sort from time_constrained.py? + if 'start_time' in reference.T.keys(): + reference = reference.sorted('start_time') + if 'start_time' in hypothesis.T.keys(): + hypothesis = hypothesis.sorted('start_time') + + # Convert to dict of lists of words + reference = { + k: [s['words'].split() for s in v if s['words'] != ''] + for k, v in reference.groupby('speaker').items() + } + hypothesis = { + k: [w for s in v if s['words'] != '' for w in s['words'].split()] + for k, v in hypothesis.groupby('speaker').items() + } + + # Call core function return mimo_error_rate(reference, hypothesis) -def mimo_word_error_rate_stm(reference_stm: 'STM', hypothesis_stm: 'STM') -> 'Dict[str, MimoErrorRate]': +def mimo_word_error_rate_multifile(reference_stm, hypothesis_stm) -> 'dict[str, MimoErrorRate]': """ Computes the MIMO WER for each example in the reference and hypothesis STM files. - To compute the overall WER, use `sum(mimo_word_error_rate_stm(r, h).values())`. + To compute the overall WER, use `sum(mimo_word_error_rate_multifile(r, h).values())`. """ - from meeteval.io.stm import apply_stm_multi_file - return apply_stm_multi_file(mimo_word_error_rate, reference_stm, hypothesis_stm) + from meeteval.io.seglst import apply_multi_file + return apply_multi_file(mimo_word_error_rate, reference_stm, hypothesis_stm) def apply_mimo_assignment( - assignment: 'List[tuple]', - reference: 'List[List[Any]] | Dict[List[Any]]', - hypothesis: 'List[Any] | Dict[Any, Any]', + assignment: 'list[tuple]', + reference: 'list[list[Any]] | dict[list[Any]]', + hypothesis: 'list[Any] | dict[Any, Any]', ): """ >>> assignment = [('A', 'O2'), ('B', 'O2'), ('A', 'O1')] diff --git a/meeteval/wer/wer/orc.py b/meeteval/wer/wer/orc.py index 76ab41e4..827b91c5 100644 --- a/meeteval/wer/wer/orc.py +++ b/meeteval/wer/wer/orc.py @@ -1,13 +1,14 @@ import collections import dataclasses -from typing import Tuple, List, Dict, Iterable, Any +from typing import Iterable, Any +from meeteval.io.seglst import asseglst, SegLST from meeteval.wer.wer.error_rate import ErrorRate from meeteval.wer.wer.siso import _siso_error_rate from meeteval.wer.utils import _items, _keys, _values, _map from meeteval.io.stm import STM -__all__ = ['OrcErrorRate', 'orc_word_error_rate', 'orc_word_error_rate_stm', 'apply_orc_assignment'] +__all__ = ['OrcErrorRate', 'orc_word_error_rate', 'orc_word_error_rate_multifile', 'apply_orc_assignment'] @dataclasses.dataclass(frozen=True, repr=False) @@ -18,7 +19,7 @@ class OrcErrorRate(ErrorRate): >>> OrcErrorRate(0, 10, 0, 0, 0, None, None, (0, 1)) + OrcErrorRate(10, 10, 0, 0, 10, None, None, (1, 0, 1)) ErrorRate(error_rate=0.5, errors=10, length=20, insertions=0, deletions=0, substitutions=10) """ - assignment: Tuple[int, ...] + assignment: 'tuple[int, ...]' def apply_assignment(self, reference, hypothesis): ref = collections.defaultdict(list) @@ -39,19 +40,19 @@ def apply_assignment(self, reference, hypothesis): return ref, hypothesis -def orc_word_error_rate_stm(reference_stm: 'STM', hypothesis_stm: 'STM') -> 'Dict[str, OrcErrorRate]': +def orc_word_error_rate_multifile(reference_stm: 'STM', hypothesis_stm: 'STM') -> 'dict[str, OrcErrorRate]': """ Computes the ORC WER for each example in the reference and hypothesis STM files. - To compute the overall WER, use `sum(orc_word_error_rate_stm(r, h).values())`. + To compute the overall WER, use `sum(orc_word_error_rate_multifile(r, h).values())`. """ - from meeteval.io.stm import apply_stm_multi_file - return apply_stm_multi_file(orc_word_error_rate, reference_stm, hypothesis_stm) + from meeteval.io.seglst import apply_multi_file + return apply_multi_file(orc_word_error_rate, reference_stm, hypothesis_stm) def orc_error_rate( - reference: 'List[Iterable]', - hypothesis: 'List[Iterable] | Dict[Any, Iterable]', + reference: 'list[Iterable]', + hypothesis: 'list[Iterable] | dict[Any, Iterable]', ): # Safety check: The complexity explodes for large numbers of speakers if len(hypothesis) > 10: @@ -94,10 +95,7 @@ def orc_error_rate( -def orc_word_error_rate( - reference: 'List[str] | STM', - hypothesis: 'List[str] | dict[str] | STM', -) -> OrcErrorRate: +def orc_word_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> OrcErrorRate: """ The Optimal Reference Combination (ORC) WER, implemented efficiently. @@ -116,30 +114,48 @@ def orc_word_error_rate( >>> er.apply_assignment(['a', 'c d', 'e'], ['a c', 'd e']) ([['a', 'c d'], ['e']], ['a c', 'd e']) + >>> orc_word_error_rate(STM.parse('X 1 A 0.0 1.0 a b\\nX 1 A 1.0 2.0 c d\\nX 1 B 0.0 2.0 e f\\n'), STM.parse('X 1 1 0.0 2.0 c d\\nX 1 0 0.0 2.0 a b e f\\n')) + OrcErrorRate(error_rate=0.0, errors=0, length=6, insertions=0, deletions=0, substitutions=0, assignment=('0', '0', '1')) + >>> er = orc_word_error_rate(['a', 'c d', 'e'], {'A': 'a c', 'B': 'd e'}) >>> er OrcErrorRate(error_rate=0.5, errors=2, length=4, insertions=1, deletions=1, substitutions=0, assignment=('A', 'A', 'B')) >>> er.apply_assignment(['a', 'c d', 'e'], {'A': 'a c', 'B': 'd e'}) ({'A': ['a', 'c d'], 'B': ['e']}, {'A': 'a c', 'B': 'd e'}) """ - if isinstance(reference, STM) or isinstance(hypothesis, STM): - from meeteval.wer.wer.utils import _check_valid_input_files - _check_valid_input_files(reference, hypothesis) - reference = reference.utterance_transcripts() - hypothesis = { - speaker_id: h_.merged_transcripts() - for speaker_id, h_ in hypothesis.grouped_by_speaker_id().items() - } - - reference_words = [r.split() for r in reference] - hypothesis_words = _map(str.split, hypothesis) - return orc_error_rate(reference_words, hypothesis_words) + # Convert to SegLST + reference = asseglst(reference, required_keys=('words',)) + hypothesis = asseglst(hypothesis, required_keys=('words', 'speaker')) + + if 'start_time' in reference.T.keys(): + reference = reference.sorted('start_time') + if 'start_time' in hypothesis.T.keys(): + hypothesis = hypothesis.sorted('start_time') + + reference = [w.split() for w in reference.T['words'] if w] + hypothesis = { + speaker: [w for s in seglst if s['words'] for w in s['words'].split()] + for speaker, seglst in hypothesis.groupby('speaker').items() + } + + # if isinstance(reference, STM) or isinstance(hypothesis, STM): + # from meeteval.wer.wer.utils import _check_valid_input_files + # _check_valid_input_files(reference, hypothesis) + # reference = reference.utterance_transcripts() + # hypothesis = { + # speaker_id: h_.merged_transcripts() + # for speaker_id, h_ in hypothesis.grouped_by_speaker_id().items() + # } + # + # reference_words = [r.split() for r in reference] + # hypothesis_words = _map(str.split, hypothesis) + return orc_error_rate(reference, hypothesis) def apply_orc_assignment( - assignment: 'List[tuple]', - reference: 'List[str]', - hypothesis: 'List[str] | dict[str]', + assignment: 'list[tuple]', + reference: 'list[str]', + hypothesis: 'list[str] | dict[str]', ): """ >>> assignment = ('A', 'A', 'B') diff --git a/meeteval/wer/wer/siso.py b/meeteval/wer/wer/siso.py index b35d33bf..5252c36d 100644 --- a/meeteval/wer/wer/siso.py +++ b/meeteval/wer/wer/siso.py @@ -1,19 +1,33 @@ import typing -from typing import List, Hashable, Dict +from typing import Hashable -from meeteval.io.keyed_text import KeyedText +from meeteval.io.py import NestedStructure from meeteval.wer.wer.error_rate import ErrorRate +from meeteval.io.seglst import asseglst if typing.TYPE_CHECKING: from meeteval.io.stm import STM + from meeteval.io.seglst import SegLST -__all__ = ['siso_word_error_rate', 'siso_character_error_rate', 'siso_word_error_rate_keyed_text'] +__all__ = ['siso_word_error_rate', 'siso_character_error_rate', 'siso_word_error_rate_multifile'] -def _siso_error_rate( - reference: List[Hashable], - hypothesis: List[Hashable] -) -> ErrorRate: +def siso_levenshtein_distance(reference: 'SegLST', hypothesis: 'SegLST') -> int: + """ + Every element is treated as a single word. + """ + reference = asseglst(reference, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + hypothesis = asseglst(hypothesis, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + + from meeteval.wer.matching.cy_levenshtein import levenshtein_distance + + reference = [w for w in reference.T['words'] if w] + hypothesis = [w for w in hypothesis.T['words'] if w] + + return levenshtein_distance(reference, hypothesis) + + +def _siso_error_rate(reference: 'list[Hashable]', hypothesis: 'list[Hashable]') -> ErrorRate: import kaldialign try: @@ -33,10 +47,24 @@ def _siso_error_rate( ) -def siso_word_error_rate( - reference: 'str | STM | KeyedText', - hypothesis: 'str | STM | KeyedText', -) -> ErrorRate: +def _seglst_siso_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> ErrorRate: + reference = [w for w in reference.T['words'] if w] + hypothesis = [w for w in hypothesis.T['words'] if w] + return _siso_error_rate(reference, hypothesis) + + +def siso_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> ErrorRate: + reference = asseglst(reference, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + hypothesis = asseglst(hypothesis, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + if reference[0].get('session') != hypothesis[0].get('session'): + raise ValueError( + f'Session ID must be identical, but found {reference[0].get("session")} for the reference ' + f'and {hypothesis[0].get("session")} for the hypothesis.' + ) + return _seglst_siso_error_rate(reference, hypothesis) + + +def siso_word_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> ErrorRate: """ The "standard" Single Input speaker, Single Output speaker (SISO) WER. @@ -49,45 +77,55 @@ def siso_word_error_rate( ErrorRate(error_rate=1.0, errors=2, length=2, insertions=0, deletions=0, substitutions=2) >>> siso_word_error_rate(reference='This is wikipedia', hypothesis='This wikipedia') # Deletion example from https://en.wikipedia.org/wiki/Word_error_rate ErrorRate(error_rate=0.3333333333333333, errors=1, length=3, insertions=0, deletions=1, substitutions=0) + >>> from meeteval.io.stm import STM + >>> siso_word_error_rate(STM.parse('X 1 Wikipedia 0 1 This is wikipedia'), STM.parse('X 1 Wikipedia 0 1 This wikipedia')) + ErrorRate(error_rate=0.3333333333333333, errors=1, length=3, insertions=0, deletions=1, substitutions=0) """ - if isinstance(reference, KeyedText) or isinstance(hypothesis, KeyedText): - from meeteval.wer.wer.utils import _check_valid_input_files - _check_valid_input_files(reference, hypothesis) - if len(reference.lines) != 1: - raise ValueError( - f'Reference must contain exactly one line, but found {len(reference.lines)} lines in {reference}.' - ) - if len(hypothesis.lines) != 1: - raise ValueError( - f'Hypothesis must contain exactly one line, but found {len(hypothesis.lines)} lines in {reference}.' - ) - reference = reference.lines[0].transcript - hypothesis = hypothesis.lines[0].transcript - - return _siso_error_rate( - reference.split(), - hypothesis.split() - ) + reference = asseglst(reference, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + hypothesis = asseglst(hypothesis, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + if len(reference) != 1: + raise ValueError(f'Reference must contain exactly one line, but found {len(reference)} lines.') + if len(hypothesis) != 1: + raise ValueError(f'Hypothesis must contain exactly one line, but found {len(hypothesis)} lines.') -def siso_word_error_rate_keyed_text(reference: 'STM | KeyedText', hypothesis: 'STM | KeyedText') -> 'Dict[str, ErrorRate]': + def split_words(d): + return [ + {**s, 'words': w} + for s in d + for w in (s['words'].split() if s['words'].strip() else ['']) + ] + + return siso_error_rate(split_words(reference), split_words(hypothesis)) + + +def siso_word_error_rate_multifile(reference, hypothesis) -> 'dict[str, ErrorRate]': """ Computes the standard WER for each example in the reference and hypothesis files. - To compute the overall WER, use `sum(siso_word_error_rate_keyed_text(r, h).values())`. + To compute the overall WER, use `sum(siso_word_error_rate_multifile(r, h).values())`. """ - from meeteval.io.stm import apply_stm_multi_file - return apply_stm_multi_file(siso_word_error_rate, reference, hypothesis, allowed_empty_examples_ratio=0) + from meeteval.io.seglst import apply_multi_file + return apply_multi_file(siso_word_error_rate, reference, hypothesis, allowed_empty_examples_ratio=0) -def siso_character_error_rate( - reference: str, - hypothesis: str, -) -> ErrorRate: +def siso_character_error_rate(reference: 'SegLST', hypothesis: 'SegLST') -> ErrorRate: """ >>> siso_character_error_rate('abc', 'abc') ErrorRate(error_rate=0.0, errors=0, length=3, insertions=0, deletions=0, substitutions=0) """ - return _siso_error_rate( - list(reference), list(hypothesis) - ) + reference = asseglst(reference, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + hypothesis = asseglst(hypothesis, required_keys=('words',), py_convert=lambda p: NestedStructure(p, ('segment',))) + if len(reference) != 1: + raise ValueError(f'Reference must contain exactly one line, but found {len(reference)} lines.') + if len(hypothesis) != 1: + raise ValueError(f'Hypothesis must contain exactly one line, but found {len(hypothesis)} lines.') + + def split_characters(s): + return [ + {**s, 'words': c} + for c in s['words'].strip() + if c != '' + ] + + return siso_error_rate(reference.flatmap(split_characters), hypothesis.flatmap(split_characters)) diff --git a/meeteval/wer/wer/time_constrained.py b/meeteval/wer/wer/time_constrained.py index ccb609bd..1d2f5686 100644 --- a/meeteval/wer/wer/time_constrained.py +++ b/meeteval/wer/wer/time_constrained.py @@ -6,13 +6,11 @@ from dataclasses import dataclass, replace from meeteval.io.stm import STM +from meeteval.io.seglst import SegLST, seglst_map, asseglst, SegLstSegment from meeteval.wer.wer.error_rate import ErrorRate, SelfOverlap from meeteval.wer.wer.cp import CPErrorRate -from typing import List, Dict import logging -from meeteval.wer.utils import _values, _map, _keys - logger = logging.getLogger('time_constrained') if typing.TYPE_CHECKING: @@ -28,7 +26,10 @@ class Segment(TypedDict): __all__ = [ 'time_constrained_minimum_permutation_word_error_rate', 'time_constrained_siso_word_error_rate', - 'tcp_word_error_rate_stm' + 'tcp_word_error_rate_multifile', + 'apply_collar', + 'get_pseudo_word_level_timings', + 'align', ] @@ -54,7 +55,7 @@ def equidistant_points(interval, words): return [((interval[0] + interval[1]) / 2,) * 2] interval_length = (interval[1] - interval[0]) / count - return [(interval[0] + (i + 0.5) * interval_length,) * 2 for i in range(count)] + return [(interval[0] + i * interval_length + interval_length / 2,) * 2 for i in range(count)] def character_based(interval, words): @@ -134,17 +135,24 @@ def _check_timing_annotations(t, k): def _time_constrained_siso_error_rate( - reference, hypothesis, reference_timing, hypothesis_timing, prune='auto' + reference: SegLST, hypothesis: SegLST, prune='auto' ): from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance_with_alignment + reference = reference.filter(lambda s: s['words']) + hypothesis = hypothesis.filter(lambda s: s['words']) + reference_words = reference.T['words'] + reference_timing = list(zip(reference.T['start_time'], reference.T['end_time'])) + hypothesis_words = hypothesis.T['words'] + hypothesis_timing = list(zip(hypothesis.T['start_time'], hypothesis.T['end_time'])) + result = time_constrained_levenshtein_distance_with_alignment( - reference, hypothesis, reference_timing, hypothesis_timing, prune=prune + reference_words, hypothesis_words, reference_timing, hypothesis_timing, prune=prune ) return ErrorRate( result['total'], - len(reference), + len(reference_words), insertions=result['insertions'], deletions=result['deletions'], substitutions=result['substitutions'], @@ -155,8 +163,30 @@ def _time_constrained_siso_error_rate( @dataclass class TimeMarkedTranscript: - transcript: List[str] - timings: List[typing.Tuple[float, float]] + transcript: 'list[str]' + timings: 'list[tuple[float, float]]' + + def to_seglst(self): + return SegLST([{ + 'words': transcript, + 'start_time': timing[0], + 'end_time': timing[1], + } for transcript, timing in zip(self.transcript, self.timings)]) + + @classmethod + def new(cls, d): + d = asseglst(d) + return cls( + transcript=[s['words'] for s in d], + timings=[(s['start_time'], s['end_time']) for s in d], + ) + + @classmethod + def merge(cls, *t): + return TimeMarkedTranscript( + transcript=[w for tt in t for w in tt.transcript], + timings=[timing for tt in t for timing in tt.timings] + ) def has_self_overlaps(self): last_end = 0 @@ -234,7 +264,7 @@ def from_stm(cls, stm: STM) -> 'TimeMarkedTranscript': return time_marked_transcript @classmethod - def from_segment_dicts(cls, data: 'List[Segment]') -> 'TimeMarkedTranscript': + def from_segment_dicts(cls, data: 'list[Segment]') -> 'TimeMarkedTranscript': if len(data) == 0: return cls([], []) if 'speaker' in data[0]: @@ -278,187 +308,278 @@ def _repr_pretty_(self, p, cycle): # Annotation for input -TimeMarkedTranscriptLike = 'TimeMarkedTranscript | STM | List[Segment]' +TimeMarkedTranscriptLike = 'TimeMarkedTranscript | STM | list[Segment]' + + +@seglst_map() +def apply_collar(s: SegLST, collar: float): + """ + Adds a collar to begin and end times. + Works with any format that is convertible to SegLST and back, such as `STM` and `RTTM`. -def apply_collar(s: TimeMarkedTranscript, collar: float): - return replace(s, timings=[(t[0] - collar, t[1] + collar) for t in s.timings]) + >>> apply_collar(SegLST([{'start_time': 0, 'end_time': 1}]), 1) + SegLST(segments=[{'start_time': -1, 'end_time': 2}]) + >>> print(apply_collar(STM.parse('X 1 A 0 1 a b'), 1).dumps()) + X 1 A -1 2 a b + + """ + return s.map(lambda s: {**s, 'start_time': s['start_time'] - collar, 'end_time': s['end_time'] + collar}) -def get_pseudo_word_level_timings( - s: TimeMarkedTranscript, - strategy: str, -) -> TimeMarkedTranscript: +@seglst_map() +def get_pseudo_word_level_timings(t: SegLST, strategy: str) -> SegLST: """ + Takes a transcript with segment-level annotations and outputs a transcript with estimated word-level annotations. + + Choices for `strategy`: + - `'equidistant_intervals`': Divide segment-level timing into equally sized intervals + - `'equidistant_points`': Place time points equally spaded int the segment-level intervals + - `'full_segment`': Use the full segment for each word that belongs to that segment + - `'character_based`': Estimate the word length based on the number of characters + - `'character_based_points`': Estimates the word length based on the number of characters and creates a point in the center of each word + - `'none`' or `None`: Do not estimate word-level timings but assume that the provided timings are already given on a word level. + >>> from IPython.lib.pretty import pprint - >>> s = TimeMarkedTranscript(['abc b', 'c d e f'], [(0, 4), (4, 8)]) + >>> from meeteval.io.seglst import SegLST + >>> s = SegLST([{'words': 'abc b', 'start_time': 0, 'end_time': 4}, {'words': 'c d e f', 'start_time': 4, 'end_time': 8}]) >>> pprint(get_pseudo_word_level_timings(s, 'full_segment')) - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(0, 4), (0, 4), (4, 8), (4, 8), (4, 8), (4, 8)] - ) + SegLST([{'words': 'abc', 'start_time': 0, 'end_time': 4}, + {'words': 'b', 'start_time': 0, 'end_time': 4}, + {'words': 'c', 'start_time': 4, 'end_time': 8}, + {'words': 'd', 'start_time': 4, 'end_time': 8}, + {'words': 'e', 'start_time': 4, 'end_time': 8}, + {'words': 'f', 'start_time': 4, 'end_time': 8}]) >>> pprint(get_pseudo_word_level_timings(s, 'equidistant_points')) - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(1.0, 1.0), - (3.0, 3.0), - (4.5, 4.5), - (5.5, 5.5), - (6.5, 6.5), - (7.5, 7.5)] - ) + SegLST([{'words': 'abc', 'start_time': 1.0, 'end_time': 1.0}, + {'words': 'b', 'start_time': 3.0, 'end_time': 3.0}, + {'words': 'c', 'start_time': 4.5, 'end_time': 4.5}, + {'words': 'd', 'start_time': 5.5, 'end_time': 5.5}, + {'words': 'e', 'start_time': 6.5, 'end_time': 6.5}, + {'words': 'f', 'start_time': 7.5, 'end_time': 7.5}]) >>> pprint(get_pseudo_word_level_timings(s, 'equidistant_intervals')) - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(0.0, 2.0), - (2.0, 4.0), - (4.0, 5.0), - (5.0, 6.0), - (6.0, 7.0), - (7.0, 8.0)] - ) + SegLST([{'words': 'abc', 'start_time': 0.0, 'end_time': 2.0}, + {'words': 'b', 'start_time': 2.0, 'end_time': 4.0}, + {'words': 'c', 'start_time': 4.0, 'end_time': 5.0}, + {'words': 'd', 'start_time': 5.0, 'end_time': 6.0}, + {'words': 'e', 'start_time': 6.0, 'end_time': 7.0}, + {'words': 'f', 'start_time': 7.0, 'end_time': 8.0}]) >>> word_level = get_pseudo_word_level_timings(s, 'character_based') >>> pprint(word_level) - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(0.0, 3.0), - (3.0, 4.0), - (4.0, 5.0), - (5.0, 6.0), - (6.0, 7.0), - (7.0, 8.0)] - ) + SegLST([{'words': 'abc', 'start_time': 0.0, 'end_time': 3.0}, + {'words': 'b', 'start_time': 3.0, 'end_time': 4.0}, + {'words': 'c', 'start_time': 4.0, 'end_time': 5.0}, + {'words': 'd', 'start_time': 5.0, 'end_time': 6.0}, + {'words': 'e', 'start_time': 6.0, 'end_time': 7.0}, + {'words': 'f', 'start_time': 7.0, 'end_time': 8.0}]) >>> pprint(get_pseudo_word_level_timings(word_level, 'none')) # Copies over the timings since word-level timings are already assumed - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(0.0, 3.0), - (3.0, 4.0), - (4.0, 5.0), - (5.0, 6.0), - (6.0, 7.0), - (7.0, 8.0)] - ) + SegLST([{'words': 'abc', 'start_time': 0.0, 'end_time': 3.0}, + {'words': 'b', 'start_time': 3.0, 'end_time': 4.0}, + {'words': 'c', 'start_time': 4.0, 'end_time': 5.0}, + {'words': 'd', 'start_time': 5.0, 'end_time': 6.0}, + {'words': 'e', 'start_time': 6.0, 'end_time': 7.0}, + {'words': 'f', 'start_time': 7.0, 'end_time': 8.0}]) >>> pprint(get_pseudo_word_level_timings(s, 'character_based_points')) - TimeMarkedTranscript( - transcript=['abc', 'b', 'c', 'd', 'e', 'f'], - timings=[(1.5, 1.5), - (3.5, 3.5), - (4.5, 4.5), - (5.5, 5.5), - (6.5, 6.5), - (7.5, 7.5)] - ) + SegLST([{'words': 'abc', 'start_time': 1.5, 'end_time': 1.5}, + {'words': 'b', 'start_time': 3.5, 'end_time': 3.5}, + {'words': 'c', 'start_time': 4.5, 'end_time': 4.5}, + {'words': 'd', 'start_time': 5.5, 'end_time': 5.5}, + {'words': 'e', 'start_time': 6.5, 'end_time': 6.5}, + {'words': 'f', 'start_time': 7.5, 'end_time': 7.5}]) + + Works with any format that is convertible to SegLST and back, for example STM: + >>> print(get_pseudo_word_level_timings(STM.new(s, session_id='dummy', speaker='dummy'), 'character_based_points').dumps()) + dummy 1 dummy 1.5 1.5 abc + dummy 1 dummy 3.5 3.5 b + dummy 1 dummy 4.5 4.5 c + dummy 1 dummy 5.5 5.5 d + dummy 1 dummy 6.5 6.5 e + dummy 1 dummy 7.5 7.5 f + """ pseudo_word_level_strategy = pseudo_word_level_strategies[strategy] - all_words = [] - word_level_timings = [] - - for words, interval in zip(s.transcript, s.timings): - words = words.split() # Get words form segment - segment_timings = pseudo_word_level_strategy(interval, words) - word_level_timings.extend(segment_timings) - all_words.extend(words) - assert len(words) == len(segment_timings), (words, segment_timings) + def get_words(s): + res = [] + words = s['words'].split() + if not words: # Make sure that we don't drop a speaker + words = [''] + for w, (start, end) in zip(words, pseudo_word_level_strategy((s['start_time'], s['end_time']), words)): + res.append({**s, 'words': w, 'start_time': start, 'end_time': end}) + return res - return TimeMarkedTranscript(all_words, word_level_timings) + return t.flatmap(get_words) +@seglst_map(required_keys=('start_time', 'end_time')) def remove_overlaps( - s: TimeMarkedTranscript, + t: SegLST, max_overlap: float = 0.4, warn_message: str = None, -) -> TimeMarkedTranscript: +) -> SegLST: """ Remove overlaps between words or segments in a transcript. + Note: Sorts the segments by begin time. + Args: - s: TimeMarkedTranscript + t: SegLST object to remove overlaps from max_overlap: maximum allowed relative overlap between words or segments. Raises a `ValueError` when more overlap is found. warn_message: if not None, a warning is printed when overlaps are corrected. """ - corrected_timings = [] - for t in s.timings: - if corrected_timings and corrected_timings[-1][1] > t[0]: + last: 'typing.Optional[SegLstSegment]' = None + + def correct(s): + nonlocal last + if last and last['end_time'] > s['start_time']: if warn_message is not None: import warnings warnings.warn(warn_message) - last = corrected_timings[-1] - overlap = last[1] - t[0] - if overlap > max_overlap * (t[1] - last[0]): + overlap = last['end_time'] - s['start_time'] + if overlap > max_overlap * (s['start_time'] - last['start_time']): import numpy as np raise ValueError( f'Overlapping segments exceed max allowed relative overlap. ' - f'Segment {last} overlaps with {t}. ' - f'{overlap} > {max_overlap * (t[1] - last[0])} ' - f'relative overlap: {np.divide(overlap, (t[1] - last[-1]))}' + f'Segment {last} overlaps with {s}. ' + f'{overlap} > {max_overlap * (s["end_time"] - last["start_time"])} ' + f'relative overlap: {np.divide(overlap, (s["end_time"] - last["end_time"]))}' ) - center = (last[-1] + t[0]) / 2 - corrected_timings[-1] = (last[0], center) - t = (center, t[1]) - - assert t[1] > t[0], t - assert last[1] > last[0], last - - corrected_timings.append(t) - return TimeMarkedTranscript(s.transcript, corrected_timings) - - -def sort_segments(s: TimeMarkedTranscript): - import numpy as np - order = np.argsort(np.asarray(s.timings)[:, 0]) - return TimeMarkedTranscript( - [s.transcript[int(i)] for i in order], - [s.timings[int(i)] for i in order], - ) - + center = (last['end_time'] + s['start_time']) / 2 + assert center > last['start_time'], (center, last['start_time']) + last['start_time'] = last['start_time'] + last['end_time'] = center + assert last['end_time'] > last['start_time'], last + last = s + return s -def sort_and_validate(segments, sort, pseudo_word_level_timing, name): - # Check that all timings are valid - if len(segments.transcript) != len(segments.timings): - raise ValueError( - f'Number of words does not match number of timings in {name}: ' - f'{len(segments.transcript)} != {len(segments.timings)}' - ) + return t.sorted('start_time').map(correct) - for t in segments.timings: - if t[1] < t[0]: - raise ValueError(f'The end time of an interval must be larger than the start time. Found {t} in {name}') +def sort_and_validate(segments: SegLST, sort, pseudo_word_level_timing, name): + """ + Args: + segments: + sort: How to sort words/segments. Options: + - `True`: sort by segment start time and assert that the word-level timings are sorted by start time + - `False`: do not sort and do not check word order + - `'segment'`: sort segments by start time and do not check word order + - `'word'`: sort words by start time + pseudo_word_level_timing: + name: + + >>> segments = SegLST([{'words': 'c d', 'start_time': 1, 'end_time': 3}, {'words': 'a b', 'start_time': 0, 'end_time': 3}]) + >>> sort_and_validate(segments, True, 'character_based', 'test') + Traceback (most recent call last): + ... + ValueError: The order of word-level timings contradicts the segment-level order in test: 2 of 4 times. + Consider setting sort to False or "segment" or "word". + >>> sort_and_validate(segments, False, 'character_based', 'test') + SegLST(segments=[{'words': 'c', 'start_time': 1.0, 'end_time': 2.0}, {'words': 'd', 'start_time': 2.0, 'end_time': 3.0}, {'words': 'a', 'start_time': 0.0, 'end_time': 1.5}, {'words': 'b', 'start_time': 1.5, 'end_time': 3.0}]) + >>> sort_and_validate(segments, 'segment', 'character_based', 'test') + SegLST(segments=[{'words': 'a', 'start_time': 0.0, 'end_time': 1.5}, {'words': 'b', 'start_time': 1.5, 'end_time': 3.0}, {'words': 'c', 'start_time': 1.0, 'end_time': 2.0}, {'words': 'd', 'start_time': 2.0, 'end_time': 3.0}]) + >>> sort_and_validate(segments, 'word', 'character_based', 'test') + SegLST(segments=[{'words': 'a', 'start_time': 0.0, 'end_time': 1.5}, {'words': 'c', 'start_time': 1.0, 'end_time': 2.0}, {'words': 'b', 'start_time': 1.5, 'end_time': 3.0}, {'words': 'd', 'start_time': 2.0, 'end_time': 3.0}]) + """ if sort not in (True, False, 'segment', 'word'): raise ValueError(f'Invalid value for sort: {sort}. Choose one of True, False, "segment", "word"') + for s in segments: + if s['end_time'] < s['start_time']: + raise ValueError(f'The end time of an interval must be larger than the start time. Found {s} in {name}') + if sort in (True, 'segment', 'word'): - segments = sort_segments(segments) + segments = segments.sorted('start_time') words = get_pseudo_word_level_timings(segments, pseudo_word_level_timing) # Check whether words are sorted by start time - words_sorted = sort_segments(words) - prune = True + words_sorted = words.sorted('start_time') if words_sorted != words: - contradictions = [a != b for a, b in zip(words_sorted.transcript, words.transcript)] + # This check should be fast because `sorted` doesn't change the identity of the contained objects + # (so `words_sorted[0] is words[0] == True` when they are sorted). + contradictions = [a != b for a, b in zip(words_sorted, words)] msg = ( f'The order of word-level timings contradicts the segment-level order in {name}: ' f'{sum(contradictions)} of {len(contradictions)} times.' ) if sort is not True: logger.warning(msg) - prune = False else: - raise ValueError(f'{msg}\nConsider setting sort to False or "segment" or "word".\n') + raise ValueError(f'{msg}\nConsider setting sort to False or "segment" or "word".') if sort == 'word': words = words_sorted - prune = True # Pruning always works when words are sorted by start time - return words, prune + return words + + +def get_self_overlap(d: SegLST): + """ + Returns the self-overlap of the transcript. + + ▇ + ▇ + ▇ + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 2, 'end_time': 3}]) + SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=3) + + ▇ + ▇ + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 1}] * 2) + SelfOverlap(overlap_rate=1.0, overlap_time=1, total_time=1) + + ▇ + ▇ + ▇ + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 1}] * 3) + SelfOverlap(overlap_rate=2.0, overlap_time=2, total_time=1) + + ▇▇▇▇▇▇▇▇▇▇ + ▇ + ▇ + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 10}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 3, 'end_time': 4}]) + SelfOverlap(overlap_rate=0.2, overlap_time=2, total_time=10) + + ▇▇▇▇ + ▇▇ + ▇▇▇ + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 4}, {'words': 'b', 'start_time': 1, 'end_time': 3}, {'words': 'c', 'start_time': 2, 'end_time': 5}]) + SelfOverlap(overlap_rate=0.8, overlap_time=4, total_time=5) + + >>> get_self_overlap([{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 0.5, 'end_time': 1.5}, {'words': 'c', 'start_time': 1, 'end_time': 2}]) + SelfOverlap(overlap_rate=0.5, overlap_time=1.0, total_time=2.0) + """ + d = asseglst(d, required_keys=('start_time', 'end_time'), py_convert=None) + latest_end = 0 + self_overlap = 0 + total = 0 + for t in sorted(d, key=lambda x: x['start_time']): + if latest_end > t['start_time']: + self_overlap += min(latest_end, t['end_time']) - t['start_time'] + total += max(0, t['end_time'] - latest_end) + latest_end = max(latest_end, t['end_time']) + return SelfOverlap(self_overlap, total) + + +def time_constrained_siso_levenshtein_distance(reference: 'SegLST', hypothesis: 'SegLST') -> int: + from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance + + # Ignore empty segments + reference = reference.filter(lambda s: s['words']) + hypothesis = hypothesis.filter(lambda s: s['words']) + + return time_constrained_levenshtein_distance( + reference=reference.T['words'], + hypothesis=hypothesis.T['words'], + reference_timing=list(zip(reference.T['start_time'], reference.T['end_time'])), + hypothesis_timing=list(zip(hypothesis.T['start_time'], hypothesis.T['end_time'])), + ) def time_constrained_siso_word_error_rate( - reference: TimeMarkedTranscriptLike, - hypothesis: TimeMarkedTranscriptLike, + reference: 'SegLST', + hypothesis: 'SegLST', reference_pseudo_word_level_timing='character_based', hypothesis_pseudo_word_level_timing='character_based_points', collar: int = 0, @@ -483,36 +604,41 @@ def time_constrained_siso_word_error_rate( - 'segment': sort by segment start time and don't check word order - 'word': sort by word start time - >>> time_constrained_siso_word_error_rate(TimeMarkedTranscript(['a b', 'c d'], [(0,2), (0,2)]), TimeMarkedTranscript(['a'], [(0,1)])) + >>> time_constrained_siso_word_error_rate( + ... [{'words': 'a b', 'start_time': 0, 'end_time': 2}, {'words': 'c d', 'start_time': 0, 'end_time': 2}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}]) ErrorRate(error_rate=0.75, errors=3, length=4, insertions=0, deletions=3, substitutions=0, reference_self_overlap=SelfOverlap(overlap_rate=1.0, overlap_time=2, total_time=2), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=1)) """ - reference = TimeMarkedTranscript.create(reference) - hypothesis = TimeMarkedTranscript.create(hypothesis) + # Convert to SegLST. Disallow Python conversions since there is currently no way to get the timings from a + # Python structure. + reference = asseglst(reference, required_keys=('start_time', 'end_time', 'words'), py_convert=None) + hypothesis = asseglst(hypothesis, required_keys=('start_time', 'end_time', 'words'), py_convert=None) - reference_, prune1 = sort_and_validate(reference, reference_sort, reference_pseudo_word_level_timing, 'reference') - hypothesis_, prune2 = sort_and_validate(hypothesis, hypothesis_sort, hypothesis_pseudo_word_level_timing, - 'hypothesis') + # Only single-speaker transcripts are supported, but we can here have multiple segments, e.g., for word-level + # transcripts + assert 'speaker' not in reference.T.keys() or len(reference.unique('speaker')) <= 1, 'Only single-speaker transcripts are supported' + assert 'speaker' not in hypothesis.T.keys() or len(hypothesis.unique('speaker')) <= 1, 'Only single-speaker transcripts are supported' - hypothesis = apply_collar(hypothesis, collar) + _reference = sort_and_validate(reference, reference_sort, reference_pseudo_word_level_timing, 'reference') + _hypothesis = sort_and_validate(hypothesis, hypothesis_sort, hypothesis_pseudo_word_level_timing, 'hypothesis') + _hypothesis = apply_collar(_hypothesis, collar) + + er = _time_constrained_siso_error_rate(_reference, _hypothesis) - er = _time_constrained_siso_error_rate( - reference_.transcript, hypothesis_.transcript, - reference_.timings, hypothesis_.timings, - prune=prune1 and prune2 - ) # pseudo_word_level_timing and collar change the time stamps, # hence calculate the overlap with the original time stamps er = replace( er, - reference_self_overlap=reference.get_self_overlap(), - hypothesis_self_overlap=hypothesis.get_self_overlap(), + reference_self_overlap=get_self_overlap(reference), + hypothesis_self_overlap=get_self_overlap(hypothesis), ) return er def time_constrained_minimum_permutation_word_error_rate( - reference: 'List[TimeMarkedTranscriptLike] | Dict[str, TimeMarkedTranscriptLike] | STM', - hypothesis: 'List[TimeMarkedTranscriptLike] | Dict[str, TimeMarkedTranscriptLike] | STM', + reference: 'SegLST', + hypothesis: 'SegLST', + *, reference_pseudo_word_level_timing='character_based', hypothesis_pseudo_word_level_timing='character_based_points', collar: int = 0, @@ -537,54 +663,48 @@ def time_constrained_minimum_permutation_word_error_rate( - 'segment': sort by segment start time and don't check word order - 'word': sort by word start time """ - from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance from meeteval.wer.wer.cp import _cp_error_rate - if isinstance(reference, STM): - reference = reference.grouped_by_speaker_id() - if isinstance(hypothesis, STM): - hypothesis = hypothesis.grouped_by_speaker_id() + reference = asseglst(reference, required_keys=('start_time', 'end_time', 'words', 'speaker'), py_convert=None) + hypothesis = asseglst(hypothesis, required_keys=('start_time', 'end_time', 'words', 'speaker'), py_convert=None) + + reference = reference.groupby('speaker') + hypothesis = hypothesis.groupby('speaker') - reference = _map(TimeMarkedTranscript.create, reference) - hypothesis = _map(TimeMarkedTranscript.create, hypothesis) + # Compute self-overlap for ref and hyp before converting to words and applying the collar. + # This is required later + reference_self_overlap = sum([get_self_overlap(v) for v in reference.values()]) + hypothesis_self_overlap = sum([get_self_overlap(v) for v in hypothesis.values()]) # Convert segments into lists of words and word-level timings - prune = True - reference_self_overlap = SelfOverlap(0, 0) - for k in _keys(reference): - reference_self_overlap += reference[k].get_self_overlap() - reference[k], p = sort_and_validate( - reference[k], reference_sort, reference_pseudo_word_level_timing, f'reference {k}' - ) - prune = prune and p + reference = { + k: sort_and_validate(v, reference_sort, reference_pseudo_word_level_timing, f'reference speaker "{k}"') + for k, v in reference.items() + } + hypothesis = { + k: sort_and_validate(v, hypothesis_sort, hypothesis_pseudo_word_level_timing, f'hypothesis speaker "{k}"') + for k, v in hypothesis.items() + } + + reference = SegLST.merge(*reference.values()) + hypothesis = SegLST.merge(*hypothesis.values()) - hypothesis_self_overlap = SelfOverlap(0, 0) - for k in _keys(hypothesis): - hypothesis_self_overlap += hypothesis[k].get_self_overlap() - hypothesis[k], p = sort_and_validate( - hypothesis[k], hypothesis_sort, hypothesis_pseudo_word_level_timing, f'hypothesis {k}' - ) - hypothesis[k] = apply_collar(hypothesis[k], collar) - prune = prune and p + hypothesis = apply_collar(hypothesis, collar) - sym2int = {v: i for i, v in enumerate({ - word for words in itertools.chain(_values(reference), _values(hypothesis)) for word in words.transcript - })} + # Convert into integer representation to save some computation later. `'words'` contains a single word only. + sym2int = {v: i for i, v in enumerate([ + segment['words'] for segment in itertools.chain(reference, hypothesis) + if segment['words'] + ], start=1)} + sym2int[''] = 0 - reference = _map(lambda x: TimeMarkedTranscript([sym2int[s] for s in x.transcript], x.timings), reference) - hypothesis = _map(lambda x: TimeMarkedTranscript([sym2int[s] for s in x.transcript], x.timings), hypothesis) + reference = reference.map(lambda s: {**s, 'words': sym2int[s['words']]}) + hypothesis = hypothesis.map(lambda s: {**s, 'words': sym2int[s['words']]}) er = _cp_error_rate( reference, hypothesis, - distance_fn=lambda tt, et: time_constrained_levenshtein_distance( - tt.transcript, et.transcript, tt.timings, et.timings, - prune=prune, - ), - siso_error_rate=lambda tt, et: _time_constrained_siso_error_rate( - tt.transcript, et.transcript, tt.timings, et.timings, - prune=prune, - ), - missing=TimeMarkedTranscript([], []), + distance_fn=time_constrained_siso_levenshtein_distance, + siso_error_rate=_time_constrained_siso_error_rate, ) er = replace( er, @@ -597,29 +717,29 @@ def time_constrained_minimum_permutation_word_error_rate( tcp_word_error_rate = time_constrained_minimum_permutation_word_error_rate -def tcp_word_error_rate_stm( - reference_stm: 'STM', hypothesis_stm: 'STM', +def tcp_word_error_rate_multifile( + reference, hypothesis, reference_pseudo_word_level_timing='character_based', hypothesis_pseudo_word_level_timing='character_based_points', collar: int = 0, reference_sort='segment', hypothesis_sort='segment', -) -> 'Dict[str, CPErrorRate]': +) -> 'dict[str, CPErrorRate]': """ - Computes the tcpWER for each example in the reference and hypothesis STM files. + Computes the tcpWER for each example in the reference and hypothesis files. See `time_constrained_minimum_permutation_word_error_rate` for details. - To compute the overall WER, use `sum(tcp_word_error_rate_stm(r, h).values())`. + To compute the overall WER, use `sum(tcp_word_error_rate_multifile(r, h).values())`. """ - from meeteval.io.stm import apply_stm_multi_file - r = apply_stm_multi_file(lambda r, h: time_constrained_minimum_permutation_word_error_rate( + from meeteval.io.seglst import apply_multi_file + r = apply_multi_file(lambda r, h: time_constrained_minimum_permutation_word_error_rate( r, h, reference_pseudo_word_level_timing=reference_pseudo_word_level_timing, hypothesis_pseudo_word_level_timing=hypothesis_pseudo_word_level_timing, collar=collar, reference_sort=reference_sort, hypothesis_sort=hypothesis_sort, - ), reference_stm, hypothesis_stm) + ), reference, hypothesis) return r @@ -631,7 +751,8 @@ def index_alignment_to_kaldi_alignment(alignment, reference, hypothesis, eps='*' def align( - reference: TimeMarkedTranscriptLike, hypothesis: TimeMarkedTranscriptLike, + reference: SegLST, hypothesis: SegLST, + *, reference_pseudo_word_level_timing='character_based', hypothesis_pseudo_word_level_timing='character_based_points', collar: int = 0, @@ -640,7 +761,9 @@ def align( hypothesis_sort='segment', ): """ - Align two time-marked transcripts, similar to `kaldialign.align`, but with time constriant. + Align two transcripts, similar to `kaldialign.align`, but with time constraint. + + Note that empty segments are ignored / skipped for the alignment. Args: reference: reference transcript @@ -650,8 +773,13 @@ def align( collar: collar applied to hypothesis pseudo-word level timings style: Alignment output style. Can be one of - 'words' or 'kaldi': Output in the style of `kaldialign.align` - - 'index': Output indices of the reference and hypothesis words instead of the words - - 'words_and_times': Output the time interval (pseudo-word-level, without collar) with each word + - 'index': Output indices of the reference and hypothesis words instead of the words. Empty segments are + included in the index, so aligning `('', 'a')` would give index `1` for `'a'`. Note that the + indices index words, not segments, and that word indices do not necessarily correspond to the index + of the segment in the input. If you want the indices to be valid for your input, make sure to pass + word-level timings and set `reference_pseudo_word_level_timing=None` and/or + `hypothesis_pseudo_word_level_timing=None`. + - 'seglst': Output the (seglst) segments for each word. Note: Empty segments are ignored reference_sort: How to sort the reference. Options: 'segment', 'word', True, False. See below hypothesis_sort: How to sort the reference. Options: 'segment', 'word', True, False. See below @@ -661,42 +789,116 @@ def align( - 'segment': sort by segment start time and don't check word order - 'word': sort by word start time - >>> align(TimeMarkedTranscript('a b c'.split(), [(0,1), (1,2), (2,3)]), TimeMarkedTranscript('a b c'.split(), [(0,1), (1,2), (3,4)])) + >>> from pprint import pprint + >>> align( + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 2, 'end_time': 3}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 3, 'end_time': 4}]) [('a', 'a'), ('b', 'b'), ('c', '*'), ('*', 'c')] - >>> align(TimeMarkedTranscript('a b c'.split(), [(0,1), (1,2), (2,3)]), TimeMarkedTranscript('a b c'.split(), [(0,1), (1,2), (3,4)]), collar=1) + >>> align( + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 2, 'end_time': 3}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 3, 'end_time': 4}], + ... collar=1) [('a', 'a'), ('b', 'b'), ('c', 'c')] - >>> align(TimeMarkedTranscript(['a b', 'c', 'd e'], [(0,1), (1,2), (2,3)]), TimeMarkedTranscript(['a', 'b c', 'e f'], [(0,1), (1,2), (3,4)]), collar=1) + >>> align( + ... [{'words': 'a b', 'start_time': 0, 'end_time': 1}, {'words': 'c', 'start_time': 1, 'end_time': 2}, {'words': 'd e', 'start_time': 2, 'end_time': 3}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b c', 'start_time': 1, 'end_time': 2}, {'words': 'e f', 'start_time': 3, 'end_time': 4}], collar=1) [('a', 'a'), ('b', 'b'), ('c', 'c'), ('d', '*'), ('e', 'e'), ('*', 'f')] - >>> align(TimeMarkedTranscript(['a b', 'c', 'd e'], [(0,1), (1,2), (2,3)]), TimeMarkedTranscript(['a', 'b c', 'e f'], [(0,1), (1,2), (3,4)]), collar=1, style='index') + >>> align( + ... [{'words': 'a b', 'start_time': 0, 'end_time': 1}, {'words': 'c', 'start_time': 1, 'end_time': 2}, {'words': 'd e', 'start_time': 2, 'end_time': 3}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b c', 'start_time': 1, 'end_time': 2}, {'words': 'e f', 'start_time': 3, 'end_time': 4}], + ... collar=1, style='index') [(0, 0), (1, 1), (2, 2), (3, None), (4, 3), (None, 4)] - >>> align(TimeMarkedTranscript(['a b', 'c', 'd e'], [(0,1), (1,2), (2,3)]), TimeMarkedTranscript(['a', 'b c', 'e f'], [(0,1), (1,2), (3,4)]), collar=1, style='words_and_times') - [(('a', (0.0, 0.5)), ('a', (0.5, 0.5))), (('b', (0.5, 1.0)), ('b', (1.25, 1.25))), (('c', (1, 2)), ('c', (1.75, 1.75))), (('d', (2.0, 2.5)), ('*', (2.0, 2.5))), (('e', (2.5, 3.0)), ('e', (3.25, 3.25))), (('*', (3.75, 3.75)), ('f', (3.75, 3.75)))] + >>> pprint(align( + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 2, 'end_time': 3}], + ... [{'words': 'a', 'start_time': 0, 'end_time': 1}, {'words': 'b', 'start_time': 1, 'end_time': 2}, {'words': 'c', 'start_time': 3, 'end_time': 4}], style='seglst')) + [({'end_time': 1, 'start_time': 0, 'words': 'a'}, + {'end_time': 0.5, 'start_time': 0.5, 'words': 'a'}), + ({'end_time': 2, 'start_time': 1, 'words': 'b'}, + {'end_time': 1.5, 'start_time': 1.5, 'words': 'b'}), + ({'end_time': 3, 'start_time': 2, 'words': 'c'}, None), + (None, {'end_time': 3.5, 'start_time': 3.5, 'words': 'c'})] + + Empty segments / words are ignored + >>> pprint(align( + ... [{'words': '', 'start_time': 0, 'end_time': 1}, {'words': 'a', 'start_time': 1, 'end_time': 2}], + ... [{'words': 'a', 'start_time': 1, 'end_time': 2}, {'words': '', 'start_time': 2, 'end_time': 3}] + ... )) + [('a', 'a')] + >>> pprint(align( + ... [{'words': '', 'start_time': 0, 'end_time': 1}, {'words': 'a', 'start_time': 1, 'end_time': 2}], + ... [{'words': 'a', 'start_time': 1, 'end_time': 2}, {'words': '', 'start_time': 2, 'end_time': 3}], + ... style='index')) + [(1, 0)] + + Any additional attributes are passed through when style='seglst' + >>> align([{'words': 'a', 'start_time': 0, 'end_time': 1, 'custom_data': [1, 2, 3]}], [], style='seglst') + [({'words': 'a', 'start_time': 0, 'end_time': 1, 'custom_data': [1, 2, 3]}, None)] + >>> from meeteval.io.stm import STM, STMLine + >>> pprint(align(STM([STMLine.parse('ex 1 A 0 1 a')]), STM([STMLine.parse('ex 1 B 0 1 a')]), style='seglst')) + [({'channel': 1, + 'end_time': 1, + 'session_id': 'ex', + 'speaker': 'A', + 'start_time': 0, + 'words': 'a'}, + {'channel': 1, + 'end_time': 0.5, + 'session_id': 'ex', + 'speaker': 'B', + 'start_time': 0.5, + 'words': 'a'})] """ - reference = TimeMarkedTranscript.create(reference) - hypothesis = TimeMarkedTranscript.create(hypothesis) - - reference, prune1 = sort_and_validate(reference, reference_sort, reference_pseudo_word_level_timing, 'reference') - hypothesis, prune2 = sort_and_validate(hypothesis, hypothesis_sort, hypothesis_pseudo_word_level_timing, 'hypothesis') + reference = asseglst(reference, required_keys=('start_time', 'end_time', 'words'), py_convert=None) + hypothesis = asseglst(hypothesis, required_keys=('start_time', 'end_time', 'words'), py_convert=None) + reference = sort_and_validate(reference, reference_sort, reference_pseudo_word_level_timing, 'reference') + hypothesis = sort_and_validate(hypothesis, hypothesis_sort, hypothesis_pseudo_word_level_timing, 'hypothesis') + + # Add index for tracking across filtering operations. This is only required for the index style since all other + # styles can be constructed from seglst without the index. Especially for `style = 'seglst'` we want to keep + # identity + if style == 'index': + reference = SegLST([{**s, '__align_index': i} for i, s in enumerate(reference)]) + hypothesis = SegLST([{**s, '__align_index': i} for i, s in enumerate(hypothesis)]) + + # Ignore empty segments + reference = reference.filter(lambda s: s['words']) + hypothesis = hypothesis.filter(lambda s: s['words']) hypothesis_ = apply_collar(hypothesis, collar=collar) + + # Compute the alignment with Cython code + reference_words = reference.T['words'] + reference_timing = list(zip(reference.T['start_time'], reference.T['end_time'])) + hypothesis_words = hypothesis_.T['words'] + hypothesis_timing = list(zip(hypothesis_.T['start_time'], hypothesis_.T['end_time'])) + from meeteval.wer.matching.cy_levenshtein import time_constrained_levenshtein_distance_with_alignment alignment = time_constrained_levenshtein_distance_with_alignment( - reference.transcript, hypothesis_.transcript, - reference.timings, hypothesis_.timings, - prune=prune1 and prune2, + reference_words, hypothesis_words, reference_timing, hypothesis_timing )['alignment'] - if style in ('kaldi', 'words'): - alignment = index_alignment_to_kaldi_alignment(alignment, reference.transcript, hypothesis.transcript) - elif style == 'words_and_times': - reference = list(zip(reference.transcript, reference.timings)) - hypothesis = list(zip(hypothesis.transcript, hypothesis.timings)) + # Convert "local" (relative to filtered words) indices to segments + alignment = [ + (None if a is None else reference[a], None if b is None else hypothesis[b]) + for a, b in alignment + ] + if style == 'index': + # Use the "global" (before filtering) index so that it corresponds to the input when the input + # already consists of words + alignment = [ + (None if a is None else a['__align_index'], + None if b is None else b['__align_index']) + for a, b in alignment + ] + elif style in ('kaldi', 'words'): alignment = [ - (('*', hypothesis[b][1]) if a is None else reference[a], - ('*', reference[a][1]) if b is None else hypothesis[b]) + ('*' if a is None else a['words'], + '*' if b is None else b['words']) for a, b in alignment ] + elif style == 'seglst': + pass # Already in correct format elif style != 'index': raise ValueError(f'Unknown alignment style: {style}') diff --git a/setup.py b/setup.py index 960c2c82..bc0a8df3 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ 'Cython' ], extras_require=extras_require, - package_data={'meeteval': ['**/*.pyx', '**/*.h']}, # https://stackoverflow.com/a/60751886 + package_data={'meeteval': ['**/*.pyx', '**/*.h', '**/*.js', '**/*.css']}, # https://stackoverflow.com/a/60751886 entry_points={ 'console_scripts': [ 'meeteval-wer=meeteval.wer.__main__:cli', diff --git a/tests/test_io.py b/tests/test_io.py index 0709db4e..de2d05b5 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,6 +1,11 @@ import tempfile from pathlib import Path +from hypothesis import given, strategies as st, example +import decimal + +from meeteval.io import RTTM +from meeteval.io.keyed_text import KeyedText from meeteval.io.stm import STM, STMLine from meeteval.io.ctm import CTM, CTMLine @@ -30,14 +35,22 @@ def test_ctm_load(): ctm = CTM.load(file) assert ctm.lines == [ - CTMLine(filename='7654', channel='A', begin_time=11.34, duration=0.2, word='YES', confidence='-6.763'), - CTMLine(filename='7654', channel='A', begin_time=12.0, duration=0.34, word='YOU', confidence='-12.384530'), - CTMLine(filename='7654', channel='A', begin_time=13.3, duration=0.5, word='CAN', confidence='2.806418'), - CTMLine(filename='7654', channel='A', begin_time=17.5, duration=0.2, word='AS', confidence='0.537922'), - CTMLine(filename='7654', channel='B', begin_time=1.34, duration=0.2, word='I', confidence='-6.763'), - CTMLine(filename='7654', channel='B', begin_time=2.0, duration=0.34, word='CAN', confidence='-12.384530'), - CTMLine(filename='7654', channel='B', begin_time=3.4, duration=0.5, word='ADD', confidence='2.806418'), - CTMLine(filename='7654', channel='B', begin_time=7.0, duration=0.2, word='AS', confidence='0.537922') + CTMLine(filename='7654', channel='A', begin_time=decimal.Decimal('11.34'), duration=decimal.Decimal('0.2'), + word='YES', confidence='-6.763'), + CTMLine(filename='7654', channel='A', begin_time=decimal.Decimal('12.0'), duration=decimal.Decimal('0.34'), + word='YOU', confidence='-12.384530'), + CTMLine(filename='7654', channel='A', begin_time=decimal.Decimal('13.3'), duration=decimal.Decimal('0.5'), + word='CAN', confidence='2.806418'), + CTMLine(filename='7654', channel='A', begin_time=decimal.Decimal('17.5'), duration=decimal.Decimal('0.2'), + word='AS', confidence='0.537922'), + CTMLine(filename='7654', channel='B', begin_time=decimal.Decimal('1.34'), duration=decimal.Decimal('0.2'), + word='I', confidence='-6.763'), + CTMLine(filename='7654', channel='B', begin_time=decimal.Decimal('2.0'), duration=decimal.Decimal('0.34'), + word='CAN', confidence='-12.384530'), + CTMLine(filename='7654', channel='B', begin_time=decimal.Decimal('3.4'), duration=decimal.Decimal('0.5'), + word='ADD', confidence='2.806418'), + CTMLine(filename='7654', channel='B', begin_time=decimal.Decimal('7.0'), duration=decimal.Decimal('0.2'), + word='AS', confidence='0.537922') ] @@ -58,7 +71,61 @@ def test_stm_load(): stm = STM.load(file) assert stm.lines == [ - STMLine(filename='2345', channel='A', speaker_id='2345-a', begin_time=0.1, end_time=2.03, transcript='uh huh yes i thought'), - STMLine(filename='2345', channel='A', speaker_id='2345-b', begin_time=2.1, end_time=3.04, transcript='dog walking is a very'), - STMLine(filename='2345', channel='A', speaker_id='2345-a', begin_time=3.5, end_time=4.59, transcript="yes but it's worth it") + STMLine(filename='2345', channel='A', speaker_id='2345-a', begin_time=decimal.Decimal('0.1'), + end_time=decimal.Decimal('2.03'), transcript='uh huh yes i thought'), + STMLine(filename='2345', channel='A', speaker_id='2345-b', begin_time=decimal.Decimal('2.1'), + end_time=decimal.Decimal('3.04'), transcript='dog walking is a very'), + STMLine(filename='2345', channel='A', speaker_id='2345-a', begin_time=decimal.Decimal('3.5'), + end_time=decimal.Decimal('4.59'), transcript="yes but it's worth it") ] + + +# Generate files +# The generated files don't contain comments since they cannot be reconstructed +filenames = st.text(st.characters(blacklist_categories=['Z', 'C'], blacklist_characters=';'), min_size=1) # (no space, no control char, no comment char) +speaker_ids = st.text(st.characters(blacklist_categories=['Z', 'C']), min_size=1) # (no space, no control char) +timestamps = st.decimals(allow_nan=False, allow_infinity=False) +durations = st.decimals(min_value=0, allow_nan=False, allow_infinity=False) +words = st.text(st.characters(blacklist_categories=['C', 'Z']), min_size=1) # (no space, no control char) +transcripts = st.builds(' '.join, st.lists(words)) + +keyed_text_line = st.builds('{} {}\n'.format, filenames, transcripts) +keyed_text_str = st.builds(''.join, st.lists(keyed_text_line)) + +stm_line = st.builds('{} 1 {} {} {} {}\n'.format, filenames, speaker_ids, timestamps, timestamps, transcripts) +stm_str = st.builds(''.join, st.lists(stm_line)) + +ctm_line = st.builds('{} 1 {} {} {}\n'.format, filenames, timestamps, durations, words) +ctm_str = st.builds(''.join, st.lists(ctm_line)) + +# Only SPEAKER lines are supported +rttm_line = st.builds( + 'SPEAKER {} 1 {} {} {} \n'.format, + filenames, timestamps, durations, speaker_ids +) +rttm_str = st.builds(''.join, st.lists(rttm_line)) + + +@given(keyed_text_str) +def test_reconstruct_keyed_text(keyed_text_str): + reconstructed = KeyedText.parse(keyed_text_str).dumps() + assert reconstructed == keyed_text_str, (reconstructed, keyed_text_str) + + +@given(stm_str) +def test_reconstruct_stm(stm_str): + reconstructed = STM.parse(stm_str).dumps() + assert reconstructed == stm_str, (reconstructed, stm_str) + + +@given(ctm_str) +def test_reconstruct_ctm(ctm_str): + reconstructed = CTM.parse(ctm_str.strip()).dumps() + assert reconstructed == ctm_str, (reconstructed, ctm_str) + + +@given(rttm_str) +@example('SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 juliet \n') +def test_reconstruct_rttm(rttm_str): + reconstructed = RTTM.parse(rttm_str).dumps() + assert reconstructed == rttm_str, (reconstructed, rttm_str) diff --git a/tests/test_time_constrained.py b/tests/test_time_constrained.py index 321b89d2..ffd8fcba 100644 --- a/tests/test_time_constrained.py +++ b/tests/test_time_constrained.py @@ -1,6 +1,9 @@ import pytest from hypothesis import settings, given, strategies as st +from meeteval.io.ctm import CTMGroup, CTM +from meeteval.io.seglst import SegLST + # Limit alphabet to ensure a few correct matches string = st.text(alphabet='abcdefg', min_size=0, max_size=100) @@ -113,32 +116,24 @@ def test_time_constrained_levenshtein_distance_with_alignment_against_kaldialign @given( - st.composite(lambda draw: [ - [ - draw(st.text(alphabet='abcdefg', min_size=1, max_size=3)) - for _ in range(draw(st.integers(min_value=2, max_value=10))) - ] - for _ in range(draw(st.integers(min_value=2, max_value=10))) - ])(), - st.composite(lambda draw: [ - [ - draw(st.text(alphabet='abcdefg', min_size=1, max_size=3)) - for _ in range(draw(st.integers(min_value=2, max_value=10))) - ] - for _ in range(draw(st.integers(min_value=2, max_value=10))) - ])(), + st.lists(st.lists(string, min_size=2, max_size=10), min_size=2, max_size=10), + st.lists(st.lists(string, min_size=2, max_size=10), min_size=2, max_size=10), ) @settings(deadline=None) -def test_tcpwer_vs_cpwer( - a, b -): +def test_tcpwer_vs_cpwer(a, b): from meeteval.wer.wer.time_constrained import time_constrained_minimum_permutation_word_error_rate from meeteval.wer.wer.cp import cp_word_error_rate cp_statistics = cp_word_error_rate([' '.join(speaker) for speaker in a], [' '.join(speaker) for speaker in b]) tcp_statistics = time_constrained_minimum_permutation_word_error_rate( - [[{'words': word, 'start_time': 0, 'end_time': 1} for word in speaker] for speaker in a], - [[{'words': word, 'start_time': 0, 'end_time': 1} for word in speaker] for speaker in b], + SegLST([ + {'words': word, 'start_time': 0, 'end_time': 1, 'speaker': speaker_id} + for speaker_id, speaker in enumerate(a) for word in speaker + ]), + SegLST([ + {'words': word, 'start_time': 0, 'end_time': 1, 'speaker': speaker_id} + for speaker_id, speaker in enumerate(b) for word in speaker + ]), ) from dataclasses import replace tcp_statistics = replace(tcp_statistics, reference_self_overlap=None, hypothesis_self_overlap=None) @@ -146,95 +141,94 @@ def test_tcpwer_vs_cpwer( def test_tcpwer_input_formats(): - from meeteval.wer.wer.time_constrained import time_constrained_minimum_permutation_word_error_rate, \ - TimeMarkedTranscript - from meeteval.io.stm import STM, STMLine + from meeteval.wer.wer.time_constrained import time_constrained_minimum_permutation_word_error_rate + from meeteval.io.stm import STM r1 = time_constrained_minimum_permutation_word_error_rate( - [TimeMarkedTranscript(['a'], [(0, 1)]), TimeMarkedTranscript(['b c'], [(1, 2)])], - [TimeMarkedTranscript(['a b'], [(0, 1)]), TimeMarkedTranscript(['c'], [(1, 2)])], + SegLST([ + {'words': 'a', 'start_time': 0, 'end_time': 1, 'speaker': 'A'}, + {'words': 'b c', 'start_time': 1, 'end_time': 2, 'speaker': 'B'} + ]), + SegLST([ + {'words': 'a b', 'start_time': 0, 'end_time': 1, 'speaker': 'A'}, + {'words': 'c', 'start_time': 1, 'end_time': 2, 'speaker': 'B'} + ]), ) r2 = time_constrained_minimum_permutation_word_error_rate( - [[{'words': 'a', 'start_time': 0, 'end_time': 1}], [{'words': 'b c', 'start_time': 1, 'end_time': 2}]], - [[{'words': 'a b', 'start_time': 0, 'end_time': 1}], [{'words': 'c', 'start_time': 1, 'end_time': 2}]], + STM.parse('dummy 1 A 0 1 a\ndummy 1 A 1 2 b c'), + STM.parse('dummy 1 A 0 1 a b\ndummy 1 A 1 2 c'), ) r3 = time_constrained_minimum_permutation_word_error_rate( - [ - STM([STMLine('dummy', 0, 'A', 0, 1, 'a')]), - STM([STMLine('dummy', 1, 'A', 1, 2, 'b c')]) - ], - [ - STM([STMLine('dummy', 0, 'A', 0, 1, 'a b')]), - STM([STMLine('dummy', 1, 'A', 1, 2, 'c')]) - ] - ) - r4 = time_constrained_minimum_permutation_word_error_rate( - {'A': TimeMarkedTranscript(['a'], [(0, 1)]), 'B': TimeMarkedTranscript(['b c'], [(1, 2)])}, - {'A': TimeMarkedTranscript(['a b'], [(0, 1)]), 'B': TimeMarkedTranscript(['c'], [(1, 2)])}, - ) - r5 = time_constrained_minimum_permutation_word_error_rate( - {'A': [{'words': 'a', 'start_time': 0, 'end_time': 1}], - 'B': [{'words': 'b c', 'start_time': 1, 'end_time': 2}]}, - {'A': [{'words': 'a b', 'start_time': 0, 'end_time': 1}], - 'B': [{'words': 'c', 'start_time': 1, 'end_time': 2}]}, + CTMGroup({'A': CTM.parse("dummy 1 0 1 a\ndummy 1 1 0.5 b\ndummy 1 1.5 0.5 c")}), + CTMGroup({0: CTM.parse("dummy 1 0 0.5 a\ndummy 1 0.5 0.5 b\ndummy 1 1 1 c")}) ) assert r1.error_rate == r2.error_rate assert r1.error_rate == r3.error_rate - assert r1.error_rate == r4.error_rate - assert r1.error_rate == r5.error_rate def test_time_constrained_sorting_options(): - from meeteval.wer.wer.time_constrained import time_constrained_minimum_permutation_word_error_rate, \ - TimeMarkedTranscript + from meeteval.wer.wer.time_constrained import time_constrained_minimum_permutation_word_error_rate - r1 = TimeMarkedTranscript(['a b', 'c d'], [(0, 1), (0, 1)]) + r1 = SegLST([ + {'words': 'a b', 'start_time': 0, 'end_time': 1, 'speaker': 'A'}, + {'words': 'b c', 'start_time': 0, 'end_time': 1, 'speaker': 'A'} + ]) # "True" checks whether word order matches the word-level timestamps. # Here, it doesn't match, so ValueError is raised. with pytest.raises(ValueError): time_constrained_minimum_permutation_word_error_rate( - [r1], [r1], reference_sort=True, hypothesis_sort=True + r1, r1, reference_sort=True, hypothesis_sort=True ) er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r1], reference_sort='word', hypothesis_sort='word' + r1, r1, reference_sort='word', hypothesis_sort='word' ) assert er.error_rate == 0 - r1 = TimeMarkedTranscript(['a b c d', 'e f g h'], [(0, 4), (2, 6)]) - r2 = TimeMarkedTranscript(['a b c d e f g h'], [(0, 6)]) + r1 = SegLST([ + {'words': 'a b c d', 'start_time': 0, 'end_time': 4, 'speaker': 'A'}, + {'words': 'e f g h', 'start_time': 2, 'end_time': 6, 'speaker': 'A'}, + ]) + r2 = SegLST([ + {'words': 'a b c d e f g h', 'start_time': 0, 'end_time': 6, 'speaker': 'A'}, + ]) er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort='word', + r1, r2, reference_sort='word' ) assert er.error_rate == 0.75 er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort='segment', + r1, r2, reference_sort='segment' ) assert er.error_rate == 0.75 # With collar: "segment" keeps word order, so the error becomes 0 er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort='segment', collar=1 + r1, r2, reference_sort='segment', collar=1 ) assert er.error_rate == 0 # With collar: "word" does not keep word order, so the overlap gets penalized er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort='word', collar=1 + r1, r2, reference_sort='word', collar=1 ) assert er.error_rate == 0.25 # False means the user provides the sorting, so we can pass anything - r1 = TimeMarkedTranscript(['e f g h', 'a b c d'], [(4, 8), (0, 4)]) - r2 = TimeMarkedTranscript(['a b c d e f g h'], [(0, 8)]) + r1 = SegLST([ + {'words': 'e f g h', 'start_time': 4, 'end_time': 8, 'speaker': 'A'}, + {'words': 'a b c d', 'start_time': 0, 'end_time': 4, 'speaker': 'A'}, + ]) + r2 = SegLST([ + {'words': 'a b c d e f g h', 'start_time': 0, 'end_time': 8, 'speaker': 'A'}, + ]) er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort='segment', + r1, r2, reference_sort='segment', ) assert er.error_rate == 0 er = time_constrained_minimum_permutation_word_error_rate( - [r1], [r2], reference_sort=False, hypothesis_sort=False, + r1, r2, reference_sort=False, hypothesis_sort=False, ) assert er.error_rate == 1