Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #272 from Aarhus-Psychiatry-Research/MartinBernsto…
Browse files Browse the repository at this point in the history
…rff/Fix-sensitivity-heatmap

fix: incorrect ordering of y_labels in sensitivity heatmap
  • Loading branch information
MartinBernstorff authored Oct 18, 2022
2 parents e1ab5c5 + 7f288bb commit 9a373c9
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 98 deletions.
6 changes: 4 additions & 2 deletions src/psycopt2d/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
plot_metric_by_time_until_diagnosis,
plot_performance_by_calendar_time,
)
from psycopt2d.visualization.sens_over_time import plot_sensitivity_by_time_to_outcome
from psycopt2d.visualization.sens_over_time import (
plot_sensitivity_by_time_to_outcome_heatmap,
)
from psycopt2d.visualization.utils import log_image_to_wandb


Expand Down Expand Up @@ -177,7 +179,7 @@ def evaluate_model(
# Add plots
plots.update(
{
"sensitivity_by_time_by_threshold": plot_sensitivity_by_time_to_outcome(
"sensitivity_by_time_by_threshold": plot_sensitivity_by_time_to_outcome_heatmap(
labels=y,
y_hat_probs=y_hat_probs,
pred_proba_thresholds=pred_proba_thresholds,
Expand Down
144 changes: 91 additions & 53 deletions src/psycopt2d/visualization/sens_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_sensitivity_by_time_to_outcome_df(
},
)

# Get proportion of y_hat == 1, which is equal to the positive rate
# Get proportion of y_hat == 1, which is equal to the positive rate in the data
threshold_percentile = round(
df[df["y_hat"] == 1].shape[0] / df.shape[0] * 100,
2,
Expand Down Expand Up @@ -111,15 +111,20 @@ def _generate_sensitivity_array(
A tuple containing the generated sensitivity array (np.ndarray), the x axis labels and the y axis labels rounded to n_decimals_y_axis.
"""
x_labels = df["days_to_outcome_binned"].unique().tolist()
y_labels = df["threshold"].unique().tolist()

y_labels = df["threshold_percentile"].unique().tolist()

y_labels_rounded = [
round(y_labels[value], n_decimals_y_axis) for value in range(len(y_labels))
]

sensitivity_array = []
for threshold in y_labels:

for threshold in df["threshold"].unique().tolist():
sensitivity_current_threshold = []

df_subset_y = df[df["threshold"] == threshold]

for days_interval in x_labels:
df_subset_y_x = df_subset_y[
df_subset_y["days_to_outcome_binned"] == days_interval
Expand Down Expand Up @@ -205,7 +210,64 @@ def _annotate_heatmap(
return texts


def plot_sensitivity_by_time_to_outcome(
def _format_sens_by_time_heatmap(
colorbar_label,
x_title,
y_title,
data,
x_labels,
y_labels,
fig,
axes,
image,
) -> tuple[plt.Figure, plt.Axes]:
# Create colorbar
cbar = axes.figure.colorbar(image, ax=axes)
cbar.ax.set_ylabel(colorbar_label, rotation=-90, va="bottom")

# Show all ticks and label them with the respective list entries.
axes.set_xticks(np.arange(data.shape[1]), labels=x_labels)
axes.set_yticks(np.arange(data.shape[0]), labels=y_labels)

# Let the horizontal axes labeling appear on top.
axes.tick_params(
top=False,
bottom=True,
labeltop=False,
labelbottom=True,
)

# Rotate the tick labels and set their alignment.
plt.setp(
axes.get_xticklabels(),
rotation=90,
ha="right",
rotation_mode="anchor",
)

# Turn spines off and create white grid.
axes.spines[:].set_visible(False)

axes.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
axes.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
axes.grid(which="minor", color="w", linestyle="-", linewidth=3)
axes.tick_params(which="minor", bottom=False, left=False)

# Add annotations
_ = _annotate_heatmap(image, value_formatter="{x:.1f}")

# Set axis labels and title
axes.set(
xlabel=x_title,
ylabel=y_title,
)

fig.tight_layout()

return fig, axes


def plot_sensitivity_by_time_to_outcome_heatmap(
labels: Iterable[int],
y_hat_probs: Iterable[int],
pred_proba_thresholds: list[float],
Expand All @@ -216,7 +278,7 @@ def plot_sensitivity_by_time_to_outcome(
colorbar_label: Optional[str] = "Sensitivity",
x_title: Optional[str] = "Days to outcome",
y_title: Optional[str] = "Positive rate",
n_decimals_y_axis: Optional[int] = 4,
n_decimals_y_axis: int = 4,
save_path: Optional[Path] = None,
) -> Union[None, Path]:
"""Plot heatmap of sensitivity by time to outcome according to different
Expand All @@ -232,8 +294,8 @@ def plot_sensitivity_by_time_to_outcome(
color_map (str, optional): Colormap to use. Defaults to "PuBu".
colorbar_label (str, optional): Colorbar label. Defaults to "Sensitivity".
x_title (str, optional): X axis title. Defaults to "Days to outcome".
y_title (str, optional): Y axis title. Defaults to "Positive rate".
n_decimals_y_axis (int, optional): Number of decimals to round y axis labels. Defaults to 4.
y_title (str, optional): Y axis title. Defaults to "y_hat percentile".
n_decimals_y_axis (int): Number of decimals to round y axis labels. Defaults to 4.
save_path (Optional[Path], optional): Path to save the plot. Defaults to None.
Returns:
Expand Down Expand Up @@ -264,6 +326,10 @@ def plot_sensitivity_by_time_to_outcome(
>>> )
"""
# Construct sensitivity dataframe
# Note that threshold_percentile IS equal to the positive rate,
# since it is calculated on the entire dataset, not just those
# whose true label is 1.

func = partial(
create_sensitivity_by_time_to_outcome_df,
labels=labels,
Expand All @@ -284,56 +350,28 @@ def plot_sensitivity_by_time_to_outcome(
)

# Prepare data for plotting
data, x_labels, y_labels = _generate_sensitivity_array(df, n_decimals_y_axis)

fig, ax = plt.subplots() # pylint: disable=invalid-name

# Plot the heatmap
im = ax.imshow(data, cmap=color_map) # pylint: disable=invalid-name

# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel(colorbar_label, rotation=-90, va="bottom")

# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]), labels=x_labels)
ax.set_yticks(np.arange(data.shape[0]), labels=y_labels)

# Let the horizontal axes labeling appear on top.
ax.tick_params(
top=False,
bottom=True,
labeltop=False,
labelbottom=True,
data, x_labels, y_labels = _generate_sensitivity_array(
df,
n_decimals_y_axis=n_decimals_y_axis,
)

# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=90,
ha="right",
rotation_mode="anchor",
)

# Turn spines off and create white grid.
ax.spines[:].set_visible(False)

ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
fig, axes = plt.subplots() # pylint: disable=invalid-name

# Add annotations
_ = _annotate_heatmap(im, value_formatter="{x:.1f}")

# Set axis labels and title
ax.set(
xlabel=x_title,
ylabel=y_title,
# Plot the heatmap
image = axes.imshow(data, cmap=color_map) # pylint: disable=invalid-name

fig, axes = _format_sens_by_time_heatmap(
colorbar_label=colorbar_label,
x_title=x_title,
y_title=y_title,
data=data,
x_labels=x_labels,
y_labels=y_labels,
fig=fig,
axes=axes,
image=image,
)

fig.tight_layout()

if save_path is None:
plt.show()
else:
Expand All @@ -358,7 +396,7 @@ def plot_sensitivity_by_time_to_outcome(
positive_rate_thresholds=positive_rate_thresholds,
)

plot_sensitivity_by_time_to_outcome(
plot_sensitivity_by_time_to_outcome_heatmap(
labels=df["label"],
y_hat_probs=df["pred_prob"],
pred_proba_thresholds=pred_proba_thresholds,
Expand Down
41 changes: 0 additions & 41 deletions tests/test_sens_over_time.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from psycopt2d.visualization.sens_over_time import (
create_sensitivity_by_time_to_outcome_df,
plot_sensitivity_by_time_to_outcome,
plot_sensitivity_by_time_to_outcome_heatmap,
)


Expand Down Expand Up @@ -116,7 +116,7 @@ def test_plot_sens_by_time_to_outcome(df, tmp_path):
positive_rate_thresholds=positive_rate_thresholds,
)

plot_sensitivity_by_time_to_outcome( # noqa
plot_sensitivity_by_time_to_outcome_heatmap( # noqa
labels=df["label"],
y_hat_probs=df["pred_prob"],
outcome_timestamps=df["timestamp_t2d_diag"],
Expand Down

0 comments on commit 9a373c9

Please sign in to comment.