Skip to content

Commit

Permalink
chore: small style fixes and some small tests
Browse files Browse the repository at this point in the history
  • Loading branch information
k4black committed Jul 13, 2023
1 parent 89ae3db commit a61aaba
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
4 changes: 1 addition & 3 deletions codebleu/codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def calc_codebleu(
Return:
Scores dict
"""
assert len(references) == len(predictions), "Number of references and predictions should be the same"
assert lang in AVAILABLE_LANGS, f"Language {lang} is not supported (yet). Available languages: {AVAILABLE_LANGS}"
assert len(weights) == 4, "weights should be a tuple of 4 floats (alpha, beta, gamma, theta)"
assert keywords_dir.exists(), f"keywords_dir {keywords_dir} does not exist"
Expand All @@ -43,9 +44,6 @@ def calc_codebleu(
references = [[x.strip() for x in ref] if isinstance(ref, list) else [ref.strip()] for ref in references]
hypothesis = [x.strip() for x in predictions]

if not len(references) == len(hypothesis):
raise ValueError

# calculate ngram match (BLEU)
if tokenizer is None:

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
ROOT = Path(__file__).parent


subprocess.check_call(
subprocess.run(
['bash', 'build.sh'],
cwd=ROOT / 'codebleu' / 'parser'
cwd=ROOT / 'codebleu' / 'parser',
check=True,
)


Expand Down
25 changes: 25 additions & 0 deletions tests/test_codebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,31 @@ def test_exact_match_works_for_all_langs(lang: str) -> None:
assert calc_codebleu(references, predictions, lang)['codebleu'] == 1.0


@pytest.mark.parametrize(['lang', 'predictions', 'references'], [
('python', ['def foo ( x ) :\n return x'], ['def bar ( y ) :\n return y']),
('java', ['public function foo ( x ) { return x }'], ['public function bar ( y ) {\n return y\n}']),
('javascript', ['function foo ( x ) { return x }'], ['function bar ( y ) {\n return y\n}']),
('c', ['int foo ( int x ) { return x }'], ['int bar ( int y ) {\n return y\n}']),
('c_sharp', ['public int foo ( int x ) { return x }'], ['public int bar ( int y ) {\n return y\n}']),
('cpp', ['int foo ( int x ) { return x }'], ['int bar ( int y ) {\n return y\n}']),
('php', ['function foo ( x ) { return x }'], ['function bar ( y ) {\n return y\n}']),
])
def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None:
result = calc_codebleu(references, predictions, lang)
print(result)
assert result['codebleu'] == pytest.approx(0.6, 0.1)


def test_error_when_lang_not_supported() -> None:
with pytest.raises(AssertionError):
calc_codebleu(['def foo : pass'], ['def bar : pass'], 'not_supported_lang')


def test_error_when_input_length_mismatch() -> None:
with pytest.raises(AssertionError):
calc_codebleu(['def foo : pass'], ['def bar : pass', 'def buz : pass'], 'python')


@pytest.mark.parametrize(['predictions', 'references', 'codebleu'], [
(
['public static int Sign ( double d ) { return ( float ) ( ( d == 0 ) ? 0 : ( c < 0.0 ) ? - 1 : 1) ; }'],
Expand Down

0 comments on commit a61aaba

Please sign in to comment.