Skip to content

Commit

Permalink
Improve parity on development with RewardBench
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Jul 18, 2024
1 parent de0ddc8 commit e19b9e6
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 78 deletions.
13 changes: 13 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.PHONY: style quality

# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := scripts tests

style:
python -m black --target-version py310 --line-length 119 $(check_dirs)
python -m isort $(check_dirs) --profile black -m 9

quality:
python -m flake8 --max-line-length 119 $(check_dirs)
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,28 @@ python -m scripts/run_generative.py \
--num_gpus 4 \
+ --include_languages German English
--output_dir $OUTDIR
```


## Testing and Development

This codebase contains minimal tests, mostly we test functions that were added or patched from RewardBench.
First, you need to install all the development dependencies:

```sh
pip install -r requirements-dev.txt
```

Then, you can run the tests by:

```sh
pytest tests/
pytest tests/ -m "not api" # to ignore tests that make use of third-party APIs
```

When developing, we format the code using [black](https://black.readthedocs.io/en/stable/index.html) and [isort](https://pycqa.github.io/isort/), to be consistent with the RewardBench codebase.
You can automatically format your code by running:

```
make style
```
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.black]
line-length = 119
target-version = ['py310']
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
-r requirements.txt
pytest
pytest
black
isort
5 changes: 1 addition & 4 deletions scripts/convert_multilingual_uf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@


def get_args():
parser = argparse.ArgumentParser(
description="Convert a HuggingFace dataset into the RewardBench format."
)

# fmt: off
parser = argparse.ArgumentParser(description="Convert a HuggingFace dataset into the RewardBench format.")
parser.add_argument("--dataset", type=str, default="nthakur/multilingual-ultrafeedback-dpo-v0.1", help="Dataset to convert.")
parser.add_argument("--output_path", type=Path, default="data/multilingual-ultrafeedback-dpo-v0.1.json", help="Path to save converted dataset as JSON file.")
parser.add_argument("--en", action="store_true", help="Use the english columns.")
Expand Down
40 changes: 10 additions & 30 deletions scripts/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@
"command-light-nightly",
)

API_MODEL_LIST = (
OPENAI_MODEL_LIST + ANTHROPIC_MODEL_LIST + TOGETHER_MODEL_LIST + COHERE_MODEL_LIST
)
API_MODEL_LIST = OPENAI_MODEL_LIST + ANTHROPIC_MODEL_LIST + TOGETHER_MODEL_LIST + COHERE_MODEL_LIST


# API setting constants
Expand Down Expand Up @@ -285,9 +283,7 @@ def format_judge_answers(
system_prompt = (
REL_SYSTEM_PROMPT
if not include_langs
else M_REL_SYSTEM_PROMPT.format(
src_lang=include_langs[0], tgt_lang=include_langs[1]
)
else M_REL_SYSTEM_PROMPT.format(src_lang=include_langs[0], tgt_lang=include_langs[1])
)
user_prompt = RELATIVE_PROMPT.format(
orig_instruction=question,
Expand All @@ -302,9 +298,7 @@ def format_judge_answers(
system_prompt = (
MTBENCH_MULTI_V2["system_prompt"]
if not include_langs
else m_prompt_multi_v2.format(
src_lang=include_langs[0], tgt_lang=include_langs[1]
)
else m_prompt_multi_v2.format(src_lang=include_langs[0], tgt_lang=include_langs[1])
)
user_prompt = MTBENCH_MULTI_V2["prompt_template"].format(
question_1=question,
Expand All @@ -319,9 +313,7 @@ def format_judge_answers(
system_prompt = (
MTBENCH_V2["system_prompt"]
if not include_langs
else m_prompt_v2.format(
src_lang=include_langs[0], tgt_lang=include_langs[1]
)
else m_prompt_v2.format(src_lang=include_langs[0], tgt_lang=include_langs[1])
)
user_prompt = MTBENCH_V2["prompt_template"].format(
question=question,
Expand All @@ -335,19 +327,15 @@ def format_judge_answers(
prefix = (
prompt_v2_gemini
if not include_langs
else m_prompt_v2_gemini.format(
src_lang=include_langs[0], tgt_lang=include_langs[1]
)
else m_prompt_v2_gemini.format(src_lang=include_langs[0], tgt_lang=include_langs[1])
)
user_prompt = prefix + user_prompt
system_prompt = None

return system_prompt, user_prompt


def process_judgement(
judgment: str, is_prometheus: bool = False
) -> Literal["A", "B", "error"]:
def process_judgement(judgment: str, is_prometheus: bool = False) -> Literal["A", "B", "error"]:
if is_prometheus:
if "[RESULT]" in judgment:
# after [RESULT] is A or B, else error (maybe spaces)
Expand Down Expand Up @@ -394,9 +382,7 @@ def run_judge_pair(
winners = []
judgments = []
for m in model:
winner, _, judgment = run_judge_pair(
question, answer_a, answer_b, m, multi_turn
)
winner, _, judgment = run_judge_pair(question, answer_a, answer_b, m, multi_turn)
winners.append(winner)
judgments.append(judgment)
return winners, user_prompt, judgments
Expand All @@ -419,9 +405,7 @@ def run_judge_pair(
conv.append_message(conv.roles[1], None)
conv.messages = conv.to_openai_api_messages()

judgment = chat_completion_anthropic(
model, conv, temperature=0, max_tokens=1024
)
judgment = chat_completion_anthropic(model, conv, temperature=0, max_tokens=1024)
elif model in GEMINI_MODEL_LIST:
text = user_prompt
judgment = chat_completion_gemini(model, text, temperature=0, max_tokens=4096)
Expand Down Expand Up @@ -506,9 +490,7 @@ def chat_completion_gemini(
max_output_tokens=max_tokens,
temperature=temperature,
),
request_options={
"timeout": 1000
}, # eliminate Failed to connect to Gemini API: 504 Deadline Exceeded
request_options={"timeout": 1000}, # eliminate Failed to connect to Gemini API: 504 Deadline Exceeded
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
Expand All @@ -529,9 +511,7 @@ def chat_completion_gemini(
# If the response doesn't contain text, check if the prompt was blocked.
print(f"Prompt feedback {response.prompt_feedback}")
# Also check the finish reason to see if the response was blocked.
print(
f"Finish reason {response.candidates[0].finish_reason}"
) # 5 is "unknown reason"
print(f"Finish reason {response.candidates[0].finish_reason}") # 5 is "unknown reason"
# If the finish reason was SAFETY, the safety ratings have more details.
print(f"Safety ratings {response.candidates[0].safety_ratings}")
else:
Expand Down
53 changes: 13 additions & 40 deletions scripts/run_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

from scripts.generative import ANTHROPIC_MODEL_LIST, API_MODEL_LIST
from scripts.generative import GEMINI_MODEL_LIST, OPENAI_MODEL_LIST
from scripts.generative import format_judge_answers, process_judgement
from scripts.generative import run_judge_pair
from scripts.generative import ANTHROPIC_MODEL_LIST, API_MODEL_LIST, GEMINI_MODEL_LIST
from scripts.generative import OPENAI_MODEL_LIST, format_judge_answers
from scripts.generative import process_judgement, run_judge_pair

# get token from HF_TOKEN env variable, but if it doesn't exist pass none
HF_TOKEN = os.getenv("HF_TOKEN", None)
Expand Down Expand Up @@ -88,9 +87,7 @@ def main():
log_level = logging.INFO
logger.setLevel(log_level)

logger.info(
f"Running reward model on {args.model} with chat template {args.chat_template}"
)
logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}")

model_type = "Generative RM"

Expand All @@ -103,11 +100,7 @@ def main():
assert len(args.model) % 2 == 1

# define variable if is API or local
is_api_models = (
isinstance(args.model, list)
or args.model in API_MODEL_LIST
or not args.force_local
)
is_api_models = isinstance(args.model, list) or args.model in API_MODEL_LIST or not args.force_local

# if model isn't API, load via vllm
if not is_api_models:
Expand Down Expand Up @@ -148,9 +141,7 @@ def main():
logger.info("*** Load dataset ***")
dataset = load_dataset(args.dataset_name, split=args.split)
# Rename columns for compatibility with existing API
dataset = dataset.rename_columns(
{"chosen": "text_chosen", "rejected": "text_rejected"}
)
dataset = dataset.rename_columns({"chosen": "text_chosen", "rejected": "text_rejected"})

if args.sample:
logger.debug(f"Running on first {args.sample} examples")
Expand All @@ -165,11 +156,7 @@ def main():
def update_progress_bar(done, total):
# Simple text-based progress bar
progress = int(50 * done / total) # Calculate progress (50 chars width)
sys.stdout.write(
"\r[{}{}] {}/{}".format(
"#" * progress, "." * (50 - progress), done, total
)
)
sys.stdout.write("\r[{}{}] {}/{}".format("#" * progress, "." * (50 - progress), done, total))
sys.stdout.flush()

def get_judgement(batch, debug=args.debug):
Expand Down Expand Up @@ -225,9 +212,7 @@ def get_judgement(batch, debug=args.debug):

with ThreadPoolExecutor(max_workers=args.num_threads) as executor:
# Submit all tasks and hold their futures in a list
future_to_index = {
executor.submit(get_judgement, x): i for i, x in enumerate(dataset)
}
future_to_index = {executor.submit(get_judgement, x): i for i, x in enumerate(dataset)}

# As tasks complete, update progress and store results in the original order
for future in as_completed(future_to_index):
Expand Down Expand Up @@ -268,34 +253,22 @@ def format_judgements(batch, optional_chat_template=None):
if optional_chat_template is not None:
optional_chat_template.set_system_message(system_prompt)
optional_chat_template.messages = []
optional_chat_template.append_message(
optional_chat_template.roles[0], user_prompt
)
optional_chat_template.append_message(
optional_chat_template.roles[1], None
)
optional_chat_template.append_message(optional_chat_template.roles[0], user_prompt)
optional_chat_template.append_message(optional_chat_template.roles[1], None)
prompt = optional_chat_template.get_prompt()
elif model_modifier:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch["text"] = prompt
batch["is_shuffled"] = is_shuffled
return batch

# format the dataset for the model, with optional fastchat templating
chat_template = (
get_conv_template(args.chat_template)
if args.chat_template is not None
else None
)
dataset_prompts = dataset.map(
format_judgements, fn_kwargs={"optional_chat_template": chat_template}
)
chat_template = get_conv_template(args.chat_template) if args.chat_template is not None else None
dataset_prompts = dataset.map(format_judgements, fn_kwargs={"optional_chat_template": chat_template})

# collect texts of dataset in list
prompts = dataset_prompts["text"]
Expand Down
22 changes: 19 additions & 3 deletions tests/test_generative.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Testing new functions in generative.py"""

import pytest

from scripts.generative import chat_completion_anthropic
from scripts.generative import chat_completion_cohere, chat_completion_gemini
from scripts.generative import chat_completion_together, format_judge_answers
from scripts.generative import chat_completion_cohere, format_judge_answers
from scripts.generative import process_judgement, run_judge_pair


Expand Down Expand Up @@ -40,3 +40,19 @@ def test_format_judge_answers_multilingual_includes_language():

assert src_lang in sys_prompt
assert tgt_lang in sys_prompt


@pytest.mark.parametrize("judgment,expected", [("[[A]]", "A"), ("[[B]]", "B"), ("I don't know", "error")])
def test_process_judgment_answers(judgment, expected):
answer = process_judgement(judgment, is_prometheus=False)
assert answer == expected


@pytest.mark.api
def test_cohere_api():
pass


@pytest.mark.api
def test_run_judge_pair():
pass

0 comments on commit e19b9e6

Please sign in to comment.