From dedc4b758fcb8208a3d1c485f6bb1847188dd4de Mon Sep 17 00:00:00 2001 From: Quan Wang Date: Tue, 24 Sep 2024 21:31:43 -0400 Subject: [PATCH] Use the C++ implementation of Levenshtein instead of the Python one. metrics_test.py running time reduced from 136.061s to 0.551s --- DiarizationLM/diarizationlm/__init__.py | 2 -- DiarizationLM/diarizationlm/levenshtein.py | 6 +++++- DiarizationLM/diarizationlm/levenshtein_test.py | 11 ++++++++++- DiarizationLM/diarizationlm/metrics.py | 2 +- DiarizationLM/diarizationlm/utils.py | 2 +- DiarizationLM/requirements.txt | 1 + DiarizationLM/setup.py | 2 +- 7 files changed, 19 insertions(+), 7 deletions(-) diff --git a/DiarizationLM/diarizationlm/__init__.py b/DiarizationLM/diarizationlm/__init__.py index 4e4bb1b..075558c 100644 --- a/DiarizationLM/diarizationlm/__init__.py +++ b/DiarizationLM/diarizationlm/__init__.py @@ -1,10 +1,8 @@ """__init__ file.""" -from . import levenshtein from . import utils from . import metrics -levenshtein_with_edits = levenshtein.levenshtein_with_edits PromptOptions = utils.PromptOptions transcript_preserving_speaker_transfer = ( utils.transcript_preserving_speaker_transfer) diff --git a/DiarizationLM/diarizationlm/levenshtein.py b/DiarizationLM/diarizationlm/levenshtein.py index 3f6aedd..8731c95 100644 --- a/DiarizationLM/diarizationlm/levenshtein.py +++ b/DiarizationLM/diarizationlm/levenshtein.py @@ -1,4 +1,8 @@ -"""Function for the Levenshtein algorithm.""" +"""Function for the Levenshtein algorithm. + +Note: This Python implementation is very inefficient. Please use this C++ +implementation instead: https://github.com/wq2012/word_levenshtein +""" import numpy as np from enum import Enum diff --git a/DiarizationLM/diarizationlm/levenshtein_test.py b/DiarizationLM/diarizationlm/levenshtein_test.py index bfa3718..bdf0cb9 100644 --- a/DiarizationLM/diarizationlm/levenshtein_test.py +++ b/DiarizationLM/diarizationlm/levenshtein_test.py @@ -1,4 +1,4 @@ -"""Test levenshterin.""" +"""Test levenshtein.""" import unittest from diarizationlm import levenshtein @@ -21,6 +21,15 @@ def test_levenshtein_with_edits_2(self): [(0, 0), (1, -1), (2, 1), (-1, 2), (3, 3), (4, 4), (5, 5)], align[1]) + def test_levenshtein_with_edit_multiple_spaces(self): + s1 = " hello good morning how are you" + s2 = "hello morning hi how are you " + align = levenshtein.levenshtein_with_edits(s1, s2) + self.assertEqual(2, align[0]) + self.assertListEqual( + [(0, 0), (1, -1), (2, 1), (-1, 2), (3, 3), (4, 4), (5, 5)], + align[1]) + if __name__ == "__main__": unittest.main() diff --git a/DiarizationLM/diarizationlm/metrics.py b/DiarizationLM/diarizationlm/metrics.py index b8dfec5..f99209e 100644 --- a/DiarizationLM/diarizationlm/metrics.py +++ b/DiarizationLM/diarizationlm/metrics.py @@ -28,7 +28,7 @@ from scipy import optimize import tqdm from diarizationlm import utils -from diarizationlm import levenshtein +import word_levenshtein as levenshtein @dataclasses.dataclass diff --git a/DiarizationLM/diarizationlm/utils.py b/DiarizationLM/diarizationlm/utils.py index f1c52ba..606e3a7 100644 --- a/DiarizationLM/diarizationlm/utils.py +++ b/DiarizationLM/diarizationlm/utils.py @@ -9,7 +9,7 @@ import numpy as np from scipy import optimize -from diarizationlm import levenshtein +import word_levenshtein as levenshtein PUNCTUATIONS = [",", ".", "_", "?", "!", "-", '"', "'"] diff --git a/DiarizationLM/requirements.txt b/DiarizationLM/requirements.txt index e3d4815..0c87557 100644 --- a/DiarizationLM/requirements.txt +++ b/DiarizationLM/requirements.txt @@ -6,3 +6,4 @@ openai datasets tqdm colortimelog +word_levenshtein diff --git a/DiarizationLM/setup.py b/DiarizationLM/setup.py index 671cf8a..d20bc74 100644 --- a/DiarizationLM/setup.py +++ b/DiarizationLM/setup.py @@ -2,7 +2,7 @@ import setuptools -VERSION = "0.0.11" +VERSION = "0.1.0" with open("README.md", "r") as file_object: LONG_DESCRIPTION = file_object.read()