From 6247eca924a215cfd99f0a18aca18a64b5d0c917 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 14:49:04 +0000 Subject: [PATCH 1/8] feat: working plot with only one threshold --- .../base_artifacts/plots/base_charts.py | 9 +++ .../plots/performance_over_time.py | 56 ++++++++++++++++++- tests/model_evaluation/test_visualizations.py | 16 +++++- 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py index c038e2ee..17995d7f 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py @@ -15,6 +15,8 @@ def plot_basic_chart( plot_type: Union[list[str], str], sort_x: Optional[Iterable[int]] = None, sort_y: Optional[Iterable[int]] = None, + flip_x_axis: bool = False, + flip_y_axis: bool = False, y_limits: Optional[tuple[float, float]] = None, fig_size: Optional[tuple] = (5, 5), dpi: Optional[int] = 300, @@ -36,6 +38,8 @@ def plot_basic_chart( fig_size (Optional[tuple], optional): figure size. Defaults to None. dpi (Optional[int], optional): dpi of figure. Defaults to 300. save_path (Optional[Path], optional): path to save figure. Defaults to None. + flip_x_axis (bool, optional): Whether to flip the x axis. Defaults to False. + flip_y_axis (bool, optional): Whether to flip the y axis. Defaults to False. Returns: Union[None, Path]: None if save_path is None, else path to saved figure @@ -73,6 +77,11 @@ def plot_basic_chart( if y_limits is not None: plt.ylim(y_limits) + if flip_x_axis: + plt.gca().invert_xaxis() + if flip_y_axis: + plt.gca().invert_yaxis() + if save_path is not None: plt.savefig(save_path) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index c7a86e09..1f49deae 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -10,11 +10,14 @@ import numpy as np import pandas as pd -from sklearn.metrics import f1_score, roc_auc_score +from sklearn.metrics import f1_score, recall_score, roc_auc_score from psycop_model_training.model_eval.base_artifacts.plots.base_charts import ( plot_basic_chart, ) +from psycop_model_training.model_eval.base_artifacts.plots.sens_over_time import ( + create_sensitivity_by_time_to_outcome_df, +) from psycop_model_training.model_eval.base_artifacts.plots.utils import calc_performance from psycop_model_training.model_eval.dataclasses import EvalDataset from psycop_model_training.utils.utils import bin_continuous_data, round_floats_to_edge @@ -35,6 +38,7 @@ def create_performance_by_calendar_time_df( timestamps (Iterable[pd.Timestamp]): Timestamps of predictions metric_fn (Callable): Callable which returns the metric to calculate bin_period (str): How to bin time. Takes "M" for month, "Q" for quarter or "Y" for year + threshold_percentile (float, optional): Threshold percentile of highest predicted probabilities to mark as positive in binary classification. Defaults to None. Returns: pd.DataFrame: Dataframe ready for plotting @@ -50,6 +54,56 @@ def create_performance_by_calendar_time_df( return output_df +def plot_recall_by_calendar_time( + eval_dataset: EvalDataset, + pred_proba_threshold: float, + bins: Iterable[float], + y_title: str = "Sensitivity (Recall)", + bin_period: str = "Y", + y_limits: Optional[tuple[float, float]] = None, + save_path: Optional[str] = None, +) -> Union[None, Path]: + """Plot performance by calendar time of prediciton. + + Args: + eval_dataset (EvalDataset): EvalDataset object + pred_proba_threshold (float): Threshold for predicted probabilities to mark as positive in binary classification. + bins (Iterable[float], optional): Bins to use for time to outcome. + y_title (str): Title of y-axis. Defaults to "AUC". + bin_period (str): Which time period to bin on. Takes "M" for month, "Q" for quarter or "Y" for year + save_path (str, optional): Path to save figure. Defaults to None. + metric_fn (Callable): Function which returns the metric. Defaults to roc_auc_score. + y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0). + + Returns: + Union[None, Path]: Path to saved figure or None if not saved. + """ + df = create_sensitivity_by_time_to_outcome_df( + labels=eval_dataset.y, + y_hat_probs=eval_dataset.y_hat_probs, + pred_proba_threshold=pred_proba_threshold, + outcome_timestamps=eval_dataset.outcome_timestamps, + prediction_timestamps=eval_dataset.pred_timestamps, + bins=bins, + ) + + return plot_basic_chart( + x_values=df["days_to_outcome_binned"], + y_values=df["sens"], + x_title="Month" + if bin_period == "M" + else "Quarter" + if bin_period == "Q" + else "Year", + sort_x=df["days_to_outcome_binned"][::-1], # Reverse the order of the bins + y_title=y_title, + y_limits=y_limits, + flip_x_axis=True, + plot_type=["line", "scatter"], + save_path=save_path, + ) + + def plot_metric_by_calendar_time( eval_dataset: EvalDataset, y_title: str = "AUC", diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index 5b1a8575..abd96884 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd import pytest -from sklearn.metrics import f1_score, roc_auc_score +from sklearn.metrics import f1_score, recall_score, roc_auc_score from psycop_model_training.model_eval.base_artifacts.plots.base_charts import ( plot_basic_chart, @@ -24,6 +24,7 @@ plot_metric_by_calendar_time, plot_metric_by_cyclic_time, plot_metric_by_time_until_diagnosis, + plot_recall_by_calendar_time, ) from psycop_model_training.model_eval.base_artifacts.plots.precision_recall import ( plot_precision_recall, @@ -119,6 +120,19 @@ def test_plot_performance_by_calendar_time( eval_dataset=synth_eval_dataset, bin_period=bin_period, metric_fn=roc_auc_score, + save_path=PROJECT_ROOT / "test.png", + ) + + +def test_plot_recall_by_calendar_time( + synth_eval_dataset: EvalDataset, +): + plot_recall_by_calendar_time( + eval_dataset=synth_eval_dataset, + bin_period="Q", + pred_proba_threshold=0.5, + bins=(0, 14, 28, 162, 365, 730, 1460), + save_path=PROJECT_ROOT / "test_recall_by_calendar_time.png", ) From 447daaeec1b0376abcd887d84c88e41b089d9fcc Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 15:08:40 +0000 Subject: [PATCH 2/8] feat: add support for multiple y-series on the same base-plot --- .../base_artifacts/plots/base_charts.py | 28 +++++++++------ .../plots/performance_over_time.py | 36 ++++++++++++------- tests/model_evaluation/test_visualizations.py | 6 ++-- 3 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py index 17995d7f..e0d04011 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py @@ -5,11 +5,12 @@ import matplotlib.pyplot as plt import pandas as pd +from numpy import isin def plot_basic_chart( x_values: Iterable, - y_values: Iterable, + y_values: Union[Iterable[int, float], Iterable[Iterable[int, float]]], x_title: str, y_title: str, plot_type: Union[list[str], str], @@ -17,6 +18,7 @@ def plot_basic_chart( sort_y: Optional[Iterable[int]] = None, flip_x_axis: bool = False, flip_y_axis: bool = False, + legend: bool = False, y_limits: Optional[tuple[float, float]] = None, fig_size: Optional[tuple] = (5, 5), dpi: Optional[int] = 300, @@ -48,7 +50,7 @@ def plot_basic_chart( plot_type = [plot_type] df = pd.DataFrame( - {"x": x_values, "y": y_values, "sort_x": sort_x, "sort_y": sort_y}, + {"x": x_values, "sort_x": sort_x, "sort_y": sort_y}, ) if sort_x is not None: @@ -59,15 +61,19 @@ def plot_basic_chart( plt.figure(figsize=fig_size, dpi=dpi) - if "bar" in plot_type: - plt.bar(df["x"], df["y"]) - if "hbar" in plot_type: - plt.barh(df["x"], df["y"]) - plt.yticks(fontsize=7) - if "line" in plot_type: - plt.plot(df["x"], df["y"]) - if "scatter" in plot_type: - plt.scatter(df["x"], df["y"]) + if not isinstance(y_values[0], Iterable): + y_values = [y_values] # Make y_values an iterable + + for y_series in y_values: + if "bar" in plot_type: + plt.bar(df["x"], y_series) + if "hbar" in plot_type: + plt.barh(df["x"], y_series) + plt.yticks(fontsize=7) + if "line" in plot_type: + plt.plot(df["x"], y_series) + if "scatter" in plot_type: + plt.scatter(df["x"], y_series) plt.xlabel(x_title) plt.ylabel(y_title) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index 1f49deae..acd66657 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -56,10 +56,11 @@ def create_performance_by_calendar_time_df( def plot_recall_by_calendar_time( eval_dataset: EvalDataset, - pred_proba_threshold: float, + pred_proba_percentile: Union[float, Iterable[float]], bins: Iterable[float], y_title: str = "Sensitivity (Recall)", bin_period: str = "Y", + legend: bool = True, y_limits: Optional[tuple[float, float]] = None, save_path: Optional[str] = None, ) -> Union[None, Path]: @@ -67,7 +68,7 @@ def plot_recall_by_calendar_time( Args: eval_dataset (EvalDataset): EvalDataset object - pred_proba_threshold (float): Threshold for predicted probabilities to mark as positive in binary classification. + pred_proba_percentile (Union[float, Iterable[float]]): Percentile of highest predicted probabilities to mark as positive in binary classification. bins (Iterable[float], optional): Bins to use for time to outcome. y_title (str): Title of y-axis. Defaults to "AUC". bin_period (str): Which time period to bin on. Takes "M" for month, "Q" for quarter or "Y" for year @@ -78,29 +79,38 @@ def plot_recall_by_calendar_time( Returns: Union[None, Path]: Path to saved figure or None if not saved. """ - df = create_sensitivity_by_time_to_outcome_df( - labels=eval_dataset.y, - y_hat_probs=eval_dataset.y_hat_probs, - pred_proba_threshold=pred_proba_threshold, - outcome_timestamps=eval_dataset.outcome_timestamps, - prediction_timestamps=eval_dataset.pred_timestamps, - bins=bins, - ) + if not isinstance(pred_proba_percentile, Iterable): + pred_proba_percentile = [pred_proba_percentile] + + # Get percentiles from a series of predicted probabilities + pred_proba_percentiles = eval_dataset.y_hat_probs.rank(pct=True) + + dfs = [ + create_sensitivity_by_time_to_outcome_df( + labels=eval_dataset.y, + y_hat_probs=pred_proba_percentiles, + pred_proba_threshold=threshold, + outcome_timestamps=eval_dataset.outcome_timestamps, + prediction_timestamps=eval_dataset.pred_timestamps, + bins=bins, + ) + for threshold in pred_proba_percentile + ] return plot_basic_chart( - x_values=df["days_to_outcome_binned"], - y_values=df["sens"], + x_values=dfs[0]["days_to_outcome_binned"], + y_values=[df["sens"] for df in dfs], x_title="Month" if bin_period == "M" else "Quarter" if bin_period == "Q" else "Year", - sort_x=df["days_to_outcome_binned"][::-1], # Reverse the order of the bins y_title=y_title, y_limits=y_limits, flip_x_axis=True, plot_type=["line", "scatter"], save_path=save_path, + legend=True, ) diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index abd96884..28d73daa 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -130,8 +130,10 @@ def test_plot_recall_by_calendar_time( plot_recall_by_calendar_time( eval_dataset=synth_eval_dataset, bin_period="Q", - pred_proba_threshold=0.5, - bins=(0, 14, 28, 162, 365, 730, 1460), + pred_proba_percentile=[0.8, 0.9, 0.95], + bins=list(range(0, 1460, 180)), + legend=True, + y_limits=(0, 0.5), save_path=PROJECT_ROOT / "test_recall_by_calendar_time.png", ) From e2fa846417590768aa94a4fa8ae5f2726bb00fb6 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 15:09:55 +0000 Subject: [PATCH 3/8] style: linting --- .../model_eval/base_artifacts/plots/base_charts.py | 1 - .../model_eval/base_artifacts/plots/performance_over_time.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py index e0d04011..4933d164 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py @@ -18,7 +18,6 @@ def plot_basic_chart( sort_y: Optional[Iterable[int]] = None, flip_x_axis: bool = False, flip_y_axis: bool = False, - legend: bool = False, y_limits: Optional[tuple[float, float]] = None, fig_size: Optional[tuple] = (5, 5), dpi: Optional[int] = 300, diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index acd66657..c8d936c9 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -38,7 +38,6 @@ def create_performance_by_calendar_time_df( timestamps (Iterable[pd.Timestamp]): Timestamps of predictions metric_fn (Callable): Callable which returns the metric to calculate bin_period (str): How to bin time. Takes "M" for month, "Q" for quarter or "Y" for year - threshold_percentile (float, optional): Threshold percentile of highest predicted probabilities to mark as positive in binary classification. Defaults to None. Returns: pd.DataFrame: Dataframe ready for plotting @@ -60,7 +59,6 @@ def plot_recall_by_calendar_time( bins: Iterable[float], y_title: str = "Sensitivity (Recall)", bin_period: str = "Y", - legend: bool = True, y_limits: Optional[tuple[float, float]] = None, save_path: Optional[str] = None, ) -> Union[None, Path]: @@ -73,7 +71,6 @@ def plot_recall_by_calendar_time( y_title (str): Title of y-axis. Defaults to "AUC". bin_period (str): Which time period to bin on. Takes "M" for month, "Q" for quarter or "Y" for year save_path (str, optional): Path to save figure. Defaults to None. - metric_fn (Callable): Function which returns the metric. Defaults to roc_auc_score. y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0). Returns: From 830c9ba79c6474bf003c6a457074e53700d61666 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 15:13:34 +0000 Subject: [PATCH 4/8] feat: add the plot to base artifact generator --- .../base_artifacts/base_artifact_generator.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py b/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py index 34b92c8b..b5ea03ca 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py +++ b/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py @@ -14,6 +14,7 @@ plot_metric_by_calendar_time, plot_metric_by_cyclic_time, plot_metric_by_time_until_diagnosis, + plot_recall_by_calendar_time, ) from psycop_model_training.model_eval.base_artifacts.plots.precision_recall import ( plot_precision_recall, @@ -162,6 +163,18 @@ def create_base_plot_artifacts(self) -> list[ArtifactContainer]: save_path=self.save_dir / "precision_recall.png", ), ), + ArtifactContainer( + label="precision_recall", + artifact=plot_recall_by_calendar_time( + eval_dataset=self.eval_ds, + bin_period="Q", + pred_proba_percentile=[0.95, 0.97, 0.99], + bins=self.cfg.eval.lookahead_bins, + legend=True, + y_limits=(0, 0.5), + save_path=self.save_dir / "recall_by_calendar_time.png", + ), + ), ] def get_feature_selection_artifacts(self): From d23357b670bd75f44c2aa3b9d3eb8642ac0ba8ac Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 15:14:19 +0000 Subject: [PATCH 5/8] fix: match x-title and configuration --- .../model_eval/base_artifacts/base_artifact_generator.py | 1 - .../base_artifacts/plots/performance_over_time.py | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py b/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py index b5ea03ca..1a1b491a 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py +++ b/src/psycop_model_training/model_eval/base_artifacts/base_artifact_generator.py @@ -167,7 +167,6 @@ def create_base_plot_artifacts(self) -> list[ArtifactContainer]: label="precision_recall", artifact=plot_recall_by_calendar_time( eval_dataset=self.eval_ds, - bin_period="Q", pred_proba_percentile=[0.95, 0.97, 0.99], bins=self.cfg.eval.lookahead_bins, legend=True, diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index c8d936c9..644ad23d 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -58,7 +58,6 @@ def plot_recall_by_calendar_time( pred_proba_percentile: Union[float, Iterable[float]], bins: Iterable[float], y_title: str = "Sensitivity (Recall)", - bin_period: str = "Y", y_limits: Optional[tuple[float, float]] = None, save_path: Optional[str] = None, ) -> Union[None, Path]: @@ -69,7 +68,6 @@ def plot_recall_by_calendar_time( pred_proba_percentile (Union[float, Iterable[float]]): Percentile of highest predicted probabilities to mark as positive in binary classification. bins (Iterable[float], optional): Bins to use for time to outcome. y_title (str): Title of y-axis. Defaults to "AUC". - bin_period (str): Which time period to bin on. Takes "M" for month, "Q" for quarter or "Y" for year save_path (str, optional): Path to save figure. Defaults to None. y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0). @@ -97,11 +95,7 @@ def plot_recall_by_calendar_time( return plot_basic_chart( x_values=dfs[0]["days_to_outcome_binned"], y_values=[df["sens"] for df in dfs], - x_title="Month" - if bin_period == "M" - else "Quarter" - if bin_period == "Q" - else "Year", + x_title="Days from event", y_title=y_title, y_limits=y_limits, flip_x_axis=True, From edda59c25a326bcad4302faea55aacf9098a804d Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Mon, 20 Feb 2023 15:27:40 +0000 Subject: [PATCH 6/8] tests: remove deprecated arguments --- .../model_eval/base_artifacts/plots/performance_over_time.py | 1 - tests/model_evaluation/test_visualizations.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index 644ad23d..c6162c23 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -101,7 +101,6 @@ def plot_recall_by_calendar_time( flip_x_axis=True, plot_type=["line", "scatter"], save_path=save_path, - legend=True, ) diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index 28d73daa..6e93aeb2 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -129,10 +129,8 @@ def test_plot_recall_by_calendar_time( ): plot_recall_by_calendar_time( eval_dataset=synth_eval_dataset, - bin_period="Q", pred_proba_percentile=[0.8, 0.9, 0.95], bins=list(range(0, 1460, 180)), - legend=True, y_limits=(0, 0.5), save_path=PROJECT_ROOT / "test_recall_by_calendar_time.png", ) From 94edcdda63d0d9b2f8e867a80d63a6f1f76e7ed5 Mon Sep 17 00:00:00 2001 From: Lasse Date: Wed, 22 Feb 2023 10:39:48 +0100 Subject: [PATCH 7/8] feat: add labels to base chart --- .../base_artifacts/plots/base_charts.py | 38 ++++++++++++++----- .../plots/performance_over_time.py | 1 + 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py index 4933d164..36da72e8 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py @@ -14,6 +14,7 @@ def plot_basic_chart( x_title: str, y_title: str, plot_type: Union[list[str], str], + labels: Optional[list[str]] = None, sort_x: Optional[Iterable[int]] = None, sort_y: Optional[Iterable[int]] = None, flip_x_axis: bool = False, @@ -33,6 +34,7 @@ def plot_basic_chart( y_title (str): title of y axis plot_type (Optional[Union[List[str], str]], optional): type of plots. Options are combinations of ["bar", "hbar", "line", "scatter"] Defaults to "bar". + labels: (Optional[list[str]]): Optional labels to add to the plot(s). sort_x (Optional[Iterable[int]], optional): order of values on the x-axis. Defaults to None. sort_y (Optional[Iterable[int]], optional): order of values on the y-axis. Defaults to None. y_limits (Optional[tuple[float, float]], optional): y-axis limits. Defaults to None. @@ -63,16 +65,26 @@ def plot_basic_chart( if not isinstance(y_values[0], Iterable): y_values = [y_values] # Make y_values an iterable + if len(plot_type) > 1: + label_plot_type = plot_type[0] + plot_functions = { + "bar": plt.bar, + "hbar": plt.barh, + "line": plt.plot, + "scatter": plt.scatter, + } + + label_plot = plot_functions.get(label_plot_type) + + label_plots = [] for y_series in y_values: - if "bar" in plot_type: - plt.bar(df["x"], y_series) - if "hbar" in plot_type: - plt.barh(df["x"], y_series) - plt.yticks(fontsize=7) - if "line" in plot_type: - plt.plot(df["x"], y_series) - if "scatter" in plot_type: - plt.scatter(df["x"], y_series) + for p_type in plot_type: + plot_function = plot_functions.get(p_type) + plot = plot_function(df["x"], y_series) + if plot_function == label_plot: + label_plots.append(plot) + if p_type == "hbar": + plt.yticks(fontsize=7) plt.xlabel(x_title) plt.ylabel(y_title) @@ -86,6 +98,14 @@ def plot_basic_chart( plt.gca().invert_xaxis() if flip_y_axis: plt.gca().invert_yaxis() + if labels is not None: + plt.figlegend( + [plot[0] for plot in label_plots], + [str(label) for label in labels], + loc="upper right", + bbox_to_anchor=(0.9, 0.88), + frameon=True, + ) if save_path is not None: plt.savefig(save_path) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py index c6162c23..6a026505 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/performance_over_time.py @@ -96,6 +96,7 @@ def plot_recall_by_calendar_time( x_values=dfs[0]["days_to_outcome_binned"], y_values=[df["sens"] for df in dfs], x_title="Days from event", + labels=pred_proba_percentile, y_title=y_title, y_limits=y_limits, flip_x_axis=True, From 33198b75d2bac7b90aef2309210e69b7a8148651 Mon Sep 17 00:00:00 2001 From: Lasse Date: Wed, 22 Feb 2023 10:43:59 +0100 Subject: [PATCH 8/8] fix: minor refactor --- .../base_artifacts/plots/base_charts.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py index 36da72e8..614eb393 100644 --- a/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py +++ b/src/psycop_model_training/model_eval/base_artifacts/plots/base_charts.py @@ -65,23 +65,23 @@ def plot_basic_chart( if not isinstance(y_values[0], Iterable): y_values = [y_values] # Make y_values an iterable - if len(plot_type) > 1: - label_plot_type = plot_type[0] - plot_functions = { - "bar": plt.bar, - "hbar": plt.barh, - "line": plt.plot, - "scatter": plt.scatter, - } + plot_functions = { + "bar": plt.bar, + "hbar": plt.barh, + "line": plt.plot, + "scatter": plt.scatter, + } - label_plot = plot_functions.get(label_plot_type) + # choose the first plot type as the one to use for legend + label_plot_type = plot_type[0] label_plots = [] for y_series in y_values: for p_type in plot_type: plot_function = plot_functions.get(p_type) plot = plot_function(df["x"], y_series) - if plot_function == label_plot: + if p_type == label_plot_type: + # need to one of the plot types for labelling label_plots.append(plot) if p_type == "hbar": plt.yticks(fontsize=7)