Skip to content

Commit

Permalink
Merge pull request #35 from for-ai/update-analysis
Browse files Browse the repository at this point in the history
Update analysis
  • Loading branch information
ljvmiranda921 authored Oct 2, 2024
2 parents d230c18 + a40c03f commit d729fe5
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 25 deletions.
54 changes: 32 additions & 22 deletions analysis/plot_leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def get_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=Path, help="Directory to save the output plots."),
parser.add_argument("--dataset", type=str, default="aya-rm-multilingual/eval-results", help="HuggingFace dataset that stores the eval results.")
parser.add_argument("--dataset", type=str, default="aya-rm-multilingual/eval-results-gtranslate-v2", help="HuggingFace dataset that stores the eval results.")
parser.add_argument("--force_download", action="store_true", help="If set, will redownload the dataset.")
parser.add_argument("--show_english_drop", action="store_true", help="If set, will show English drop.")
# fmt: on
return parser.parse_args()

Expand Down Expand Up @@ -59,7 +60,12 @@ def main():
model_types = leaderboard_df["Type"].unique().tolist()
for model_type in model_types:
model_type_df = leaderboard_df[leaderboard_df["Type"] == model_type]
data = model_type_df.drop(["eng_Latn", "Type", "Std"], axis=1)
columns = ["Type", "Std"]
if "eng_Latn" not in model_type_df.columns:
logging.warning(f"Language 'eng_Latn' not found for {model_type}!")
else:
columns += ["eng_Latn"]
data = model_type_df.drop(columns, axis=1)
avg_col = "Avg"
data = data[[avg_col] + [c for c in data.columns if c != avg_col]]
data = data.dropna()
Expand All @@ -76,28 +82,31 @@ def main():
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
fig.tight_layout()
output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.png"
csv_output_file = output_dir / f"leaderboard-{model_type.replace(' ', '_')}.csv"
data.to_csv(csv_output_file)
fig.savefig(output_file, dpi=120)
logging.info(f"Saved to {output_file}")

# *** English drop ***
eng_drop_df = pd.DataFrame(
{
"Overall": get_eng_drop(leaderboard_df)["Percentage_Change"],
"Chat": get_eng_drop(chat_leaderboard_df)["Percentage_Change"],
"Chat Hard": get_eng_drop(chat_hard_leaderboard_df)["Percentage_Change"],
"Safety": get_eng_drop(safety_leaderboard_df)["Percentage_Change"],
"Reasoning": get_eng_drop(reasoning_leaderboard_df)["Percentage_Change"],
}
)
# Only get top-3 and bottom-3. Put bottom 3 at the top rows
top_bottom_n = pd.concat([eng_drop_df.nsmallest(3, "Overall"), eng_drop_df.nlargest(3, "Overall")])
fig, ax = plt.subplots(figsize=(9, 4))
sns.heatmap(top_bottom_n, annot=True, cmap="Reds_r", fmt=".1f", annot_kws={"size": 18}, cbar=False)
ax.xaxis.tick_top()
fig.tight_layout()
output_file = output_dir / "eng-drop-overall.png"
fig.savefig(output_file, dpi=120)
logging.info(f"Saved to {output_file}")
if args.show_english_drop:
eng_drop_df = pd.DataFrame(
{
"Overall": get_eng_drop(leaderboard_df)["Percentage_Change"],
"Chat": get_eng_drop(chat_leaderboard_df)["Percentage_Change"],
"Chat Hard": get_eng_drop(chat_hard_leaderboard_df)["Percentage_Change"],
"Safety": get_eng_drop(safety_leaderboard_df)["Percentage_Change"],
"Reasoning": get_eng_drop(reasoning_leaderboard_df)["Percentage_Change"],
}
)
# Only get top-3 and bottom-3. Put bottom 3 at the top rows
top_bottom_n = pd.concat([eng_drop_df.nsmallest(3, "Overall"), eng_drop_df.nlargest(3, "Overall")])
fig, ax = plt.subplots(figsize=(9, 4))
sns.heatmap(top_bottom_n, annot=True, cmap="Reds_r", fmt=".1f", annot_kws={"size": 18}, cbar=False)
ax.xaxis.tick_top()
fig.tight_layout()
output_file = output_dir / "eng-drop-overall.png"
fig.savefig(output_file, dpi=120)
logging.info(f"Saved to {output_file}")


def get_eng_drop(df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -134,8 +143,9 @@ def get_leaderboard(dataset: str, force_download: bool, category: Optional[str]
)

# Get average but dont include eng_Latn
lang_scores_df["Avg"] = lang_scores_df.drop(["eng_Latn", "Type"], axis=1).mean(axis=1, skipna=False)
lang_scores_df["Std"] = lang_scores_df.drop(["eng_Latn", "Type"], axis=1).std(axis=1, skipna=False)
columns = ["Type"] if "eng_Latn" not in lang_scores_df else ["eng_Latn", "Type"]
lang_scores_df["Avg"] = lang_scores_df.drop(columns, axis=1).mean(axis=1, skipna=False)
lang_scores_df["Std"] = lang_scores_df.drop(columns, axis=1).std(axis=1, skipna=False)
lang_scores_df = lang_scores_df.sort_values(by=["Type", "Avg"], ascending=False)
return lang_scores_df

Expand Down
111 changes: 108 additions & 3 deletions analysis/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from pathlib import Path
from typing import Any, Dict, List

from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING

logging.basicConfig(level=logging.INFO)


PLOT_PARAMS = {
"text.usetex": True,
"text.usetex": False,
"xtick.labelsize": 18,
"ytick.labelsize": 18,
"legend.fontsize": 18,
Expand Down Expand Up @@ -52,6 +50,19 @@ def _compute_category_scores(results: Dict[str, float]) -> Dict[str, float]:
"subset_scores": result["subset"],
}
)
elif result.get("ref_model"):
# Most likely DPO:
category_scores = _compute_category_scores(result["extra_results"])
model_scores.append(
{
"model": result["model"],
"model_type": "DPO",
"chat_template": result["chat_template"],
"score": sum(category_scores.values()) / len(category_scores),
"category_scores": category_scores,
"subset_scores": result["extra_results"],
}
)
else:
category_scores = _compute_category_scores(result["extra_results"])
model_scores.append(
Expand All @@ -65,3 +76,97 @@ def _compute_category_scores(results: Dict[str, float]) -> Dict[str, float]:
}
)
return model_scores


EXAMPLE_COUNTS = {
"alpacaeval-easy": 100,
"alpacaeval-length": 95,
"alpacaeval-hard": 95,
"mt-bench-easy": 28,
"mt-bench-med": 40,
"mt-bench-hard": 37,
"math-prm": 984, # actual length 447, upweighting to be equal to code
"refusals-dangerous": 100,
"refusals-offensive": 100,
"llmbar-natural": 100,
"llmbar-adver-neighbor": 134,
"llmbar-adver-GPTInst": 92,
"llmbar-adver-GPTOut": 47,
"llmbar-adver-manual": 46,
"xstest-should-refuse": 154,
"xstest-should-respond": 250,
"donotanswer": 136,
"hep-cpp": 164,
"hep-go": 164,
"hep-java": 164,
"hep-js": 164,
"hep-python": 164,
"hep-rust": 164,
}

SUBSET_MAPPING = {
"Chat": [
"alpacaeval-easy",
"alpacaeval-length",
"alpacaeval-hard",
"mt-bench-easy",
"mt-bench-med",
],
"Chat Hard": [
"mt-bench-hard",
"llmbar-natural",
"llmbar-adver-neighbor",
"llmbar-adver-GPTInst",
"llmbar-adver-GPTOut",
"llmbar-adver-manual",
],
"Safety": [
"refusals-dangerous",
"refusals-offensive",
"xstest-should-refuse",
"xstest-should-respond",
"donotanswer",
],
"Reasoning": [
"math-prm",
"hep-cpp",
"hep-go",
"hep-java",
"hep-js",
"hep-python",
"hep-rust",
],
}

SUBSET_NAME_TO_PAPER_READY = {
"alpacaeval-easy": "AlpacaEval Easy",
"alpacaeval-length": "AlpacaEval Length",
"alpacaeval-hard": "AlpacaEval Hard",
"mt-bench-easy": "MT Bench Easy",
"mt-bench-med": "MT Bench Medium",
"mt-bench-hard": "MT Bench Hard",
"llmbar-natural": "LLMBar Natural",
"llmbar-adver-neighbor": "LLMBar Adver. Neighbor",
"llmbar-adver-GPTInst": "LLMBar Adver. GPTInst",
"llmbar-adver-GPTOut": "LLMBar Adver. GPTOut",
"llmbar-adver-manual": "LLMBar Adver. Manual",
"refusals-dangerous": "Refusals Dangerous",
"refusals-offensive": "Refusals Offensive",
"xstest-should-refuse": "XSTest Should Refuse",
"xstest-should-respond": "XSTest Should Respond",
"donotanswer": "Do Not Answer",
"math-prm": "PRM Math",
"hep-cpp": "HumanEvalPack CPP",
"hep-go": "HumanEvalPack Go",
"hep-java": "HumanEvalPack Java",
"hep-js": "HumanEvalPack Javascript",
"hep-python": "HumanEvalPack Python",
"hep-rust": "HumanEvalPack Rust",
"anthropic_harmless": "Anthropic Harmless",
"anthropic_helpful": "Anthropic Helpful",
"anthropic_hhh": "Anthropic HHH",
"mtbench_gpt4": "MT Bench GPT-4",
"mtbench_human": "MT Bench Human",
"shp": "SHP",
"summarize": "Summarize",
}

0 comments on commit d729fe5

Please sign in to comment.