Skip to content

Commit

Permalink
Update tests for generative multilingual
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Jul 18, 2024
1 parent e19b9e6 commit f39befc
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 27 deletions.
4 changes: 2 additions & 2 deletions scripts/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ def process_judgement(judgment: str, is_prometheus: bool = False) -> Literal["A"
# noqa adapted from FastChat https://github.com/lm-sys/FastChat/blob/b015f21cb9d0cf3c87d2a5e53008074c537e8be0/fastchat/llm_judge/common.py#L235C1-L312C1
def run_judge_pair(
question: str,
answer_a: str,
answer_b: str,
answer_a: list[dict[str, str]],
answer_b: list[dict[str, str]],
model: str,
multi_turn: bool = False,
model_modifier: str = None,
Expand Down
93 changes: 68 additions & 25 deletions tests/test_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,14 @@

def test_format_judge_answers_multilingual_includes_language():
question = "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot"
answer_a = [
{
"role": "user",
"content": "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot",
},
{
"role": "assistant",
"content": "20. Unahing i-add ang nasa loob ng parenthesis. Tapos i-multiply sa 4.",
},
]
answer_b = [
{
"role": "user",
"content": "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot",
},
{
"role": "assistant",
"content": "Ang sagot ay 20.",
},
]
ans_a = "20. Unahing i-add ang nasa loob ng parenthesis. Tapos i-multiply sa 4."
ans_b = "Ang sagot ay 20."
answer_a = [{"role": "user", "content": question}, {"role": "assistant", "content": ans_a}]
answer_b = [{"role": "user", "content": question}, {"role": "assistant", "content": ans_b}]
src_lang = "Filipino" # language the prompt is written on
tgt_lang = "English" # language the assistant should reply on
include_languages = [src_lang, tgt_lang]
sys_prompt, user_prompt = format_judge_answers(
sys_prompt, _ = format_judge_answers(
question=question,
answer_a=answer_a,
answer_b=answer_b,
Expand All @@ -49,10 +33,69 @@ def test_process_judgment_answers(judgment, expected):


@pytest.mark.api
def test_cohere_api():
pass
@pytest.mark.parametrize("multilingual", [True, False])
def test_cohere_api(multilingual):
from fastchat.conversation import get_conv_template

if multilingual:
question = "Quelle est la capitale du Japon?"
ans_a = "Tokyo"
ans_b = "La capitale du Japon est Tokyo"
include_langs = ["Japanese", "English"]
else:
question = "What is the capital of Japan?"
ans_a = "Tokyo"
ans_b = "The capital of Japan is Tokyo"
include_langs = None

sys_prompt, user_prompt = format_judge_answers(
question=question,
answer_a=[{"role": "user", "content": question}, {"role": "assistant", "content": ans_a}],
answer_b=[{"role": "user", "content": question}, {"role": "assistant", "content": ans_b}],
include_langs=include_langs,
)

conv = get_conv_template("raw")
conv.append_message(conv.roles[0], user_prompt)
conv.set_system_message(sys_prompt)
judgement = chat_completion_cohere(
conv=conv,
model="command-r",
temperature=0,
max_tokens=2048,
)

assert judgement
assert isinstance(judgement, str)


@pytest.mark.api
def test_run_judge_pair():
pass
@pytest.mark.parametrize("multilingual", [True, False])
def test_run_judge_pair(multilingual):
if multilingual:
question = "Quelle est la capitale du Japon?"
ans_a = "Tokyo"
ans_b = "La capitale du Japon est Tokyo"
include_langs = ["Japanese", "English"]
else:
question = "What is the capital of Japan?"
ans_a = "Tokyo"
ans_b = "The capital of Japan is Tokyo"
include_langs = None

answer_a = ([{"role": "user", "content": question}, {"role": "assistant", "content": ans_a}],)
answer_b = ([{"role": "user", "content": question}, {"role": "assistant", "content": ans_b}],)

winner, user_prompt, judgement = run_judge_pair(
question,
answer_a=answer_a,
answer_b=answer_b,
model="command-r",
multi_turn=False,
model_modifier=None,
include_langs=include_langs,
)

assert winner in ["A", "B", "none"]
assert judgement
assert user_prompt

0 comments on commit f39befc

Please sign in to comment.