diff --git a/meeteval/io/base.py b/meeteval/io/base.py index 31765ef1..80ef16c3 100644 --- a/meeteval/io/base.py +++ b/meeteval/io/base.py @@ -4,13 +4,14 @@ import typing from pathlib import Path import contextlib -from typing import Dict, List, NamedTuple +from typing import Dict, List import dataclasses from dataclasses import dataclass from itertools import groupby if typing.TYPE_CHECKING: from typing import Self + from meeteval.io.seglst import SegLstSegment, SegLSTMixin from meeteval.io.uem import UEM, UEMLine from meeteval.io.stm import STM, STMLine from meeteval.io.ctm import CTM, CTMLine @@ -119,16 +120,11 @@ def replace(self, **kwargs) -> 'Self': return dataclasses.replace(self, **kwargs) -class Base: +class Base(SegLSTMixin): lines: 'List[LineSubclasses]' line_cls = 'LineSubclasses' - def __init__(self, data, **defaults): - if isinstance(data, self.__class__): - self.lines = data.lines - elif hasattr(data, 'to_seglst'): - self.lines = [self.line_cls.from_seglst({**defaults, **segment}) for segment in data.to_seglst()] - else: + def __init__(self, data): self.lines = data @classmethod diff --git a/meeteval/io/ctm.py b/meeteval/io/ctm.py index 31115e33..9e45377c 100644 --- a/meeteval/io/ctm.py +++ b/meeteval/io/ctm.py @@ -13,7 +13,7 @@ 'CTMGroup', ] -from meeteval.io.seglst import SegLstSegment, SegLST +from meeteval.io.seglst import SegLstSegment, SegLST, SegLSTMixin @dataclass(frozen=True) @@ -128,7 +128,7 @@ def from_seglst(cls, s: 'SegLST', **defaults) -> 'Self': @dataclass(frozen=True) -class CTMGroup: +class CTMGroup(SegLSTMixin): ctms: 'Dict[str, CTM]' @classmethod diff --git a/meeteval/io/pbjson.py b/meeteval/io/pbjson.py index 08f40ed6..3a0a5572 100644 --- a/meeteval/io/pbjson.py +++ b/meeteval/io/pbjson.py @@ -8,7 +8,7 @@ import json from pathlib import Path -from meeteval.io.seglst import SegLST +from meeteval.io.seglst import SegLST, SegLSTMixin def _load_json(file): @@ -42,7 +42,7 @@ def get_sample_rate(ex): return sample_rate -class PBJson: +class PBJson(SegLSTMixin): """ The JSON format used at the NT department internally for storing databases. diff --git a/meeteval/io/py.py b/meeteval/io/py.py index 58eee761..de53b4c0 100644 --- a/meeteval/io/py.py +++ b/meeteval/io/py.py @@ -5,7 +5,7 @@ if typing.TYPE_CHECKING: from typing import Self - from meeteval.io.seglst import SegLST + from meeteval.io.seglst import SegLST, SegLSTMixin def _convert_python_structure(structure, *, keys=(), final_key='words', final_types=str): @@ -120,7 +120,7 @@ def _invert_python_structure(t: 'SegLST', types, keys): @dataclasses.dataclass(frozen=True) -class NestedStructure: +class NestedStructure(SegLSTMixin): """ Wraps a Python structure where the structure levels represent keys. diff --git a/meeteval/io/seglst.py b/meeteval/io/seglst.py index ae1b5e85..d0334a2c 100644 --- a/meeteval/io/seglst.py +++ b/meeteval/io/seglst.py @@ -27,6 +27,20 @@ class SegLstSegment(TypedDict, total=False): confidence: float +class SegLSTMixin: + @classmethod + def new(cls, d, **defaults): + return cls.from_seglst(asseglst(d, **defaults)) + + @classmethod + def from_seglst(cls, d: 'SegLST', **defaults): + raise NotImplementedError() + + def to_seglst(self): + raise NotImplementedError() + + + @dataclasses.dataclass(frozen=True) class SegLST: """