Skip to content

Commit

Permalink
Fix sacrebleu and NLTK interfaces (#132)
Browse files Browse the repository at this point in the history
* Update interface to sacrebleu

* Deal with update to NLTK

* Bump version
  • Loading branch information
neubig authored Feb 5, 2022
1 parent 4d2fc94 commit 9b0bfe2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
24 changes: 13 additions & 11 deletions compare_mt/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,12 @@ class SacreBleuScorer(Scorer):
A scorer that computes BLEU on detokenized text.
"""
def __init__(self, smooth_method='exp', smooth_value=0, use_effective_order=False, case_insensitive=False):
def __init__(self, smooth_method='exp', smooth_value=0, effective_order=False, case_insensitive=False):
self.smooth_method = smooth_method
self.smooth_value = smooth_value
self.use_effective_order = use_effective_order
self.effective_order = effective_order
self.case_insensitive = case_insensitive
self.bleu = sacrebleu.BLEU()

@property
def scale(self):
Expand Down Expand Up @@ -452,13 +453,10 @@ def cache_stats(self, ref, out, src=None):
if self.case_insensitive:
ref = corpus_utils.lower(ref)
out = corpus_utils.lower(out)
ref = [' '.join(x) for x in ref]
out = [' '.join(x) for x in out]

cached_stats = []
for r, o in zip(ref, out):
re = sacrebleu.corpus_bleu(" ".join(o), " ".join(r))
cached_stats.append( (re.counts, re.totals, re.sys_len, re.ref_len) )

return cached_stats
return self.bleu._extract_corpus_statistics(out, [ref])

def score_cached_corpus(self, sent_ids, cached_stats):
"""
Expand All @@ -474,10 +472,14 @@ def score_cached_corpus(self, sent_ids, cached_stats):
if len(cached_stats) == 0:
return 0.0, None

counts, totals, sys_len, ref_len = zip(*cached_stats)
counts, totals, sys_len, ref_len = [np.sum(np.array(x)[sent_ids], 0) for x in [counts, totals, sys_len, ref_len]]
stats = np.sum(np.array(cached_stats)[list(sent_ids)],0)

return sacrebleu.compute_bleu(counts, totals, sys_len, ref_len, smooth_method=self.smooth_method, smooth_value=self.smooth_value, use_effective_order=self.use_effective_order).score, None
return self.bleu.compute_bleu(correct = stats[2: 2 + self.bleu.max_ngram_order],
total = stats[2 + self.bleu.max_ngram_order:],
sys_len = int(stats[0]), ref_len = int(stats[1]),
smooth_method=self.smooth_method,
smooth_value=self.smooth_value,
effective_order=self.effective_order).score, None

def name(self):
return "SacreBleuScorer"
Expand Down
2 changes: 1 addition & 1 deletion compare_mt/version_info.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.9"
__version__ = "0.2.10"
2 changes: 1 addition & 1 deletion tests/test_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def setUpClass(self):
def test_score_sentence(self):
bleu, _ = self.scorer.score_sentence(self.ref[0], self.out[0])
# compare to nltk
self.assertAlmostEqual(bleu, 32.607099228782377)
self.assertAlmostEqual(bleu, 32.44376694160122)

def test_score_corpus(self):
sent_bleu_corpus, _ = self.scorer.score_corpus(self.ref, self.out)
Expand Down

0 comments on commit 9b0bfe2

Please sign in to comment.