Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
feat: plot sensitivity by time to diagnosis by threshold (#397)
Browse files Browse the repository at this point in the history
## Notes for reviewers
Primary downside is the lack of a legend for each `pred_proba`
threshold. Unsure about how to do this. Perhaps you have an idea,
@HLasse? :-)

Looks like this on synth data:

![test_recall_by_calendar_time](https://user-images.githubusercontent.com/8526086/220144105-5e618de0-baea-49c2-88fa-32c82d97f9bf.png)

Closees #396.
  • Loading branch information
MartinBernstorff authored Feb 22, 2023
2 parents c2b8062 + 33198b7 commit 1195673
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
plot_metric_by_calendar_time,
plot_metric_by_cyclic_time,
plot_metric_by_time_until_diagnosis,
plot_recall_by_calendar_time,
)
from psycop_model_training.model_eval.base_artifacts.plots.precision_recall import (
plot_precision_recall,
Expand Down Expand Up @@ -162,6 +163,17 @@ def create_base_plot_artifacts(self) -> list[ArtifactContainer]:
save_path=self.save_dir / "precision_recall.png",
),
),
ArtifactContainer(
label="precision_recall",
artifact=plot_recall_by_calendar_time(
eval_dataset=self.eval_ds,
pred_proba_percentile=[0.95, 0.97, 0.99],
bins=self.cfg.eval.lookahead_bins,
legend=True,
y_limits=(0, 0.5),
save_path=self.save_dir / "recall_by_calendar_time.png",
),
),
]

def get_feature_selection_artifacts(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

import matplotlib.pyplot as plt
import pandas as pd
from numpy import isin


def plot_basic_chart(
x_values: Iterable,
y_values: Iterable,
y_values: Union[Iterable[int, float], Iterable[Iterable[int, float]]],
x_title: str,
y_title: str,
plot_type: Union[list[str], str],
labels: Optional[list[str]] = None,
sort_x: Optional[Iterable[int]] = None,
sort_y: Optional[Iterable[int]] = None,
flip_x_axis: bool = False,
flip_y_axis: bool = False,
y_limits: Optional[tuple[float, float]] = None,
fig_size: Optional[tuple] = (5, 5),
dpi: Optional[int] = 300,
Expand All @@ -30,12 +34,15 @@ def plot_basic_chart(
y_title (str): title of y axis
plot_type (Optional[Union[List[str], str]], optional): type of plots.
Options are combinations of ["bar", "hbar", "line", "scatter"] Defaults to "bar".
labels: (Optional[list[str]]): Optional labels to add to the plot(s).
sort_x (Optional[Iterable[int]], optional): order of values on the x-axis. Defaults to None.
sort_y (Optional[Iterable[int]], optional): order of values on the y-axis. Defaults to None.
y_limits (Optional[tuple[float, float]], optional): y-axis limits. Defaults to None.
fig_size (Optional[tuple], optional): figure size. Defaults to None.
dpi (Optional[int], optional): dpi of figure. Defaults to 300.
save_path (Optional[Path], optional): path to save figure. Defaults to None.
flip_x_axis (bool, optional): Whether to flip the x axis. Defaults to False.
flip_y_axis (bool, optional): Whether to flip the y axis. Defaults to False.
Returns:
Union[None, Path]: None if save_path is None, else path to saved figure
Expand All @@ -44,7 +51,7 @@ def plot_basic_chart(
plot_type = [plot_type]

df = pd.DataFrame(
{"x": x_values, "y": y_values, "sort_x": sort_x, "sort_y": sort_y},
{"x": x_values, "sort_x": sort_x, "sort_y": sort_y},
)

if sort_x is not None:
Expand All @@ -55,15 +62,29 @@ def plot_basic_chart(

plt.figure(figsize=fig_size, dpi=dpi)

if "bar" in plot_type:
plt.bar(df["x"], df["y"])
if "hbar" in plot_type:
plt.barh(df["x"], df["y"])
plt.yticks(fontsize=7)
if "line" in plot_type:
plt.plot(df["x"], df["y"])
if "scatter" in plot_type:
plt.scatter(df["x"], df["y"])
if not isinstance(y_values[0], Iterable):
y_values = [y_values] # Make y_values an iterable

plot_functions = {
"bar": plt.bar,
"hbar": plt.barh,
"line": plt.plot,
"scatter": plt.scatter,
}

# choose the first plot type as the one to use for legend
label_plot_type = plot_type[0]

label_plots = []
for y_series in y_values:
for p_type in plot_type:
plot_function = plot_functions.get(p_type)
plot = plot_function(df["x"], y_series)
if p_type == label_plot_type:
# need to one of the plot types for labelling
label_plots.append(plot)
if p_type == "hbar":
plt.yticks(fontsize=7)

plt.xlabel(x_title)
plt.ylabel(y_title)
Expand All @@ -73,6 +94,19 @@ def plot_basic_chart(
if y_limits is not None:
plt.ylim(y_limits)

if flip_x_axis:
plt.gca().invert_xaxis()
if flip_y_axis:
plt.gca().invert_yaxis()
if labels is not None:
plt.figlegend(
[plot[0] for plot in label_plots],
[str(label) for label in labels],
loc="upper right",
bbox_to_anchor=(0.9, 0.88),
frameon=True,
)

if save_path is not None:
plt.savefig(save_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.metrics import f1_score, recall_score, roc_auc_score

from psycop_model_training.model_eval.base_artifacts.plots.base_charts import (
plot_basic_chart,
)
from psycop_model_training.model_eval.base_artifacts.plots.sens_over_time import (
create_sensitivity_by_time_to_outcome_df,
)
from psycop_model_training.model_eval.base_artifacts.plots.utils import calc_performance
from psycop_model_training.model_eval.dataclasses import EvalDataset
from psycop_model_training.utils.utils import bin_continuous_data, round_floats_to_edge
Expand Down Expand Up @@ -50,6 +53,58 @@ def create_performance_by_calendar_time_df(
return output_df


def plot_recall_by_calendar_time(
eval_dataset: EvalDataset,
pred_proba_percentile: Union[float, Iterable[float]],
bins: Iterable[float],
y_title: str = "Sensitivity (Recall)",
y_limits: Optional[tuple[float, float]] = None,
save_path: Optional[str] = None,
) -> Union[None, Path]:
"""Plot performance by calendar time of prediciton.
Args:
eval_dataset (EvalDataset): EvalDataset object
pred_proba_percentile (Union[float, Iterable[float]]): Percentile of highest predicted probabilities to mark as positive in binary classification.
bins (Iterable[float], optional): Bins to use for time to outcome.
y_title (str): Title of y-axis. Defaults to "AUC".
save_path (str, optional): Path to save figure. Defaults to None.
y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0).
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
"""
if not isinstance(pred_proba_percentile, Iterable):
pred_proba_percentile = [pred_proba_percentile]

# Get percentiles from a series of predicted probabilities
pred_proba_percentiles = eval_dataset.y_hat_probs.rank(pct=True)

dfs = [
create_sensitivity_by_time_to_outcome_df(
labels=eval_dataset.y,
y_hat_probs=pred_proba_percentiles,
pred_proba_threshold=threshold,
outcome_timestamps=eval_dataset.outcome_timestamps,
prediction_timestamps=eval_dataset.pred_timestamps,
bins=bins,
)
for threshold in pred_proba_percentile
]

return plot_basic_chart(
x_values=dfs[0]["days_to_outcome_binned"],
y_values=[df["sens"] for df in dfs],
x_title="Days from event",
labels=pred_proba_percentile,
y_title=y_title,
y_limits=y_limits,
flip_x_axis=True,
plot_type=["line", "scatter"],
save_path=save_path,
)


def plot_metric_by_calendar_time(
eval_dataset: EvalDataset,
y_title: str = "AUC",
Expand Down
16 changes: 15 additions & 1 deletion tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import pandas as pd
import pytest
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.metrics import f1_score, recall_score, roc_auc_score

from psycop_model_training.model_eval.base_artifacts.plots.base_charts import (
plot_basic_chart,
Expand All @@ -24,6 +24,7 @@
plot_metric_by_calendar_time,
plot_metric_by_cyclic_time,
plot_metric_by_time_until_diagnosis,
plot_recall_by_calendar_time,
)
from psycop_model_training.model_eval.base_artifacts.plots.precision_recall import (
plot_precision_recall,
Expand Down Expand Up @@ -119,6 +120,19 @@ def test_plot_performance_by_calendar_time(
eval_dataset=synth_eval_dataset,
bin_period=bin_period,
metric_fn=roc_auc_score,
save_path=PROJECT_ROOT / "test.png",
)


def test_plot_recall_by_calendar_time(
synth_eval_dataset: EvalDataset,
):
plot_recall_by_calendar_time(
eval_dataset=synth_eval_dataset,
pred_proba_percentile=[0.8, 0.9, 0.95],
bins=list(range(0, 1460, 180)),
y_limits=(0, 0.5),
save_path=PROJECT_ROOT / "test_recall_by_calendar_time.png",
)


Expand Down

0 comments on commit 1195673

Please sign in to comment.