Skip to content

Commit

Permalink
Update tests and enhanced validator to validate 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
yaniv-golan committed Oct 21, 2024
1 parent 7f71c21 commit 5febe26
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 12 deletions.
7 changes: 6 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@
"name": "Run pytest",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/-m pytest",
"program": "python",
"args": [
"-m",
"pytest",
"${workspaceFolder}/tests/python/"
],
"env": {
"PYTHONPATH": "${workspaceFolder}/tools/python"
},
"console": "integratedTerminal"
}
]
Expand Down
8 changes: 8 additions & 0 deletions tests/python/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sys
import os

# Add tools/python to sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
tools_python = os.path.join(project_root, 'tools', 'python')
sys.path.insert(0, tools_python)
66 changes: 66 additions & 0 deletions tests/python/test_stj_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import pytest
from jsonschema import validate, ValidationError, SchemaError
from stj_validator import validate_segments, validate_words

# Get the absolute path to the project root
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
Expand Down Expand Up @@ -81,3 +82,68 @@ def test_schema_error():
stj_data = {}
with pytest.raises(SchemaError):
validate(instance=stj_data, schema=invalid_schema)

def test_overlapping_segments(schema):
stj_data = {
"metadata": {
"transcriber": {"name": "Test", "version": "1.0"},
"created_at": "2023-10-21T12:00:00Z"
},
"transcript": {
"segments": [
{"start": 0.0, "end": 5.0, "text": "First segment"},
{"start": 4.5, "end": 10.0, "text": "Second segment overlaps"}
]
}
}
with pytest.raises(ValueError, match="Segments overlap or are out of order"):
validate(stj_data, schema)
validate_segments(stj_data['transcript']['segments'])

def test_invalid_word_timing_mode(schema):
stj_data = {
"metadata": {
"transcriber": {"name": "Test", "version": "1.0"},
"created_at": "2023-10-21T12:00:00Z"
},
"transcript": {
"segments": [
{
"start": 0.0,
"end": 5.0,
"text": "Hello world",
"word_timing_mode": "complete",
"words": [
{"start": 0.0, "end": 1.0, "text": "Hello"}
# Missing "world" in words array
]
}
]
}
}
with pytest.raises(ValueError, match="Concatenated words do not match segment text"):
validate(stj_data, schema)
validate_words(stj_data['transcript']['segments'][0])

def test_zero_duration_word_without_flag(schema):
stj_data = {
"metadata": {
"transcriber": {"name": "Test", "version": "1.0"},
"created_at": "2023-10-21T12:00:00Z"
},
"transcript": {
"segments": [
{
"start": 0.0,
"end": 5.0,
"text": "Zero duration word",
"words": [
{"start": 1.0, "end": 1.0, "text": "Zero"}
]
}
]
}
}
with pytest.raises(ValueError, match="Zero-duration word at 1.0 without 'word_duration' set to 'zero'"):
validate(stj_data, schema)
validate_words(stj_data['transcript']['segments'][0])
23 changes: 23 additions & 0 deletions tests/python/test_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json
import os
import subprocess
import filecmp
import tempfile
import pytest
from jsonschema import validate, ValidationError, SchemaError
from stj_validator import validate_segments, validate_words

PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))

Expand Down Expand Up @@ -85,3 +89,22 @@ def test_schema_error():
stj_data = {}
with pytest.raises(SchemaError):
validate(instance=stj_data, schema=invalid_schema)

@pytest.fixture
def base_dir():
return os.path.dirname(os.path.abspath(__file__))

def test_stj_to_srt_conversion(base_dir):
# Paths
stj_tool = os.path.abspath(os.path.join(base_dir, '..', '..', 'tools', 'python', 'stj_to_srt.py'))
stj_input = os.path.abspath(os.path.join(base_dir, '..', '..', 'examples', 'simple.stj.json'))
expected_srt = os.path.abspath(os.path.join(base_dir, '..', 'expected_outputs', 'expected_simple.srt'))

with tempfile.TemporaryDirectory() as temp_dir:
output_srt = os.path.join(temp_dir, 'output_test.srt')

# Run the conversion
subprocess.run(['python', stj_tool, stj_input, output_srt], check=True)

# Compare output with expected SRT file
assert filecmp.cmp(output_srt, expected_srt), "SRT files do not match."
1 change: 1 addition & 0 deletions tools/python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This file can be empty. It marks the directory as a Python package.
3 changes: 3 additions & 0 deletions tools/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ srt>=3.5.2
webvtt-py>=0.4.6
jsonschema>=3.2.0
pytest>=7.0.0
iso639-lang>=2.4.2
mypy>=1.0.0


172 changes: 161 additions & 11 deletions tools/python/stj_validator.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,171 @@
#!/usr/bin/env python3

import json
import jsonschema
import argparse
import sys
from iso639 import Lang
from datetime import datetime
from typing import List, Dict, Any, Optional, TypedDict

# Define TypedDicts for structured data
class Speaker(TypedDict):
id: str
name: str

class Style(TypedDict):
id: str
description: str

class Word(TypedDict):
start: float
end: float
text: str
confidence: Optional[float]

class Segment(TypedDict):
start: float
end: float
text: str
speaker_id: Optional[str]
style_id: Optional[str]
language: Optional[str]
word_timing_mode: Optional[str]
words: Optional[List[Word]]
word_duration: Optional[str]
confidence: Optional[float]

class Transcript(TypedDict):
segments: List[Segment]
speakers: Optional[List[Speaker]]
styles: Optional[List[Style]]

class Metadata(TypedDict):
transcriber: Dict[str, Any]
created_at: str
source: Optional[Dict[str, Any]]
languages: Optional[List[str]]

class STJData(TypedDict):
metadata: Metadata
transcript: Transcript

def validate_language_code(lang_code: str) -> None:
try:
Lang(lang_code)
# If the code is invalid, Lang() will raise a KeyError
except KeyError:
raise ValueError(f"Invalid language code: {lang_code}")

def validate_language_codes(language_list: List[str]) -> None:
for lang_code in language_list:
validate_language_code(lang_code)

def validate_speaker_ids(speakers: List[Speaker], segments: List[Segment]) -> None:
speaker_ids = {speaker['id'] for speaker in speakers}
for segment in segments:
speaker_id = segment.get('speaker_id')
if speaker_id and speaker_id not in speaker_ids:
raise ValueError(f"Invalid speaker_id '{speaker_id}' in segment starting at {segment['start']}")

def validate_style_ids(styles: List[Style], segments: List[Segment]) -> None:
style_ids = {style['id'] for style in styles}
for segment in segments:
style_id = segment.get('style_id')
if style_id and style_id not in style_ids:
raise ValueError(f"Invalid style_id '{style_id}' in segment starting at {segment['start']}")

def validate_segments(segments: List[Segment]) -> None:
# Implement the logic to check for overlapping segments
for i in range(len(segments) - 1):
current_end = segments[i]['end']
next_start = segments[i + 1]['start']
if next_start < current_end:
raise ValueError("Segments overlap or are out of order")

def validate_words(segment: Segment) -> None:
# Implement the logic to validate word timings and text
if segment.get('word_timing_mode') == 'complete':
concatenated_words = ''.join(word['text'] for word in segment.get('words', []))
if concatenated_words != segment['text'].replace(' ', ''):
raise ValueError("Concatenated words do not match segment text")
for word in segment.get('words', []):
if word['start'] == word['end'] and segment.get('word_duration') != 'zero':
raise ValueError(f"Zero-duration word at {word['start']} without 'word_duration' set to 'zero'")

def validate_confidence_scores(segments: List[Segment]) -> None:
for segment in segments:
confidence = segment.get('confidence')
if confidence is not None and not (0.0 <= confidence <= 1.0):
raise ValueError(f"Segment confidence {confidence} out of range [0.0, 1.0] in segment starting at {segment['start']}")
for word in segment.get('words', []):
word_confidence = word.get('confidence')
if word_confidence is not None and not (0.0 <= word_confidence <= 1.0):
raise ValueError(f"Word confidence {word_confidence} out of range [0.0, 1.0] in segment starting at {segment['start']}")

def validate_stj(stj_data: STJData, schema: Dict[str, Any]) -> None:
# First, validate against the schema
jsonschema.validate(instance=stj_data, schema=schema)
print("Schema validation passed.")

def validate_stj(stj_file, schema_file):
with open(stj_file, 'r', encoding='utf-8') as f:
# Validate metadata languages
metadata = stj_data.get('metadata', {})
source_languages = metadata.get('source', {}).get('languages', [])
transcription_languages = metadata.get('languages', [])

try:
validate_language_codes(source_languages)
validate_language_codes(transcription_languages)
except ValueError as e:
raise ValueError(f"Language code validation error: {e}")

# Validate segments
transcript = stj_data.get('transcript', {})
segments = transcript.get('segments', [])
speakers = transcript.get('speakers', [])
styles = transcript.get('styles', [])

# Validate speaker and style IDs
if speakers:
validate_speaker_ids(speakers, segments)
if styles:
validate_style_ids(styles, segments)

# Validate segments sequentially
validate_segments(segments)

# Validate each segment
for segment in segments:
validate_words(segment)
# Validate language codes in segments
segment_language = segment.get('language')
if segment_language:
validate_language_code(segment_language)

# Validate confidence scores
validate_confidence_scores(segments)

print("All validation checks passed.")

def main() -> None:
parser = argparse.ArgumentParser(description="Validate an STJ file against the schema and additional constraints.")
parser.add_argument('stj_file', help="Path to the STJ file to validate.")
parser.add_argument('schema_file', help="Path to the JSON schema file.")
args = parser.parse_args()

with open(args.stj_file, 'r', encoding='utf-8') as f:
stj_data = json.load(f)
with open(schema_file, 'r', encoding='utf-8') as f:
with open(args.schema_file, 'r', encoding='utf-8') as f:
schema = json.load(f)

try:
jsonschema.validate(instance=stj_data, schema=schema)
print(f"{stj_file} is valid according to the schema.")
validate_stj(stj_data, schema)
except jsonschema.exceptions.ValidationError as e:
print(f"Validation error: {e.message}")
print(f"Schema validation error: {e.message}")
sys.exit(1)
except ValueError as e:
print(f"Validation error: {e}")
sys.exit(1)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Validate an STJ file against the schema.")
parser.add_argument('stj_file', help="Path to the STJ file to validate.")
parser.add_argument('schema_file', help="Path to the JSON schema file.")
args = parser.parse_args()
validate_stj(args.stj_file, args.schema_file)
main()

0 comments on commit 5febe26

Please sign in to comment.