From e0d9e55202697c102325277b5ad4576888a2b4af Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Tue, 18 Oct 2022 13:10:26 +0200 Subject: [PATCH 1/3] Fix sensitivity heatmap. Fixes #247 --- src/psycopt2d/evaluation.py | 6 +- src/psycopt2d/visualization/sens_over_time.py | 147 +++++++++++------- tests/test_sens_over_time.py | 41 ----- tests/test_visualizations.py | 4 +- 4 files changed, 100 insertions(+), 98 deletions(-) delete mode 100644 tests/test_sens_over_time.py diff --git a/src/psycopt2d/evaluation.py b/src/psycopt2d/evaluation.py index aa90e6a0..c9bc9f08 100644 --- a/src/psycopt2d/evaluation.py +++ b/src/psycopt2d/evaluation.py @@ -27,7 +27,9 @@ plot_metric_by_time_until_diagnosis, plot_performance_by_calendar_time, ) -from psycopt2d.visualization.sens_over_time import plot_sensitivity_by_time_to_outcome +from psycopt2d.visualization.sens_over_time import ( + plot_sensitivity_by_time_to_outcome_heatmap, +) from psycopt2d.visualization.utils import log_image_to_wandb @@ -177,7 +179,7 @@ def evaluate_model( # Add plots plots.update( { - "sensitivity_by_time_by_threshold": plot_sensitivity_by_time_to_outcome( + "sensitivity_by_time_by_threshold": plot_sensitivity_by_time_to_outcome_heatmap( labels=y, y_hat_probs=y_hat_probs, pred_proba_thresholds=pred_proba_thresholds, diff --git a/src/psycopt2d/visualization/sens_over_time.py b/src/psycopt2d/visualization/sens_over_time.py index 1f4c0da3..6ec5f15b 100644 --- a/src/psycopt2d/visualization/sens_over_time.py +++ b/src/psycopt2d/visualization/sens_over_time.py @@ -48,7 +48,7 @@ def create_sensitivity_by_time_to_outcome_df( }, ) - # Get proportion of y_hat == 1, which is equal to the positive rate + # Get proportion of y_hat == 1, which is equal to the positive rate in the data threshold_percentile = round( df[df["y_hat"] == 1].shape[0] / df.shape[0] * 100, 2, @@ -100,26 +100,33 @@ def create_sensitivity_by_time_to_outcome_df( def _generate_sensitivity_array( df: pd.DataFrame, n_decimals_y_axis: int, + y_label_col_name: str, ): """Generate sensitivity array for plotting heatmap. Args: df (pd.DataFrame): Dataframe with columns "sens", "days_to_outcome_binned" and "threshold". + y_label_col_name (str): Name of the column to use for the y-axis labels. n_decimals_y_axis (int): Number of decimals to round y axis labels to. Returns: A tuple containing the generated sensitivity array (np.ndarray), the x axis labels and the y axis labels rounded to n_decimals_y_axis. """ x_labels = df["days_to_outcome_binned"].unique().tolist() - y_labels = df["threshold"].unique().tolist() + + y_labels = df[y_label_col_name].unique().tolist() + y_labels_rounded = [ round(y_labels[value], n_decimals_y_axis) for value in range(len(y_labels)) ] sensitivity_array = [] - for threshold in y_labels: + + for threshold in df["threshold"].unique().tolist(): sensitivity_current_threshold = [] + df_subset_y = df[df["threshold"] == threshold] + for days_interval in x_labels: df_subset_y_x = df_subset_y[ df_subset_y["days_to_outcome_binned"] == days_interval @@ -205,7 +212,64 @@ def _annotate_heatmap( return texts -def plot_sensitivity_by_time_to_outcome( +def _format_sens_by_time_heatmap( + colorbar_label, + x_title, + y_title, + data, + x_labels, + y_labels, + fig, + axes, + image, +) -> tuple[plt.Figure, plt.Axes]: + # Create colorbar + cbar = axes.figure.colorbar(image, ax=axes) + cbar.ax.set_ylabel(colorbar_label, rotation=-90, va="bottom") + + # Show all ticks and label them with the respective list entries. + axes.set_xticks(np.arange(data.shape[1]), labels=x_labels) + axes.set_yticks(np.arange(data.shape[0]), labels=y_labels) + + # Let the horizontal axes labeling appear on top. + axes.tick_params( + top=False, + bottom=True, + labeltop=False, + labelbottom=True, + ) + + # Rotate the tick labels and set their alignment. + plt.setp( + axes.get_xticklabels(), + rotation=90, + ha="right", + rotation_mode="anchor", + ) + + # Turn spines off and create white grid. + axes.spines[:].set_visible(False) + + axes.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) + axes.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) + axes.grid(which="minor", color="w", linestyle="-", linewidth=3) + axes.tick_params(which="minor", bottom=False, left=False) + + # Add annotations + _ = _annotate_heatmap(image, value_formatter="{x:.1f}") + + # Set axis labels and title + axes.set( + xlabel=x_title, + ylabel=y_title, + ) + + fig.tight_layout() + + return fig, axes + + +def plot_sensitivity_by_time_to_outcome_heatmap( labels: Iterable[int], y_hat_probs: Iterable[int], pred_proba_thresholds: list[float], @@ -216,7 +280,7 @@ def plot_sensitivity_by_time_to_outcome( colorbar_label: Optional[str] = "Sensitivity", x_title: Optional[str] = "Days to outcome", y_title: Optional[str] = "Positive rate", - n_decimals_y_axis: Optional[int] = 4, + n_decimals_y_axis: int = 4, save_path: Optional[Path] = None, ) -> Union[None, Path]: """Plot heatmap of sensitivity by time to outcome according to different @@ -232,8 +296,8 @@ def plot_sensitivity_by_time_to_outcome( color_map (str, optional): Colormap to use. Defaults to "PuBu". colorbar_label (str, optional): Colorbar label. Defaults to "Sensitivity". x_title (str, optional): X axis title. Defaults to "Days to outcome". - y_title (str, optional): Y axis title. Defaults to "Positive rate". - n_decimals_y_axis (int, optional): Number of decimals to round y axis labels. Defaults to 4. + y_title (str, optional): Y axis title. Defaults to "y_hat percentile". + n_decimals_y_axis (int): Number of decimals to round y axis labels. Defaults to 4. save_path (Optional[Path], optional): Path to save the plot. Defaults to None. Returns: @@ -264,6 +328,10 @@ def plot_sensitivity_by_time_to_outcome( >>> ) """ # Construct sensitivity dataframe + # Note that threshold_percentile IS equal to the positive rate, + # since it is calculated on the entire dataset, not just those + # whose true label is 1. + func = partial( create_sensitivity_by_time_to_outcome_df, labels=labels, @@ -284,56 +352,29 @@ def plot_sensitivity_by_time_to_outcome( ) # Prepare data for plotting - data, x_labels, y_labels = _generate_sensitivity_array(df, n_decimals_y_axis) - - fig, ax = plt.subplots() # pylint: disable=invalid-name - - # Plot the heatmap - im = ax.imshow(data, cmap=color_map) # pylint: disable=invalid-name - - # Create colorbar - cbar = ax.figure.colorbar(im, ax=ax) - cbar.ax.set_ylabel(colorbar_label, rotation=-90, va="bottom") - - # Show all ticks and label them with the respective list entries. - ax.set_xticks(np.arange(data.shape[1]), labels=x_labels) - ax.set_yticks(np.arange(data.shape[0]), labels=y_labels) - - # Let the horizontal axes labeling appear on top. - ax.tick_params( - top=False, - bottom=True, - labeltop=False, - labelbottom=True, + data, x_labels, y_labels = _generate_sensitivity_array( + df, + n_decimals_y_axis=n_decimals_y_axis, + y_label_col_name="threshold_percentile", ) - # Rotate the tick labels and set their alignment. - plt.setp( - ax.get_xticklabels(), - rotation=90, - ha="right", - rotation_mode="anchor", - ) - - # Turn spines off and create white grid. - ax.spines[:].set_visible(False) - - ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) - ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) - ax.grid(which="minor", color="w", linestyle="-", linewidth=3) - ax.tick_params(which="minor", bottom=False, left=False) + fig, axes = plt.subplots() # pylint: disable=invalid-name - # Add annotations - _ = _annotate_heatmap(im, value_formatter="{x:.1f}") - - # Set axis labels and title - ax.set( - xlabel=x_title, - ylabel=y_title, + # Plot the heatmap + image = axes.imshow(data, cmap=color_map) # pylint: disable=invalid-name + + fig, axes = _format_sens_by_time_heatmap( + colorbar_label=colorbar_label, + x_title=x_title, + y_title=y_title, + data=data, + x_labels=x_labels, + y_labels=y_labels, + fig=fig, + axes=axes, + image=image, ) - fig.tight_layout() - if save_path is None: plt.show() else: @@ -358,7 +399,7 @@ def plot_sensitivity_by_time_to_outcome( positive_rate_thresholds=positive_rate_thresholds, ) - plot_sensitivity_by_time_to_outcome( + plot_sensitivity_by_time_to_outcome_heatmap( labels=df["label"], y_hat_probs=df["pred_prob"], pred_proba_thresholds=pred_proba_thresholds, diff --git a/tests/test_sens_over_time.py b/tests/test_sens_over_time.py deleted file mode 100644 index fd7d8e72..00000000 --- a/tests/test_sens_over_time.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests of sens over time.""" -# pylint: disable=missing-function-docstring -from pathlib import Path - -import pandas as pd -import pytest - -from psycopt2d.utils import positive_rate_to_pred_probs -from psycopt2d.visualization.sens_over_time import plot_sensitivity_by_time_to_outcome - - -@pytest.fixture(scope="function") -def df(): - repo_path = Path(__file__).parent - path = repo_path / "test_data" / "synth_eval_data.csv" - df = pd.read_csv(path) - - # Convert all timestamp cols to datetime[64]ns - for col in [col for col in df.columns if "timestamp" in col]: - df[col] = pd.to_datetime(df[col]) - - return df - - -def test_plot_sensitivity_by_time_to_outcome(df, tmp_path): - positive_rates = [0.95, 0.99, 0.999, 0.9999] - - pred_proba_thresholds = positive_rate_to_pred_probs( - pred_probs=df["pred_prob"], - positive_rate_thresholds=positive_rates, - ) - - plot_sensitivity_by_time_to_outcome( - labels=df["label"], - y_hat_probs=df["pred_prob"], - pred_proba_thresholds=pred_proba_thresholds, - outcome_timestamps=df["timestamp_t2d_diag"], - prediction_timestamps=df["timestamp"], - bins=[0, 28, 182, 365, 730, 1825], - save_path=tmp_path / "sensitivity_by_time_by_threshold.png", - ) diff --git a/tests/test_visualizations.py b/tests/test_visualizations.py index 137859e4..b108e04f 100644 --- a/tests/test_visualizations.py +++ b/tests/test_visualizations.py @@ -21,7 +21,7 @@ ) from psycopt2d.visualization.sens_over_time import ( create_sensitivity_by_time_to_outcome_df, - plot_sensitivity_by_time_to_outcome, + plot_sensitivity_by_time_to_outcome_heatmap, ) @@ -116,7 +116,7 @@ def test_plot_sens_by_time_to_outcome(df, tmp_path): positive_rate_thresholds=positive_rate_thresholds, ) - plot_sensitivity_by_time_to_outcome( # noqa + plot_sensitivity_by_time_to_outcome_heatmap( # noqa labels=df["label"], y_hat_probs=df["pred_prob"], outcome_timestamps=df["timestamp_t2d_diag"], From 554ddf817e99748f6c999d02f38f490ba2c46430 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Tue, 18 Oct 2022 13:25:40 +0200 Subject: [PATCH 2/3] fix: incorrect order of y-labels on sens heatmap From 7f288bbb2cb8168f2a9953a140230930550c36fe Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Tue, 18 Oct 2022 14:27:16 +0200 Subject: [PATCH 3/3] refactor: remove argument that should always be default --- src/psycopt2d/visualization/sens_over_time.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/psycopt2d/visualization/sens_over_time.py b/src/psycopt2d/visualization/sens_over_time.py index 6ec5f15b..c7f6be5a 100644 --- a/src/psycopt2d/visualization/sens_over_time.py +++ b/src/psycopt2d/visualization/sens_over_time.py @@ -100,13 +100,11 @@ def create_sensitivity_by_time_to_outcome_df( def _generate_sensitivity_array( df: pd.DataFrame, n_decimals_y_axis: int, - y_label_col_name: str, ): """Generate sensitivity array for plotting heatmap. Args: df (pd.DataFrame): Dataframe with columns "sens", "days_to_outcome_binned" and "threshold". - y_label_col_name (str): Name of the column to use for the y-axis labels. n_decimals_y_axis (int): Number of decimals to round y axis labels to. Returns: @@ -114,7 +112,7 @@ def _generate_sensitivity_array( """ x_labels = df["days_to_outcome_binned"].unique().tolist() - y_labels = df[y_label_col_name].unique().tolist() + y_labels = df["threshold_percentile"].unique().tolist() y_labels_rounded = [ round(y_labels[value], n_decimals_y_axis) for value in range(len(y_labels)) @@ -355,7 +353,6 @@ def plot_sensitivity_by_time_to_outcome_heatmap( data, x_labels, y_labels = _generate_sensitivity_array( df, n_decimals_y_axis=n_decimals_y_axis, - y_label_col_name="threshold_percentile", ) fig, axes = plt.subplots() # pylint: disable=invalid-name