From 4bb748ea932e1c1e3a3abd105e5d49f1292a3bbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 16:53:32 +0100 Subject: [PATCH 1/8] feat: ROC-AUC for eval Fixes #315 --- src/psycopt2d/visualization/roc_auc.py | 34 +++++++++++++++++++ tests/model_evaluation/test_visualizations.py | 5 +++ 2 files changed, 39 insertions(+) create mode 100644 src/psycopt2d/visualization/roc_auc.py diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py new file mode 100644 index 00000000..4a787481 --- /dev/null +++ b/src/psycopt2d/visualization/roc_auc.py @@ -0,0 +1,34 @@ +from sklearn.metrics import roc_curve, roc_auc_score +from typing import Optional, Union +import matplotlib.pyplot as plt +from pathlib import Path + +from psycopt2d.evaluation_dataclasses import EvalDataset + + +def plot_auc_roc( + eval_dataset: EvalDataset, + fig_size: Optional[tuple] = (10, 10), + save_path: Optional[Path] = None, +) -> Union[None, Path]: + """Plot AUC ROC curve. + + Args: + eval_dataset (EvalDataset): Evaluation dataset + fig_size (Optional[tuple], optional): figure size. Defaults to None. + save_path (Optional[Path], optional): path to save figure. Defaults to None. + """ + fpr, tpr, _ = roc_curve(eval_dataset.y, eval_dataset.y_hat_probs) + auc = roc_auc_score(eval_dataset.y, eval_dataset.y_hat_probs) + plt.plot(fpr, tpr, label="AUC score = " + str(auc)) + plt.title("AUC ROC Curve") + plt.legend(loc=4) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.show() + + if save_path is not None: + plt.savefig(save_path) + plt.close() + + return save_path diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index f8b87a9f..47df5ef5 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -22,6 +22,7 @@ plot_metric_by_calendar_time, plot_metric_by_time_until_diagnosis, ) +from psycopt2d.visualization.roc_auc import plot_auc_roc from psycopt2d.visualization.sens_over_time import ( create_sensitivity_by_time_to_outcome_df, plot_sensitivity_by_time_to_outcome_heatmap, @@ -151,3 +152,7 @@ def test_plot_feature_importances(): top_n_feature_importances=n_features, save_path="tmp", ) + + +def test_plot_roc_auc(synth_eval_dataset: EvalDataset): + plot_auc_roc(eval_dataset=synth_eval_dataset) From e2adf617db649fbdc2180749fd75cbda1db8ba06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 16:55:08 +0100 Subject: [PATCH 2/8] fix: unused argument --- src/psycopt2d/visualization/roc_auc.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index 4a787481..c17c5207 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -1,7 +1,8 @@ -from sklearn.metrics import roc_curve, roc_auc_score +from pathlib import Path from typing import Optional, Union + import matplotlib.pyplot as plt -from pathlib import Path +from sklearn.metrics import roc_auc_score, roc_curve from psycopt2d.evaluation_dataclasses import EvalDataset @@ -20,6 +21,8 @@ def plot_auc_roc( """ fpr, tpr, _ = roc_curve(eval_dataset.y, eval_dataset.y_hat_probs) auc = roc_auc_score(eval_dataset.y, eval_dataset.y_hat_probs) + + plt.figure(figsize=fig_size) plt.plot(fpr, tpr, label="AUC score = " + str(auc)) plt.title("AUC ROC Curve") plt.legend(loc=4) From f6938e171d5c88855ae37980e922e915c0b3a3a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 16:57:35 +0100 Subject: [PATCH 3/8] chore: linting --- src/psycopt2d/visualization/roc_auc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index c17c5207..ab63fdb2 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -18,6 +18,9 @@ def plot_auc_roc( eval_dataset (EvalDataset): Evaluation dataset fig_size (Optional[tuple], optional): figure size. Defaults to None. save_path (Optional[Path], optional): path to save figure. Defaults to None. + + Returns: + Union[None, Path]: None if save_path is None, else path to saved figure """ fpr, tpr, _ = roc_curve(eval_dataset.y, eval_dataset.y_hat_probs) auc = roc_auc_score(eval_dataset.y, eval_dataset.y_hat_probs) From ff9dc7d0c3e145c284ab21e4a646f83302a2ca6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 16:59:04 +0100 Subject: [PATCH 4/8] chore: linting --- src/psycopt2d/visualization/roc_auc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index ab63fdb2..7879f0a6 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -15,7 +15,7 @@ def plot_auc_roc( """Plot AUC ROC curve. Args: - eval_dataset (EvalDataset): Evaluation dataset + eval_dataset (EvalDataset): Evaluation dataset. fig_size (Optional[tuple], optional): figure size. Defaults to None. save_path (Optional[Path], optional): path to save figure. Defaults to None. From b7e020fca03c79ad18000d5b1608da6e99326141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 16:59:48 +0100 Subject: [PATCH 5/8] chore: linting --- src/psycopt2d/visualization/roc_auc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index 7879f0a6..f50aa0e1 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -20,7 +20,7 @@ def plot_auc_roc( save_path (Optional[Path], optional): path to save figure. Defaults to None. Returns: - Union[None, Path]: None if save_path is None, else path to saved figure + Union[None, Path]: None if save_path is None, else path to saved figure. """ fpr, tpr, _ = roc_curve(eval_dataset.y, eval_dataset.y_hat_probs) auc = roc_auc_score(eval_dataset.y, eval_dataset.y_hat_probs) From f43ab3c19f6e8c0416bc298b27f9715bd7ee3fdc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 17:01:45 +0100 Subject: [PATCH 6/8] chore: linting --- src/psycopt2d/visualization/roc_auc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index f50aa0e1..d5bc8163 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -1,3 +1,4 @@ +"""AUC ROC curve.""" from pathlib import Path from typing import Optional, Union From b1fff1bd1e79d4b8e3085270f819cb78f38c4615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 17:11:03 +0100 Subject: [PATCH 7/8] feat: add auc roc plot to evaluat_model --- src/psycopt2d/evaluate_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index 29b5040a..5a9464b3 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -27,6 +27,7 @@ plot_metric_by_calendar_time, plot_metric_by_time_until_diagnosis, ) +from psycopt2d.visualization.roc_auc import plot_auc_roc from psycopt2d.visualization.sens_over_time import ( plot_sensitivity_by_time_to_outcome_heatmap, ) @@ -128,6 +129,13 @@ def create_base_plot_artifacts( save_path=save_dir / "auc_by_calendar_time.png", ), ), + ArtifactContainer( + label="auc_roc", + artifact=plot_auc_roc( + eval_dataset=eval_dataset, + save_path=save_dir / "auc_roc.png", + ), + ), ArtifactContainer( label="recall_by_time_to_diagnosis", artifact=plot_metric_by_time_until_diagnosis( From f50aa1d1220812306e849bb71bae4a5193014ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 17:12:27 +0100 Subject: [PATCH 8/8] fix: remove plt.show --- src/psycopt2d/visualization/roc_auc.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/psycopt2d/visualization/roc_auc.py b/src/psycopt2d/visualization/roc_auc.py index d5bc8163..f6f1a05c 100644 --- a/src/psycopt2d/visualization/roc_auc.py +++ b/src/psycopt2d/visualization/roc_auc.py @@ -32,7 +32,6 @@ def plot_auc_roc( plt.legend(loc=4) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") - plt.show() if save_path is not None: plt.savefig(save_path)