diff --git a/compare_mt/scorers.py b/compare_mt/scorers.py index 8714c10..6e69ec7 100644 --- a/compare_mt/scorers.py +++ b/compare_mt/scorers.py @@ -419,11 +419,12 @@ class SacreBleuScorer(Scorer): A scorer that computes BLEU on detokenized text. """ - def __init__(self, smooth_method='exp', smooth_value=0, use_effective_order=False, case_insensitive=False): + def __init__(self, smooth_method='exp', smooth_value=0, effective_order=False, case_insensitive=False): self.smooth_method = smooth_method self.smooth_value = smooth_value - self.use_effective_order = use_effective_order + self.effective_order = effective_order self.case_insensitive = case_insensitive + self.bleu = sacrebleu.BLEU() @property def scale(self): @@ -452,13 +453,10 @@ def cache_stats(self, ref, out, src=None): if self.case_insensitive: ref = corpus_utils.lower(ref) out = corpus_utils.lower(out) + ref = [' '.join(x) for x in ref] + out = [' '.join(x) for x in out] - cached_stats = [] - for r, o in zip(ref, out): - re = sacrebleu.corpus_bleu(" ".join(o), " ".join(r)) - cached_stats.append( (re.counts, re.totals, re.sys_len, re.ref_len) ) - - return cached_stats + return self.bleu._extract_corpus_statistics(out, [ref]) def score_cached_corpus(self, sent_ids, cached_stats): """ @@ -474,10 +472,14 @@ def score_cached_corpus(self, sent_ids, cached_stats): if len(cached_stats) == 0: return 0.0, None - counts, totals, sys_len, ref_len = zip(*cached_stats) - counts, totals, sys_len, ref_len = [np.sum(np.array(x)[sent_ids], 0) for x in [counts, totals, sys_len, ref_len]] + stats = np.sum(np.array(cached_stats)[list(sent_ids)],0) - return sacrebleu.compute_bleu(counts, totals, sys_len, ref_len, smooth_method=self.smooth_method, smooth_value=self.smooth_value, use_effective_order=self.use_effective_order).score, None + return self.bleu.compute_bleu(correct = stats[2: 2 + self.bleu.max_ngram_order], + total = stats[2 + self.bleu.max_ngram_order:], + sys_len = int(stats[0]), ref_len = int(stats[1]), + smooth_method=self.smooth_method, + smooth_value=self.smooth_value, + effective_order=self.effective_order).score, None def name(self): return "SacreBleuScorer" diff --git a/compare_mt/version_info.py b/compare_mt/version_info.py index 75cf783..6232f7a 100644 --- a/compare_mt/version_info.py +++ b/compare_mt/version_info.py @@ -1 +1 @@ -__version__ = "0.2.9" +__version__ = "0.2.10" diff --git a/tests/test_scorers.py b/tests/test_scorers.py index 32f31a5..d3aff4d 100644 --- a/tests/test_scorers.py +++ b/tests/test_scorers.py @@ -79,7 +79,7 @@ def setUpClass(self): def test_score_sentence(self): bleu, _ = self.scorer.score_sentence(self.ref[0], self.out[0]) # compare to nltk - self.assertAlmostEqual(bleu, 32.607099228782377) + self.assertAlmostEqual(bleu, 32.44376694160122) def test_score_corpus(self): sent_bleu_corpus, _ = self.scorer.score_corpus(self.ref, self.out)