Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #320 from Aarhus-Psychiatry-Research/bokajgd/issue315
Browse files Browse the repository at this point in the history
Adding AUC-ROC plot
  • Loading branch information
MartinBernstorff authored Nov 11, 2022
2 parents 83e5ff6 + f50aa1d commit c0ff467
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/psycopt2d/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
plot_metric_by_calendar_time,
plot_metric_by_time_until_diagnosis,
)
from psycopt2d.visualization.roc_auc import plot_auc_roc
from psycopt2d.visualization.sens_over_time import (
plot_sensitivity_by_time_to_outcome_heatmap,
)
Expand Down Expand Up @@ -128,6 +129,13 @@ def create_base_plot_artifacts(
save_path=save_dir / "auc_by_calendar_time.png",
),
),
ArtifactContainer(
label="auc_roc",
artifact=plot_auc_roc(
eval_dataset=eval_dataset,
save_path=save_dir / "auc_roc.png",
),
),
ArtifactContainer(
label="recall_by_time_to_diagnosis",
artifact=plot_metric_by_time_until_diagnosis(
Expand Down
40 changes: 40 additions & 0 deletions src/psycopt2d/visualization/roc_auc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""AUC ROC curve."""
from pathlib import Path
from typing import Optional, Union

import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve

from psycopt2d.evaluation_dataclasses import EvalDataset


def plot_auc_roc(
eval_dataset: EvalDataset,
fig_size: Optional[tuple] = (10, 10),
save_path: Optional[Path] = None,
) -> Union[None, Path]:
"""Plot AUC ROC curve.
Args:
eval_dataset (EvalDataset): Evaluation dataset.
fig_size (Optional[tuple], optional): figure size. Defaults to None.
save_path (Optional[Path], optional): path to save figure. Defaults to None.
Returns:
Union[None, Path]: None if save_path is None, else path to saved figure.
"""
fpr, tpr, _ = roc_curve(eval_dataset.y, eval_dataset.y_hat_probs)
auc = roc_auc_score(eval_dataset.y, eval_dataset.y_hat_probs)

plt.figure(figsize=fig_size)
plt.plot(fpr, tpr, label="AUC score = " + str(auc))
plt.title("AUC ROC Curve")
plt.legend(loc=4)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")

if save_path is not None:
plt.savefig(save_path)
plt.close()

return save_path
5 changes: 5 additions & 0 deletions tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
plot_metric_by_calendar_time,
plot_metric_by_time_until_diagnosis,
)
from psycopt2d.visualization.roc_auc import plot_auc_roc
from psycopt2d.visualization.sens_over_time import (
create_sensitivity_by_time_to_outcome_df,
plot_sensitivity_by_time_to_outcome_heatmap,
Expand Down Expand Up @@ -151,3 +152,7 @@ def test_plot_feature_importances():
top_n_feature_importances=n_features,
save_path="tmp",
)


def test_plot_roc_auc(synth_eval_dataset: EvalDataset):
plot_auc_roc(eval_dataset=synth_eval_dataset)

0 comments on commit c0ff467

Please sign in to comment.