Skip to content

Commit

Permalink
Finalize structure for plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 committed Oct 5, 2024
1 parent 3b337c5 commit 5f99e2d
Showing 1 changed file with 26 additions and 73 deletions.
99 changes: 26 additions & 73 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt

FONT_SIZES = {"small": 12, "medium": 16, "large": 18}
Expand Down Expand Up @@ -40,9 +41,12 @@ def get_args():
shared_args.add_argument("--figsize", type=int, nargs=2, default=[10, 10], help="Matplotlib figure size.")

parser_main_results = subparsers.add_parser("main_heatmap", help="Plot results as a heatmap.", parents=[shared_args])
parser_main_results.add_argument("--input_path", action="append", help="Path to the results file and model category (e.g., DPO::path/to/dpo_results.csv).")
parser_main_results.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
parser_main_results.add_argument("--top_ten_only", action="store_true", help="If set, will only show the top-10 of all models.")
parser_main_results.add_argument("--print_latex", action="store_true", help="If set, print LaTeX table.")

parser_eng_drop = subparsers.add_parser("eng_drop_line", help="Plot english drop as a line chart.", parents=[shared_args])
parser_eng_drop.add_argument("--input_path", type=Path, required=True, help="Path to the results file.")
# fmt: on
return parser.parse_args()

Expand All @@ -52,6 +56,7 @@ def main():

cmd_map = {
"main_heatmap": plot_main_heatmap,
"eng_drop_line": plot_eng_drop_line,
}

def _filter_args(func, kwargs):
Expand All @@ -67,83 +72,31 @@ def _filter_args(func, kwargs):


def plot_main_heatmap(
input_path: list[str],
input_path: Path,
output_path: Optional[Path] = None,
figsize: Optional[tuple[int, int]] = None,
top_ten_only: bool = False,
print_latex: bool = False,
figsize: Optional[tuple[int, int]] = (18, 5),
):
category_results = {path.split("::")[0]: pd.read_csv(path.split("::")[1]) for path in input_path}

if top_ten_only:
logging.info("Passed --top_ten_only tag, will print LaTeX table of top ten models")
df_with_tags = []
for category, df in category_results.items():
df = df.set_index(df.columns[0]) * 100
df["model_type"] = category
df.index.name = "model"
df_with_tags.append(df)
top_ten_df = pd.concat(df_with_tags).sort_values(by="Avg", ascending=False).head(10)
model_type_col = top_ten_df.pop("model_type")
avg_col = top_ten_df.pop("Avg")
top_ten_df = top_ten_df.reindex(sorted(top_ten_df.columns), axis=1)
top_ten_df.insert(0, "Model", model_type_col)
top_ten_df.insert(1, "Avg", avg_col)

if print_latex:
top_ten_df.columns = top_ten_df.columns.str.replace("_", r"\_", regex=False)
print(top_ten_df.to_latex(float_format="%.2f"))

# Plot
top_ten_df.pop("Model")
fig, ax = plt.subplots(1, 1, figsize=figsize)

sns.heatmap(
top_ten_df,
ax=ax,
cmap="BuPu",
cbar=False,
annot=True,
annot_kws={"size": 14},
fmt=".2f",
)

# cbar = ax.collections[0].colorbar
# cbar.set_label("Score")
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x", rotation=45)
ax.set_ylabel("")

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")

else:
fig, axs = plt.subplots(3, 1, figsize=figsize, gridspec_kw={"height_ratios": [4, 2, 2]}, sharex=True)
cbar_ax = fig.add_axes([1.05, 0.3, 0.03, 0.4])
for idx, (ax, (category, df)) in enumerate(zip(axs, category_results.items())):
df = df.set_index(df.columns[0]) * 100
df.index.name = "model"
sns.heatmap(
df,
ax=ax,
cmap="BuPu",
annot=True,
annot_kws={"size": 12},
fmt=".2f",
# Ticklabels and colorbar on first heatmap only
xticklabels=(idx == 0),
cbar=(idx == 0),
cbar_ax=None if idx else cbar_ax,
)

if idx == 0:
cbar = ax.collections[0].colorbar
cbar.set_label("Score")
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x", rotation=45)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")
df = pd.read_csv(input_path)
# Remove unnecessary column
df.pop("eng_Latn")

df = df.sort_values(by="Avg_Multilingual", ascending=False).head(10).reset_index(drop=True)
data = df[[col for col in df.columns if col not in ("Model_Type", "Avg_Multilingual")]]
data = data.set_index("Model")
data = data * 100

fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.heatmap(data, ax=ax, cmap="YlGn", annot=True, annot_kws={"size": 14}, fmt=".2f", cbar=False)
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="x", rotation=45)
ax.set_ylabel("")
ax.set_yticklabels([f"{model} " for model in data.index])

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")


if __name__ == "__main__":
Expand Down

0 comments on commit 5f99e2d

Please sign in to comment.