Skip to content

Commit

Permalink
Use the C++ implementation of Levenshtein instead of the Python one.
Browse files Browse the repository at this point in the history
metrics_test.py running time reduced from 136.061s to 0.551s
  • Loading branch information
wq2012 committed Sep 25, 2024
1 parent e8f40bb commit dedc4b7
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 7 deletions.
2 changes: 0 additions & 2 deletions DiarizationLM/diarizationlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""__init__ file."""

from . import levenshtein
from . import utils
from . import metrics

levenshtein_with_edits = levenshtein.levenshtein_with_edits
PromptOptions = utils.PromptOptions
transcript_preserving_speaker_transfer = (
utils.transcript_preserving_speaker_transfer)
Expand Down
6 changes: 5 additions & 1 deletion DiarizationLM/diarizationlm/levenshtein.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Function for the Levenshtein algorithm."""
"""Function for the Levenshtein algorithm.
Note: This Python implementation is very inefficient. Please use this C++
implementation instead: https://github.com/wq2012/word_levenshtein
"""
import numpy as np
from enum import Enum

Expand Down
11 changes: 10 additions & 1 deletion DiarizationLM/diarizationlm/levenshtein_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Test levenshterin."""
"""Test levenshtein."""
import unittest
from diarizationlm import levenshtein

Expand All @@ -21,6 +21,15 @@ def test_levenshtein_with_edits_2(self):
[(0, 0), (1, -1), (2, 1), (-1, 2), (3, 3), (4, 4), (5, 5)],
align[1])

def test_levenshtein_with_edit_multiple_spaces(self):
s1 = " hello good morning how are you"
s2 = "hello morning hi how are you "
align = levenshtein.levenshtein_with_edits(s1, s2)
self.assertEqual(2, align[0])
self.assertListEqual(
[(0, 0), (1, -1), (2, 1), (-1, 2), (3, 3), (4, 4), (5, 5)],
align[1])


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion DiarizationLM/diarizationlm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from scipy import optimize
import tqdm
from diarizationlm import utils
from diarizationlm import levenshtein
import word_levenshtein as levenshtein


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion DiarizationLM/diarizationlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from scipy import optimize

from diarizationlm import levenshtein
import word_levenshtein as levenshtein

PUNCTUATIONS = [",", ".", "_", "?", "!", "-", '"', "'"]

Expand Down
1 change: 1 addition & 0 deletions DiarizationLM/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ openai
datasets
tqdm
colortimelog
word_levenshtein
2 changes: 1 addition & 1 deletion DiarizationLM/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import setuptools

VERSION = "0.0.11"
VERSION = "0.1.0"

with open("README.md", "r") as file_object:
LONG_DESCRIPTION = file_object.read()
Expand Down

0 comments on commit dedc4b7

Please sign in to comment.