Skip to content

Commit

Permalink
Add new classmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
thequilo committed Dec 7, 2023
1 parent 57232f5 commit 3cb8b06
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
12 changes: 4 additions & 8 deletions meeteval/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import typing
from pathlib import Path
import contextlib
from typing import Dict, List, NamedTuple
from typing import Dict, List
import dataclasses
from dataclasses import dataclass
from itertools import groupby

if typing.TYPE_CHECKING:
from typing import Self
from meeteval.io.seglst import SegLstSegment, SegLSTMixin
from meeteval.io.uem import UEM, UEMLine
from meeteval.io.stm import STM, STMLine
from meeteval.io.ctm import CTM, CTMLine
Expand Down Expand Up @@ -119,16 +120,11 @@ def replace(self, **kwargs) -> 'Self':
return dataclasses.replace(self, **kwargs)


class Base:
class Base(SegLSTMixin):
lines: 'List[LineSubclasses]'
line_cls = 'LineSubclasses'

def __init__(self, data, **defaults):
if isinstance(data, self.__class__):
self.lines = data.lines
elif hasattr(data, 'to_seglst'):
self.lines = [self.line_cls.from_seglst({**defaults, **segment}) for segment in data.to_seglst()]
else:
def __init__(self, data):
self.lines = data

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions meeteval/io/ctm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'CTMGroup',
]

from meeteval.io.seglst import SegLstSegment, SegLST
from meeteval.io.seglst import SegLstSegment, SegLST, SegLSTMixin


@dataclass(frozen=True)
Expand Down Expand Up @@ -128,7 +128,7 @@ def from_seglst(cls, s: 'SegLST', **defaults) -> 'Self':


@dataclass(frozen=True)
class CTMGroup:
class CTMGroup(SegLSTMixin):
ctms: 'Dict[str, CTM]'

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions meeteval/io/pbjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from pathlib import Path

from meeteval.io.seglst import SegLST
from meeteval.io.seglst import SegLST, SegLSTMixin


def _load_json(file):
Expand Down Expand Up @@ -42,7 +42,7 @@ def get_sample_rate(ex):
return sample_rate


class PBJson:
class PBJson(SegLSTMixin):
"""
The JSON format used at the NT department internally for storing databases.
Expand Down
4 changes: 2 additions & 2 deletions meeteval/io/py.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if typing.TYPE_CHECKING:
from typing import Self
from meeteval.io.seglst import SegLST
from meeteval.io.seglst import SegLST, SegLSTMixin


def _convert_python_structure(structure, *, keys=(), final_key='words', final_types=str):
Expand Down Expand Up @@ -120,7 +120,7 @@ def _invert_python_structure(t: 'SegLST', types, keys):


@dataclasses.dataclass(frozen=True)
class NestedStructure:
class NestedStructure(SegLSTMixin):
"""
Wraps a Python structure where the structure levels represent keys.
Expand Down
14 changes: 14 additions & 0 deletions meeteval/io/seglst.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ class SegLstSegment(TypedDict, total=False):
confidence: float


class SegLSTMixin:
@classmethod
def new(cls, d, **defaults):
return cls.from_seglst(asseglst(d, **defaults))

@classmethod
def from_seglst(cls, d: 'SegLST', **defaults):
raise NotImplementedError()

def to_seglst(self):
raise NotImplementedError()



@dataclasses.dataclass(frozen=True)
class SegLST:
"""
Expand Down

0 comments on commit 3cb8b06

Please sign in to comment.