Skip to content

Commit

Permalink
Add tests and update codebase to make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Jul 18, 2024
1 parent 432cbc7 commit de0ddc8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
api: marks tests as calling a third-party API (deselect with '-m "not api"')
10 changes: 5 additions & 5 deletions scripts/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@
GEMINI_MODEL_LIST = ("gemini-1.5-flash-001", "gemini-1.5-pro-001")

# https://docs.cohere.com/docs/models
COHERE_MODEL_LIST = [
COHERE_MODEL_LIST = (
"command-r-plus",
"command-r",
"command",
"command-nightly",
"command-light",
"command-light-nightly",
]
)

API_MODEL_LIST = (
OPENAI_MODEL_LIST + ANTHROPIC_MODEL_LIST + TOGETHER_MODEL_LIST + COHERE_MODEL_LIST
Expand Down Expand Up @@ -263,8 +263,8 @@
# format with prompt_template.format(question=question, answer_a=answer_a, answer_b=answer_b)
def format_judge_answers(
question: str,
answer_a: str,
answer_b: str,
answer_a: list[dict[str, str]],
answer_b: list[dict[str, str]],
multi_turn: bool = False,
model_modifier: str = None,
include_langs: Optional[Iterable[str]] = None,
Expand Down Expand Up @@ -350,7 +350,7 @@ def process_judgement(
) -> Literal["A", "B", "error"]:
if is_prometheus:
if "[RESULT]" in judgment:
# after [RESULT] is A or B, else error (mayube spaces)
# after [RESULT] is A or B, else error (maybe spaces)
# result = judgment.split("[RESULT]")[1].strip()
if judgment[-1] == "A":
return "A"
Expand Down
Empty file added tests/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/test_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from scripts.generative import chat_completion_anthropic
from scripts.generative import chat_completion_cohere, chat_completion_gemini
from scripts.generative import chat_completion_together, format_judge_answers
from scripts.generative import process_judgement, run_judge_pair


def test_format_judge_answers_multilingual_includes_language():
question = "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot"
answer_a = [
{
"role": "user",
"content": "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot",
},
{
"role": "assistant",
"content": "20. Unahing i-add ang nasa loob ng parenthesis. Tapos i-multiply sa 4.",
},
]
answer_b = [
{
"role": "user",
"content": "Ano ang sagot sa (2+3) * 4? Ipaliwanag ang iyong sagot",
},
{
"role": "assistant",
"content": "Ang sagot ay 20.",
},
]
src_lang = "Filipino" # language the prompt is written on
tgt_lang = "English" # language the assistant should reply on
include_languages = [src_lang, tgt_lang]
sys_prompt, user_prompt = format_judge_answers(
question=question,
answer_a=answer_a,
answer_b=answer_b,
include_langs=include_languages,
)

assert src_lang in sys_prompt
assert tgt_lang in sys_prompt

0 comments on commit de0ddc8

Please sign in to comment.