diff --git a/codebleu/codebleu.py b/codebleu/codebleu.py index f6d4bb3..76aa6b3 100644 --- a/codebleu/codebleu.py +++ b/codebleu/codebleu.py @@ -34,6 +34,7 @@ def calc_codebleu( Return: Scores dict """ + assert len(references) == len(predictions), "Number of references and predictions should be the same" assert lang in AVAILABLE_LANGS, f"Language {lang} is not supported (yet). Available languages: {AVAILABLE_LANGS}" assert len(weights) == 4, "weights should be a tuple of 4 floats (alpha, beta, gamma, theta)" assert keywords_dir.exists(), f"keywords_dir {keywords_dir} does not exist" @@ -43,9 +44,6 @@ def calc_codebleu( references = [[x.strip() for x in ref] if isinstance(ref, list) else [ref.strip()] for ref in references] hypothesis = [x.strip() for x in predictions] - if not len(references) == len(hypothesis): - raise ValueError - # calculate ngram match (BLEU) if tokenizer is None: diff --git a/setup.py b/setup.py index c5b7ca5..1b0e285 100644 --- a/setup.py +++ b/setup.py @@ -7,9 +7,10 @@ ROOT = Path(__file__).parent -subprocess.check_call( +subprocess.run( ['bash', 'build.sh'], - cwd=ROOT / 'codebleu' / 'parser' + cwd=ROOT / 'codebleu' / 'parser', + check=True, ) diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 8af38b4..1102415 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -25,6 +25,31 @@ def test_exact_match_works_for_all_langs(lang: str) -> None: assert calc_codebleu(references, predictions, lang)['codebleu'] == 1.0 +@pytest.mark.parametrize(['lang', 'predictions', 'references'], [ + ('python', ['def foo ( x ) :\n return x'], ['def bar ( y ) :\n return y']), + ('java', ['public function foo ( x ) { return x }'], ['public function bar ( y ) {\n return y\n}']), + ('javascript', ['function foo ( x ) { return x }'], ['function bar ( y ) {\n return y\n}']), + ('c', ['int foo ( int x ) { return x }'], ['int bar ( int y ) {\n return y\n}']), + ('c_sharp', ['public int foo ( int x ) { return x }'], ['public int bar ( int y ) {\n return y\n}']), + ('cpp', ['int foo ( int x ) { return x }'], ['int bar ( int y ) {\n return y\n}']), + ('php', ['function foo ( x ) { return x }'], ['function bar ( y ) {\n return y\n}']), +]) +def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None: + result = calc_codebleu(references, predictions, lang) + print(result) + assert result['codebleu'] == pytest.approx(0.6, 0.1) + + +def test_error_when_lang_not_supported() -> None: + with pytest.raises(AssertionError): + calc_codebleu(['def foo : pass'], ['def bar : pass'], 'not_supported_lang') + + +def test_error_when_input_length_mismatch() -> None: + with pytest.raises(AssertionError): + calc_codebleu(['def foo : pass'], ['def bar : pass', 'def buz : pass'], 'python') + + @pytest.mark.parametrize(['predictions', 'references', 'codebleu'], [ ( ['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'],