Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #324 from Aarhus-Psychiatry-Research/bokajgd/issue322
Browse files Browse the repository at this point in the history
Feat: Plot performance by cyclic time periods
  • Loading branch information
bokajgd authored Nov 14, 2022
2 parents 390e5d9 + 7cbf66f commit f5ff9a2
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 15 deletions.
26 changes: 25 additions & 1 deletion src/psycopt2d/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from psycopt2d.visualization.performance_over_time import (
plot_auc_by_time_from_first_visit,
plot_metric_by_calendar_time,
plot_metric_by_cyclic_time,
plot_metric_by_time_until_diagnosis,
)
from psycopt2d.visualization.roc_auc import plot_auc_roc
Expand Down Expand Up @@ -93,7 +94,30 @@ def create_base_plot_artifacts(
label="auc_by_calendar_time",
artifact=plot_metric_by_calendar_time(
eval_dataset=eval_dataset,
y_title="AUC",
save_path=save_dir / "auc_by_calendar_time.png",
),
),
ArtifactContainer(
label="auc_by_hour_of_day",
artifact=plot_metric_by_cyclic_time(
eval_dataset=eval_dataset,
bin_period="H",
save_path=save_dir / "auc_by_calendar_time.png",
),
),
ArtifactContainer(
label="auc_by_day_of_week",
artifact=plot_metric_by_cyclic_time(
eval_dataset=eval_dataset,
bin_period="D",
save_path=save_dir / "auc_by_calendar_time.png",
),
),
ArtifactContainer(
label="auc_by_month_of_year",
artifact=plot_metric_by_cyclic_time(
eval_dataset=eval_dataset,
bin_period="M",
save_path=save_dir / "auc_by_calendar_time.png",
),
),
Expand Down
134 changes: 127 additions & 7 deletions src/psycopt2d/visualization/performance_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def create_performance_by_calendar_time_df(
y_hat (Iterable[int, float]): Predicted probabilities or labels depending on metric
timestamps (Iterable[pd.Timestamp]): Timestamps of predictions
metric_fn (Callable): Callable which returns the metric to calculate
bin_period (str): How to bin time. "M" for year/month, "Y" for year, "Q" for quarter
bin_period (str): How to bin time. Takes "M" for month, "Q" for quarter or "Y" for year
Returns:
pd.DataFrame: Dataframe ready for plotting
Expand All @@ -41,15 +41,16 @@ def create_performance_by_calendar_time_df(

df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq=bin_period).format()

output_df = df.groupby("time_bin").apply(calc_performance, metric_fn)
output_df = df.groupby("time_bin").apply(func=calc_performance, metric=metric_fn)

output_df = output_df.reset_index().rename({0: "metric"}, axis=1)

return output_df


def plot_metric_by_calendar_time(
eval_dataset: EvalDataset,
y_title: str,
y_title: str = "AUC",
bin_period: str = "Y",
save_path: Optional[str] = None,
metric_fn: Callable = roc_auc_score,
Expand All @@ -58,10 +59,10 @@ def plot_metric_by_calendar_time(
Args:
eval_dataset (EvalDataset): EvalDataset object
y_title (str): Title of y-axis.
bin_period (str): Which time period to bin on. Takes "M" or "Y".
y_title (str): Title of y-axis. Defaults to "AUC".
bin_period (str): Which time period to bin on. Takes "M" for month, "Q" for quarter or "Y" for year
save_path (str, optional): Path to save figure. Defaults to None.
metric_fn (Callable): Function which returns the metric.
metric_fn (Callable): Function which returns the metric. Defaults to roc_auc_score.
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
Expand All @@ -77,14 +78,133 @@ def plot_metric_by_calendar_time(
return plot_basic_chart(
x_values=df["time_bin"],
y_values=df["metric"],
x_title="Calendar time",
x_title="Month"
if bin_period == "M"
else "Quarter"
if bin_period == "Q"
else "Year",
y_title=y_title,
sort_x=sort_order,
plot_type=["line", "scatter"],
save_path=save_path,
)


def create_performance_by_cyclic_time_df(
labels: Iterable[int],
y_hat: Iterable[Union[int, float]],
timestamps: Iterable[pd.Timestamp],
metric_fn: Callable,
bin_period: str,
) -> pd.DataFrame:
"""Calculate performance by cyclic time period of prediction time data
frame. Cyclic time periods include e.g. day of week, hour of day, etc.
Args:
labels (Iterable[int]): True labels
y_hat (Iterable[int, float]): Predicted probabilities or labels depending on metric
timestamps (Iterable[pd.Timestamp]): Timestamps of predictions
metric_fn (Callable): Callable which returns the metric to calculate
bin_period (str): Which cyclic time period to bin on. Takes "H" for hour of day, "D" for day of week and "M" for month of year.
Returns:
pd.DataFrame: Dataframe ready for plotting
"""
df = pd.DataFrame({"y": labels, "y_hat": y_hat, "timestamp": timestamps})

if bin_period == "H":
df["time_bin"] = pd.to_datetime(df["timestamp"]).dt.strftime("%H")
elif bin_period == "D":
df["time_bin"] = pd.to_datetime(df["timestamp"]).dt.strftime("%A")
# Sort days of week correctly
df["time_bin"] = pd.Categorical(
df["time_bin"],
categories=[
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
],
ordered=True,
)
elif bin_period == "M":
df["time_bin"] = pd.to_datetime(df["timestamp"]).dt.strftime("%B")
# Sort months correctly
df["time_bin"] = pd.Categorical(
df["time_bin"],
categories=[
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
],
ordered=True,
)
else:
raise ValueError(
"bin_period must be 'H' for hour of day, 'D' for day of week or 'M' for month of year",
)

output_df = df.groupby("time_bin").apply(func=calc_performance, metric=metric_fn)

output_df = output_df.reset_index().rename({0: "metric"}, axis=1)

return output_df


def plot_metric_by_cyclic_time(
eval_dataset: EvalDataset,
y_title: str = "AUC",
bin_period: str = "Y",
save_path: Optional[str] = None,
metric_fn: Callable = roc_auc_score,
) -> Union[None, Path]:
"""Plot performance by cyclic time period of prediction time. Cyclic time
periods include e.g. day of week, hour of day, etc.
Args:
eval_dataset (EvalDataset): EvalDataset object
y_title (str): Title for y-axis (metric name). Defaults to "AUC"
bin_period (str): Which cyclic time period to bin on. Takes "H" for hour of day, "D" for day of week and "M" for month of year.
save_path (str, optional): Path to save figure. Defaults to None.
metric_fn (Callable): Function which returns the metric. Defaults to roc_auc_score.
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
"""
df = create_performance_by_cyclic_time_df(
labels=eval_dataset.y,
y_hat=eval_dataset.y_hat_probs,
timestamps=eval_dataset.pred_timestamps,
metric_fn=metric_fn,
bin_period=bin_period,
)

return plot_basic_chart(
x_values=df["time_bin"],
y_values=df["metric"],
x_title="Hour of day"
if bin_period == "H"
else "Day of week"
if bin_period == "D"
else "Month of year",
y_title=y_title,
plot_type=["line", "scatter"],
save_path=save_path,
)


def create_performance_by_time_from_event_df(
labels: Iterable[int],
y_hat: Iterable[Union[int, float]],
Expand Down
27 changes: 20 additions & 7 deletions tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from psycopt2d.visualization.performance_over_time import (
plot_auc_by_time_from_first_visit,
plot_metric_by_calendar_time,
plot_metric_by_cyclic_time,
plot_metric_by_time_until_diagnosis,
)
from psycopt2d.visualization.roc_auc import plot_auc_roc
Expand Down Expand Up @@ -99,21 +100,33 @@ def test_plot_performance_by_age(synth_eval_dataset: EvalDataset):
plot_performance_by_age(eval_dataset=synth_eval_dataset)


def test_plot_performance_by_calendar_time(synth_eval_dataset: EvalDataset):
@pytest.mark.parametrize(
"bin_period",
["M", "Q", "Y"],
)
def test_plot_performance_by_calendar_time(
synth_eval_dataset: EvalDataset,
bin_period: str,
):
plot_metric_by_calendar_time(
eval_dataset=synth_eval_dataset,
bin_period="M",
bin_period=bin_period,
metric_fn=roc_auc_score,
y_title="AUC",
)


def test_plot_performance_by_calendar_time_quarterly(synth_eval_dataset: EvalDataset):
plot_metric_by_calendar_time(
@pytest.mark.parametrize(
"bin_period",
["H", "D", "M"],
)
def test_plot_performance_by_cyclic_time(
synth_eval_dataset: EvalDataset,
bin_period: str,
):
plot_metric_by_cyclic_time(
eval_dataset=synth_eval_dataset,
bin_period="Q",
bin_period=bin_period,
metric_fn=roc_auc_score,
y_title="AUC",
)


Expand Down

0 comments on commit f5ff9a2

Please sign in to comment.