Skip to content

Commit

Permalink
More tests on diarization metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed Jul 6, 2024
1 parent b6fd96c commit 6ee177c
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions DiarizationLM/diarizationlm/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for metrics."""

import json
import os
from diarizationlm import metrics
import unittest

Expand Down Expand Up @@ -68,11 +70,48 @@ def test_compute_metrics_on_json_dict(self):
]
}
result = metrics.compute_metrics_on_json_dict(json_dict)
self.assertEqual(len(result["utterances"]), 2)
self.assertEqual(result["utterances"][0]["utterance_id"], "utt1")
self.assertEqual(result["utterances"][1]["utterance_id"], "utt2")
self.assertAlmostEqual(result["WER"], 0.2666, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.1538, delta=0.001)

def test_compute_metrics_on_json_file(self):
json_file = os.path.join("testdata/example_data.json")
with open(json_file, "r") as f:
json_dict = json.load(f)
result = metrics.compute_metrics_on_json_dict(json_dict)
self.assertEqual(len(result["utterances"]), 2)
self.assertEqual(result["utterances"][0]["utterance_id"], "en_0638")
self.assertEqual(result["utterances"][1]["utterance_id"], "en_4157")
self.assertAlmostEqual(result["WER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.0437, delta=0.001)

def test_compute_metrics_on_json_file_oracle(self):
json_file = os.path.join("testdata/example_data.json")
with open(json_file, "r") as f:
json_dict = json.load(f)
result = metrics.compute_metrics_on_json_dict(
json_dict,
hyp_spk_field="hyp_spk_oracle")
self.assertEqual(len(result["utterances"]), 2)
self.assertEqual(result["utterances"][0]["utterance_id"], "en_0638")
self.assertEqual(result["utterances"][1]["utterance_id"], "en_4157")
self.assertAlmostEqual(result["WER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.0, delta=0.001)

def test_compute_metrics_on_json_file_degraded(self):
json_file = os.path.join("testdata/example_data.json")
with open(json_file, "r") as f:
json_dict = json.load(f)
result = metrics.compute_metrics_on_json_dict(
json_dict,
ref_spk_field="ref_spk_degraded")
self.assertEqual(len(result["utterances"]), 2)
self.assertEqual(result["utterances"][0]["utterance_id"], "en_0638")
self.assertEqual(result["utterances"][1]["utterance_id"], "en_4157")
self.assertAlmostEqual(result["WER"], 0.2363, delta=0.001)
self.assertAlmostEqual(result["WDER"], 0.0, delta=0.001)

if __name__ == "__main__":
unittest.main()

0 comments on commit 6ee177c

Please sign in to comment.