From d4aa90a8e8de930deb7981a931f6ff672ca1c9e1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 19:18:10 +0530 Subject: [PATCH 1/2] fix: removed the sacrebleu dependency --- algorithmic_efficiency/workloads/wmt/bleu.py | 366 ++++++++++++++++++- setup.cfg | 2 +- 2 files changed, 355 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index 1efc87381..dda6d102a 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,8 +1,20 @@ +""" +Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Reference: +https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. +""" + +from collections import Counter +from collections import namedtuple from itertools import zip_longest -from typing import Sequence +import logging +import math +import re +import sys +from typing import List, Sequence +import unicodedata from absl import logging -import sacrebleu import torch import torch.distributed as dist @@ -10,10 +22,340 @@ USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() +NGRAM_ORDER = 4 +# The default floor value to use with `--smooth floor` +SMOOTH_VALUE_DEFAULT = 0.0 + + +def my_log(num): + """ + Floors the log function + + :param num: the number + :return: log(num) floored to a very low number + """ + + if num == 0.0: + return -9999999999 + return math.log(num) + + +def tokenize_13a(line): + """ + Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + norm = line + + # language-independent part: + norm = norm.replace('', '') + norm = norm.replace('-\n', '') + norm = norm.replace('\n', ' ') + norm = norm.replace('"', '"') + norm = norm.replace('&', '&') + norm = norm.replace('<', '<') + norm = norm.replace('>', '>') + + # language-dependent part (assuming Western languages): + norm = " {} ".format(norm) + norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) + norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', + norm) # tokenize period and comma unless preceded by a digit + norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', + norm) # tokenize period and comma unless followed by a digit + norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', + norm) # tokenize dash when preceded by a digit + norm = re.sub(r'\s+', ' ', norm) # one space only between words + norm = re.sub(r'^\s+', '', norm) # no leading space + norm = re.sub(r'\s+$', '', norm) # no trailing space + + return norm + + +class UnicodeRegex: + """Ad-hoc hack to recognize all punctuation and symbols. + + without depending on https://pypi.python.org/pypi/regex/.""" + + def _property_chars(prefix): + return ''.join( + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix)) + + punctuation = _property_chars('P') + nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') + punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') + symbol_re = re.compile('([' + _property_chars('S') + '])') + + +def tokenize_v14_international(string): + r"""Tokenize a string following the official BLEU implementation. + + See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ + string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) + string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + return string.strip() + + +def tokenize_zh(sentence): + """MIT License + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: separate each Chinese + characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ + + def is_chinese_char(uchar): + """ + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + return True + elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + return True + elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + return True + elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + return True + elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + return True + elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + return True + elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + return True + elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + return True + elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + return True + elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + return True + elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + return True + elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + return True + elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + return True + elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + return True + elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + return True + elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + return True + elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + return True + elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + return True + elif uchar >= u'\u2600' and uchar <= u'\u26ff': + return True + elif uchar >= u'\u2700' and uchar <= u'\u27bf': + return True + elif uchar >= u'\u3200' and uchar <= u'\u32ff': + return True + elif uchar >= u'\u3300' and uchar <= u'\u33ff': + return True + + return False + + sentence = sentence.strip() + sentence_in_chars = "" + for char in sentence: + if is_chinese_char(char): + sentence_in_chars += " " + sentence_in_chars += char + sentence_in_chars += " " + else: + sentence_in_chars += char + sentence = sentence_in_chars + + # TODO: the code above could probably be replaced with the following line: + # import regex + # sentence = regex.sub(r'(\p{Han})', r' \1 ', sentence) + + # tokenize punctuation + sentence = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sentence) + + # tokenize period and comma unless preceded by a digit + sentence = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sentence) + + # tokenize period and comma unless followed by a digit + sentence = re.sub(r'([\.,])([^0-9])', r' \1 \2', sentence) + + # tokenize dash when preceded by a digit + sentence = re.sub(r'([0-9])(-)', r'\1 \2 ', sentence) + + # one space only between words + sentence = re.sub(r'\s+', r' ', sentence) + + # no leading or trailing spaces + sentence = sentence.strip() + + return sentence + + +TOKENIZERS = { + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, +} +DEFAULT_TOKENIZER = '13a' + + +def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: + """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. + + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ + + ngrams = Counter() + tokens = line.split() + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngram = ' '.join(tokens[i:i + n]) + ngrams[ngram] += 1 + + return ngrams + + +def ref_stats(output, refs): + ngrams = Counter() + closest_diff = None + closest_len = None + for ref in refs: + tokens = ref.split() + reflen = len(tokens) + diff = abs(len(output.split()) - reflen) + if closest_diff is None or diff < closest_diff: + closest_diff = diff + closest_len = reflen + elif diff == closest_diff: + if reflen < closest_len: + closest_len = reflen + + ngrams_ref = extract_ngrams(ref) + for ngram in ngrams_ref.keys(): + ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + + return ngrams, closest_diff, closest_len + + +BLEU = namedtuple('BLEU', + 'score, counts, totals, precisions, bp, sys_len, ref_len') + + +def compute_bleu(correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False) -> BLEU: + """Computes BLEU score from its sufficient statistics. Adds smoothing. + + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", + Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ + + precisions = [0 for x in range(NGRAM_ORDER)] + + smooth_mteval = 1. + effective_order = NGRAM_ORDER + for n in range(NGRAM_ORDER): + if smooth_method == 'add-k' and n > 1: + correct[n] += smooth_value + total[n] += smooth_value + if total[n] == 0: + break + + if use_effective_order: + effective_order = n + 1 + + if correct[n] == 0: + if smooth_method == 'exp': + smooth_mteval *= 2 + precisions[n] = 100. / (smooth_mteval * total[n]) + elif smooth_method == 'floor': + precisions[n] = 100. * smooth_value / total[n] + else: + precisions[n] = 100. * correct[n] / total[n] + + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). + # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed + # maximum order. It is only available through the API and off by default + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + + bleu = brevity_penalty * math.exp( + sum(map(my_log, precisions[:effective_order])) / effective_order) + + return BLEU._make( + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) + -# Modified (added sync for PyTorch DDP) from -# https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. -# Assumes that sacrebleu==1.3.1 is installed. def corpus_bleu(sys_stream: Sequence[str], ref_streams: Sequence[str], smooth_method: str = 'exp', @@ -21,7 +363,7 @@ def corpus_bleu(sys_stream: Sequence[str], force: bool = False, lowercase: bool = False, tokenize: str = '13a', - use_effective_order: bool = False) -> sacrebleu.BLEU: + use_effective_order: bool = False) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -44,8 +386,8 @@ def corpus_bleu(sys_stream: Sequence[str], sys_len = 0 ref_len = 0 - correct = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - total = [0 for _ in range(sacrebleu.NGRAM_ORDER)] + correct = [0 for _ in range(NGRAM_ORDER)] + total = [0 for _ in range(NGRAM_ORDER)] # Look for already-tokenized sentences. tokenized_count = 0 @@ -70,14 +412,14 @@ def corpus_bleu(sys_stream: Sequence[str], 'or don\'t care, you can suppress this message with ' '\'--force\'.') - output, *refs = [sacrebleu.TOKENIZERS[tokenize](x.rstrip()) for x in lines] + output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] - ref_ngrams, _, closest_len = sacrebleu.ref_stats(output, refs) + ref_ngrams, _, closest_len = ref_stats(output, refs) sys_len += len(output.split()) ref_len += closest_len - sys_ngrams = sacrebleu.extract_ngrams(output) + sys_ngrams = extract_ngrams(output) for ngram, sys_ngram in sys_ngrams.items(): n = len(ngram.split()) correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) @@ -100,7 +442,7 @@ def corpus_bleu(sys_stream: Sequence[str], dist.all_reduce(total) total = total.cpu().numpy().tolist() - return sacrebleu.compute_bleu( + return compute_bleu( correct, total, sys_len, diff --git a/setup.cfg b/setup.cfg index 2d246b48b..8e37acb7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,7 +102,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==1.3.1 + # Frameworks # # JAX Core From 5e348e4234b061f1819bddcd8d6a3b70ef9804b2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 23 Dec 2024 00:31:48 +0530 Subject: [PATCH 2/2] fix: resolving pylint errors --- algorithmic_efficiency/workloads/wmt/bleu.py | 132 ++++++++++--------- 1 file changed, 71 insertions(+), 61 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index dda6d102a..22f6a57e0 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,5 +1,6 @@ """ -Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Removing the dependency on sacrebleu, we reimplement the BLEU score computation +in this file. Reference: https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. """ @@ -42,7 +43,8 @@ def my_log(num): def tokenize_13a(line): """ - Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. :param line: a segment to tokenize :return: the tokenized line @@ -80,6 +82,7 @@ class UnicodeRegex: without depending on https://pypi.python.org/pypi/regex/.""" + @staticmethod def _property_chars(prefix): return ''.join( chr(x) @@ -95,20 +98,23 @@ def _property_chars(prefix): def tokenize_v14_international(string): r"""Tokenize a string following the official BLEU implementation. - See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 In our case, the input string is expected to be just one line and no HTML entities de-escaping is needed. So we just tokenize on punctuation and symbols, except when a punctuation is preceded and followed by a digit (e.g. a comma/dot as a thousand/decimal separator). - Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a space after each sentence). However, this error is already in the original mteval-v14.pl and we want to be consistent with it. The error is not present in the non-international version, - which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). :param string: the input string :return: a list of tokens @@ -123,26 +129,28 @@ def tokenize_zh(sentence): """MIT License Copyright (c) 2017 - Shujian Huang - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - The tokenization of Chinese text in this script contains two steps: separate each Chinese - characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). Author: Shujian Huang huangsj@nju.edu.cn :param sentence: input sentence @@ -151,54 +159,53 @@ def tokenize_zh(sentence): def is_chinese_char(uchar): """ - :param uchar: input char in unicode - :return: whether the input char is a Chinese character. - """ - if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if "\u3400" <= uchar <= "\u4db5": return True - elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + elif "\u4e00" <= uchar <= "\u9fa5": return True - elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + elif "\u9fa6" <= uchar <= "\u9fbb": return True - elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + elif "\uf900" <= uchar <= "\ufa2d": return True - elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + elif "\ufa30" <= uchar <= "\ufa6a": return True - elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + elif "\ufa70" <= uchar <= "\ufad9": return True - elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + elif "\u20000" <= uchar <= "\u2a6d6": return True - elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + elif "\u2f800" <= uchar <= "\u2fa1d": return True - elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + elif "\uff00" <= uchar <= "\uffef": return True - elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + elif "\u2e80" <= uchar <= "\u2eff": return True - elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + elif "\u3000" <= uchar <= "\u303f": return True - elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + elif "\u31c0" <= uchar <= "\u31ef": return True - elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + elif "\u2f00" <= uchar <= "\u2fdf": return True - elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + elif "\u2ff0" <= uchar <= "\u2fff": return True - elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + elif "\u3100" <= uchar <= "\u312f": return True - elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + elif "\u31a0" <= uchar <= "\u31bf": return True - elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + elif "\ufe10" <= uchar <= "\ufe1f": return True - elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + elif "\ufe30" <= uchar <= "\ufe4f": return True - elif uchar >= u'\u2600' and uchar <= u'\u26ff': + elif "\u2600" <= uchar <= "\u26ff": return True - elif uchar >= u'\u2700' and uchar <= u'\u27bf': + elif "\u2700" <= uchar <= "\u27bf": return True - elif uchar >= u'\u3200' and uchar <= u'\u32ff': + elif "\u3200" <= uchar <= "\u32ff": return True - elif uchar >= u'\u3300' and uchar <= u'\u33ff': + elif "\u3300" <= uchar <= "\u33ff": return True - return False sentence = sentence.strip() @@ -280,13 +287,13 @@ def ref_stats(output, refs): closest_len = reflen ngrams_ref = extract_ngrams(ref) - for ngram in ngrams_ref.keys(): + for ngram in ngrams_ref: ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) return ngrams, closest_diff, closest_len -BLEU = namedtuple('BLEU', +BLEU = namedtuple('BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len') @@ -299,8 +306,9 @@ def compute_bleu(correct: List[int], use_effective_order=False) -> BLEU: """Computes BLEU score from its sufficient statistics. Adds smoothing. - Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", - Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) - exp: NIST smoothing method (Method 3) - floor: Method 1 @@ -312,7 +320,7 @@ def compute_bleu(correct: List[int], :param sys_len: The cumulative system length :param ref_len: The cumulative reference length :param smooth: The smoothing method to use - :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param smooth_value: The smoothing value added, if smooth is 'floor' :param use_effective_order: Use effective order. :return: A BLEU object with the score (100-based) and other statistics. """ @@ -340,10 +348,12 @@ def compute_bleu(correct: List[int], else: precisions[n] = 100. * correct[n] / total[n] - # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). - # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit - # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed - # maximum order. It is only available through the API and off by default + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU + # score is 0 (technically undefined). This is a problem for sentence-level + # BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales + # NGRAM_ORDER to the observed maximum order. + # It is only available through the API and off by default brevity_penalty = 1.0 if sys_len < ref_len: @@ -374,7 +384,7 @@ def corpus_bleu(sys_stream: Sequence[str], :param force: Ignore data that looks already tokenized. :param lowercase: Lowercase the data. :param tokenize: The tokenizer to use. - :return: A BLEU object containing everything you'd want. + :return: A BLEU object containing everything yo'd want. """ # Add some robustness to the input arguments.