Skip to content

Commit

Permalink
Handle special case for Gemma (#20)
Browse files Browse the repository at this point in the history
* Gemma special case handled

Gemma does not support system prompts and FastChat template does not work in the current setup. Fixed by concatenating system and user prompts using double-newline and applying tokenizer's chat template.

* GLM-4 and Qwen-2 cases added for tokenizer chat template.

* python module call added

* python module call added
  • Loading branch information
ShayekhBinIslam authored Aug 10, 2024
1 parent ace3efc commit 168f8b7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion experiments/run_llm_evals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ declare -A languages=(
)

for lang_code in "${!languages[@]}"; do
python3 scripts/run_generative.py \
python3 -m scripts.run_generative \
--model "$MODEL" \
--dataset "$DATASET" \
--lang_code "$lang_code" \
Expand Down
4 changes: 2 additions & 2 deletions experiments/run_rm_evals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ declare -A languages=(

# Loop through each language and run the command
for lang_code in "${!languages[@]}"; do
python3 scripts/run_rewardbench.py \
python3 -m scripts.run_rewardbench \
--model "$MODEL" \
--chat_template "$CHAT_TEMPLATE" \
--dataset "$DATASET" \
Expand All @@ -93,4 +93,4 @@ for lang_code in "${!languages[@]}"; do
--trust_remote_code \
--force_truncation \
--save_all
done
done
12 changes: 12 additions & 0 deletions scripts/run_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ def main():
model_modifier = "gemini"
elif "Llama-3.1" in args.model:
model_modifier = "llama-3.1"
elif "gemma" in args.model:
model_modifier = "gemma"
elif "glm-4" in args.model:
model_modifier = "glm-4"
elif "Qwen2" in args.model:
model_modifier = "qwen-2"
else:
model_modifier = None

Expand Down Expand Up @@ -277,6 +283,12 @@ def format_judgements(batch, optional_chat_template=None):
optional_chat_template.append_message(optional_chat_template.roles[0], user_prompt)
optional_chat_template.append_message(optional_chat_template.roles[1], None)
prompt = optional_chat_template.get_prompt()
elif model_modifier == "gemma":
# Gemma models don't support `system prompt`.
messages = [
{"role": "user", "content": system_prompt + "\n\n" + user_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
elif model_modifier:
messages = [
{"role": "system", "content": system_prompt},
Expand Down

0 comments on commit 168f8b7

Please sign in to comment.