Skip to content

Commit

Permalink
fix (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
shenxiangzhuang authored Apr 22, 2024
1 parent 215f0b1 commit 87ab215
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
8 changes: 5 additions & 3 deletions benchmark/test_benchmark_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time

from typing import Tuple, List
from hypothesis import given
from hypothesis import given, settings
from hypothesis import strategies as st

import bleuscore
Expand Down Expand Up @@ -32,6 +32,7 @@ def build_translation_pair(text: str, n: int = 10) -> Tuple[List[str], List[List
return predictions, references


@settings(max_examples=100)
@given(st.text(alphabet=st.characters(min_codepoint=32, max_codepoint=126),
min_size=10, max_size=20))
def test_bleu(input_text):
Expand All @@ -55,5 +56,6 @@ def test_bleu(input_text):
print(rust_result)
rust_result = rust_result.get("bleu")
t2 = time.time()
print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1))
assert (py_result - rust_result) < 1e-10
# print(t1 - t0, t2 - t1, (t1 - t0) > (t2 - t1))
print(py_result, rust_result, abs(py_result - rust_result))
assert abs(py_result - rust_result) < 1e-10
34 changes: 29 additions & 5 deletions src/bleu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ pub fn compute_bleu(
matches_by_order[ngram.len() - 1] += overlap[ngram]
}
for order in 1..=max_order {
let possible_matches = translation.len() - order + 1;
let possible_matches = translation.len().saturating_sub(order - 1);
if possible_matches > 0 {
// println!("Order: {order}");
possible_matches_by_order[order - 1] += possible_matches
}
}
Expand All @@ -50,6 +51,7 @@ pub fn compute_bleu(
match smooth {
true => {
precisions[i] = (matches_by_order[i] as f64 + 1.0) / (possible_matches_by_order[i] as f64 + 1.0);
// println!("precision[i]: {i}, {} / {} = {}", matches_by_order[i] as f64 + 1.0, possible_matches_by_order[i] as f64 + 1.0, precisions[i]);
},
false => {
if possible_matches_by_order[i] > 0 {
Expand Down Expand Up @@ -84,12 +86,34 @@ mod test {
use crate::bleu::{compute_bleu};
#[test]
fn test_bleu() {
let reference_corpus: Vec<Vec<String>> = vec![vec!["Hello".to_string()]];
let translation_corpus: Vec<String> = vec!["Yellow".to_string()];
let reference_corpus: Vec<Vec<String>> = vec![vec!["Hello, World!".to_string()]];
let translation_corpus: Vec<String> = vec!["Yellow, World!".to_string()];
let max_order: usize = 4;
let smooth: bool = true;
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
// (0.6147881529512643, [0.7142857142857143, 0.6666666666666666, 0.6, 0.5], 1.0, 1.2, 6, 5)
assert_eq!((res.bleu - 0.6147881529512643) < 1e-10, true);
// (0.7241577342575828, [0.8666666666666667, 0.7857142857142857, 0.6923076923076923, 0.5833333333333334], 1.0, 1.0769230769230769, 14, 13)
println!("BLEU: {:?}", res);
assert_eq!((res.bleu - 0.7241577342575828).abs() < 1e-10, true);
}

#[test]
fn test_bleu_error() {
let reference_corpus: Vec<Vec<String>> = vec![
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
vec!["0000000000".to_string()],
];
let translation_corpus: Vec<String> = vec!["000000".to_string(),
"00000".to_string(),
"0000000000".to_string(),
"00".to_string()
];
let max_order: usize = 4;
let smooth: bool = true;
let res = compute_bleu(reference_corpus, translation_corpus, max_order, smooth);
// (0.7241577342575828, [0.8666666666666667, 0.7857142857142857, 0.6923076923076923, 0.5833333333333334], 1.0, 1.0769230769230769, 14, 13)
println!("BLEU: {:?}", res);
assert_eq!((res.bleu - 0.47752897762233404).abs() < 1e-10, true);
}
}

0 comments on commit 87ab215

Please sign in to comment.