Skip to content

Commit

Permalink
Add SegLST format (#44)
Browse files Browse the repository at this point in the history
* Add tidy format

* Use tidy format for cpWER

* Use tidy format for tcpWER

* Update tidy data

* Update Tidy format

* Rename module io.tidy -> io.seglst

* Rename types and methods (tidy -> seglst) and fix doctests

* Use SegLST in testcases

* Use seglst.apply_multi_file for CLI

* Remove apply_stm_multi_file

* Add new classmethod

* Rename PBJson -> PBJsonUtt

* Rename from_seglst -> new and SegLSTMixin -> BaseABC

* Fix flake8 errors

* Use decimal where possible

* Fix typing.Self imports

* Add fallback when TypedDict is not available

* Use old blacklist_categories instead of exclude_categories for Py3.7

* Fix new implementations and example

* Implement __radd__ for SelfOverlap

- Py <3.8 doesn't have sum(start=...) arg, so we have to overwrite __radd__

* Rename from_seglst -> from_dict

* Use builtin types for type annotations instead of types from typing

* Disallow conversion to list/tuple with non-integer keys

* Sort segments in remove_overlaps

* Make final_types arg private

* Remove caching from SegLST.keys

* SegLST docs

* Rename BaseLine.to_seglst -> to_seglst_segment

* Handle and remove TODOs

- A few todos are still present. They will be addressed in later PRs

* Add .T (transpose) to SegLST and replace .keys with .T.keys()

* Handle emtpy segments correctly in align

* Define _repr_pretty_ for SegLST to get a consistent pprint across Python versions
  • Loading branch information
thequilo authored Jan 5, 2024
1 parent aa06431 commit 244f6e4
Show file tree
Hide file tree
Showing 22 changed files with 1,820 additions and 657 deletions.
62 changes: 50 additions & 12 deletions meeteval/io/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import abc
import io
import os
import sys
import typing
from pathlib import Path
import contextlib
from typing import Dict, List, NamedTuple
import dataclasses
from dataclasses import dataclass
from itertools import groupby
import decimal

if typing.TYPE_CHECKING:
from typing import Self
from meeteval.io.seglst import SegLstSegment, SegLST
from meeteval.io.uem import UEM, UEMLine
from meeteval.io.stm import STM, STMLine
from meeteval.io.ctm import CTM, CTMLine
Expand All @@ -20,12 +22,33 @@
Subclasses = 'UEM | STM | CTM | RTTM'



class BaseABC:
@classmethod
def new(cls, d, **defaults):
# Example code:
# from meeteval.io.seglst import asseglst
# seglst = asseglst(d).map(lambda s: {**defaults, **s})
# ... (convert seglst to cls)
raise NotImplementedError(cls)

def to_seglst(self):
raise NotImplementedError()


@dataclass(frozen=True)
class BaseLine:
@classmethod
def parse(cls, line: str) -> 'Self':
raise NotImplementedError(cls)

@classmethod
def from_dict(cls, segment: 'SegLstSegment') -> 'Self':
raise NotImplementedError(cls)

def to_seglst_segment(self) -> 'SegLstSegment':
raise NotImplementedError(self)

def serialize(self):
raise NotImplementedError(type(self))

Expand Down Expand Up @@ -112,26 +135,31 @@ def replace(self, **kwargs) -> 'Self':
return dataclasses.replace(self, **kwargs)


@dataclass(frozen=True)
class Base:
lines: 'List[LineSubclasses]'
class Base(BaseABC):
lines: 'list[LineSubclasses]'
line_cls = 'LineSubclasses'

@classmethod
def _load(cls, file_descriptor, parse_float) -> 'List[Self.line_cls]':
raise NotImplementedError()
def __init__(self, data):
self.lines = data

@classmethod
def load(cls, file: [Path, str, io.TextIOBase, tuple, list], parse_float=float) -> 'Self':
def load(cls, file: [Path, str, io.TextIOBase, tuple, list], parse_float=decimal.Decimal) -> 'Self':
files = file if isinstance(file, (tuple, list)) else [file]

parsed_lines = []
for f in files:
with _open(f, 'r') as fd:
parsed_lines.extend(cls._load(fd, parse_float=parse_float))
parsed_lines.extend(cls.parse(fd.read(), parse_float=parse_float))

return cls(parsed_lines)

@classmethod
def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self':
# Many of the supported file-formats have different conventions for comments.
# Below is an example for a file that doesn't have comments.
# return cls([cls.line_cls.parse(line) for line in s.splitlines() if line.strip()])
raise NotImplementedError

def _repr_pretty_(self, p, cycle):
name = self.__class__.__name__
with p.group(len(name) + 1, name + '(', ')'):
Expand Down Expand Up @@ -167,7 +195,7 @@ def __add__(self, other):
return NotImplemented
return self.__class__(self.lines + other.lines)

def groupby(self, key) -> Dict[str, 'Self']:
def groupby(self, key) -> 'dict[str, Self]':
"""
>>> from meeteval.io.stm import STM, STMLine
>>> stm = STM([STMLine.parse('rec1 0 A 10 20 Hello World')])
Expand All @@ -191,7 +219,7 @@ def groupby(self, key) -> Dict[str, 'Self']:
)
}

def grouped_by_filename(self) -> Dict[str, 'Self']:
def grouped_by_filename(self) -> 'dict[str, Self]':
return self.groupby(lambda x: x.filename)

def grouped_by_speaker_id(self):
Expand Down Expand Up @@ -318,6 +346,16 @@ def cut_by_uem(self: 'Subclasses', uem: 'UEM', verbose=False):
def filenames(self):
return {x.filename for x in self.lines}

def to_seglst(self) -> 'SegLST':
from meeteval.io.seglst import SegLST
return SegLST([l.to_seglst_segment() for l in self.lines])

@classmethod
def new(cls, s, **defaults) -> 'Self':
from meeteval.io.seglst import asseglst
return cls([cls.line_cls.from_dict({**defaults, **segment}) for segment in asseglst(s)])



def _open(f, mode='r'):
if isinstance(f, io.TextIOBase):
Expand All @@ -328,7 +366,7 @@ def _open(f, mode='r'):
raise TypeError(type(f), f)


def load(file, parse_float=float):
def load(file, parse_float=decimal.Decimal):
import meeteval
file = Path(file)
if file.suffix == '.stm':
Expand Down
94 changes: 73 additions & 21 deletions meeteval/io/ctm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import typing
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import NamedTuple
from meeteval.io.base import Base, BaseLine
from typing import Optional
from meeteval.io.base import Base, BaseLine, BaseABC
import decimal

if typing.TYPE_CHECKING:
from typing import Self
from meeteval.io.seglst import SegLstSegment, SegLST

__all__ = [
'CTMLine',
Expand All @@ -12,6 +16,8 @@
]




@dataclass(frozen=True)
class CTMLine(BaseLine):
"""
Expand All @@ -30,13 +36,13 @@ class CTMLine(BaseLine):
"""
filename: str
channel: 'str | int'
begin_time: float
duration: float
begin_time: 'decimal.Decimal | float'
duration: 'decimal.Decimal | float'
word: str
confidence: Optional[int] = None

@classmethod
def parse(cls, line: str, parse_float=float) -> 'CTMLine':
def parse(cls, line: str, parse_float=decimal.Decimal) -> 'CTMLine':
try:
# CB: Should we disable the support for missing confidence?
filename, channel, begin_time, duration, word, *confidence = line.strip().split()
Expand All @@ -49,7 +55,6 @@ def parse(cls, line: str, parse_float=float) -> 'CTMLine':
word,
confidence[0] if confidence else None
)
assert ctm_line.begin_time >= 0, ctm_line
assert ctm_line.duration >= 0, ctm_line
except Exception as e:
raise ValueError(f'Unable to parse CTM line: {line}') from e
Expand All @@ -61,42 +66,80 @@ def serialize(self):
>>> line.serialize()
'rec1 0 10 2 Hello 1'
"""
return (f'{self.filename} {self.channel} {self.begin_time} '
f'{self.duration} {self.word} {self.confidence}')
s = f'{self.filename} {self.channel} {self.begin_time} {self.duration} {self.word}'
if self.confidence is not None:
s += f' {self.confidence}'
return s

@classmethod
def from_dict(cls, segment: 'SegLstSegment') -> 'Self':
# CTM only supports words as segments.
# If this check fails, the input data was not converted to words before.
assert ' ' not in segment['words'], segment
return cls(
filename=segment['session_id'],
channel=segment['channel'],
begin_time=segment['start_time'],
duration=segment['end_time'] - segment['start_time'],
word=segment['words'],
confidence=segment.get('confidence', None),
)

def to_seglst_segment(self) -> 'SegLstSegment':
d = {
'session_id': self.filename,
'channel': self.channel,
'start_time': self.begin_time,
'end_time': self.begin_time + self.duration,
'words': self.word,
}
if self.confidence is not None:
d['confidence'] = self.confidence
return d


@dataclass(frozen=True)
class CTM(Base):
lines: List[CTMLine]
line_cls = CTMLine
lines: 'list[CTMLine]'
line_cls = 'CTMLine'

@classmethod
def _load(cls, file_descriptor, parse_float) -> 'List[CTMLine]':
return [
def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self':
return cls([
CTMLine.parse(line, parse_float=parse_float)
for line in map(str.strip, file_descriptor)
for line in map(str.strip, s.split('\n'))
if len(line) > 0
if not line.startswith(';;')
]
])

def merged_transcripts(self) -> str:
return ' '.join([x.word for x in sorted(self.lines, key=lambda x: x.begin_time)])

def utterance_transcripts(self) -> List[str]:
def utterance_transcripts(self) -> 'list[str]':
"""There is no notion of an "utterance" in CTM files."""
raise NotImplementedError()

@classmethod
def new(cls, s, **defaults) -> 'Self':
# CTM only supports a single speaker. Use CTMGroup to represent multiple speakers with this format.
if len(s.unique('speaker')) > 1:
raise ValueError(
f'CTM only supports a single speaker, but found {len(s.unique("speaker"))} speakers '
f'({s.unique("speaker")}). Use CTMGroup to represent multiple speakers with this format.'
)
return super().new(s, **defaults)


@dataclass(frozen=True)
class CTMGroup:
ctms: 'Dict[str, CTM]'
class CTMGroup(BaseABC):
ctms: 'dict[str, CTM]'

@classmethod
def load(cls, ctm_files, parse_float=float):
def load(cls, ctm_files, parse_float=decimal.Decimal):
return cls({str(ctm_file): CTM.load(ctm_file, parse_float=parse_float)
for ctm_file in ctm_files})

def grouped_by_filename(self) -> Dict[str, 'CTMGroup']:
def grouped_by_filename(self) -> 'dict[str, CTMGroup]':
groups = {
k: ctm.grouped_by_filename() for k, ctm in self.ctms.items()
}
Expand Down Expand Up @@ -128,9 +171,18 @@ def grouped_by_filename(self) -> Dict[str, 'CTMGroup']:
for key in keys
}

def grouped_by_speaker_id(self) -> Dict[str, CTM]:
def grouped_by_speaker_id(self) -> 'dict[str, CTM]':
return self.ctms

@classmethod
def new(cls, s: 'SegLST', **defaults) -> 'Self':
from meeteval.io.seglst import asseglst
return cls({k: CTM.new(v) for k, v in asseglst(s).map(lambda s: {**defaults, **s}).groupby('speaker').items()})

def to_seglst(self) -> 'SegLST':
from meeteval.io.seglst import SegLST
return SegLST.merge(*[ctm.to_seglst().map(lambda x: {**x, 'speaker': speaker}) for speaker, ctm in self.ctms.items()])

def to_stm(self):
from meeteval.io import STM, STMLine
stm = []
Expand Down
38 changes: 30 additions & 8 deletions meeteval/io/keyed_text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import typing
from dataclasses import dataclass
from typing import List

from meeteval.io.base import BaseLine, Base
from meeteval.io.seglst import SegLstSegment
import decimal


if typing.TYPE_CHECKING:
from typing import Self


@dataclass(frozen=True)
Expand All @@ -10,7 +16,7 @@ class KeyedTextLine(BaseLine):
transcript: str

@classmethod
def parse(cls, line: str, parse_float=float) -> 'KeyedTextLine':
def parse(cls, line: str, parse_float=decimal.Decimal) -> 'KeyedTextLine':
"""
>>> KeyedTextLine.parse("key a transcript")
KeyedTextLine(filename='key', transcript='a transcript')
Expand All @@ -29,24 +35,40 @@ def parse(cls, line: str, parse_float=float) -> 'KeyedTextLine':
transcript = ''
return cls(filename, transcript)

def serialize(self):
return f'{self.filename} {self.transcript}'

@classmethod
def from_dict(cls, segment: 'SegLstSegment') -> 'Self':
return cls(
filename=segment['session_id'],
transcript=segment['words'],
)

def to_seglst_segment(self) -> 'SegLstSegment':
return {
'session_id': self.filename,
'words': self.transcript,
}


@dataclass(frozen=True)
class KeyedText(Base):
lines: List[KeyedTextLine]
lines: 'list[KeyedTextLine]'
line_cls = KeyedTextLine

@classmethod
def _load(cls, file_descriptor, parse_float) -> 'List[KeyedTextLine]':
return [
def parse(cls, s: str, parse_float=decimal.Decimal) -> 'Self':
return cls([
KeyedTextLine.parse(line, parse_float=parse_float)
for line in map(str.strip, file_descriptor)
for line in map(str.strip, s.split('\n'))
if len(line) > 0
# if not line.startswith(';;')
]
])

def merged_transcripts(self) -> str:
raise NotImplementedError()

def utterance_transcripts(self) -> List[str]:
def utterance_transcripts(self) -> 'list[str]':
"""There is no notion of an "utterance" in CTM files."""
raise NotImplementedError()
Loading

0 comments on commit 244f6e4

Please sign in to comment.