Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #317 from Aarhus-Psychiatry-Research/bokajgd/issue292
Browse files Browse the repository at this point in the history
Fixing bugs in eval table and plots
  • Loading branch information
bokajgd authored Nov 14, 2022
2 parents 0be0d7f + 2a00325 commit f1851de
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 53 deletions.
47 changes: 6 additions & 41 deletions src/psycopt2d/evaluate_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""_summary_"""
from collections.abc import Iterable, Sequence
from collections.abc import Iterable
from pathlib import Path, PosixPath, WindowsPath
from typing import Optional, Union
from typing import Optional

import pandas as pd
import wandb
Expand Down Expand Up @@ -58,66 +58,34 @@ def upload_artifacts_to_wandb(
run.log({artifact_container.label: wandb_table})


def filter_plot_bins(
cfg: FullConfigSchema,
):
"""Remove bins that don't make sense given the other items in the config.
E.g. it doesn't make sense to plot bins with no contents.
"""
# Bins for plotting
lookahead_bins: Iterable[int] = cfg.eval.lookahead_bins
lookbehind_bins: Iterable[int] = cfg.eval.lookbehind_bins

# Drop date_bins_direction if they are further away than min_lookdirection_days
if cfg.data.min_lookbehind_days:
lookbehind_bins = [
b for b in lookbehind_bins if cfg.data.min_lookbehind_days < b
]

if cfg.data.min_lookahead_days:
lookahead_bins = [
b for b in lookahead_bins if cfg.data.min_lookahead_days < abs(b)
]

# Invert date_bins_behind to negative if it's not already
if min(lookbehind_bins) >= 0:
lookbehind_bins = [-d for d in lookbehind_bins]

# Sort so they're monotonically increasing
lookbehind_bins = sorted(lookbehind_bins)

return lookahead_bins, lookbehind_bins


def create_base_plot_artifacts(
cfg: FullConfigSchema,
eval_dataset: EvalDataset,
save_dir: Path,
lookahead_bins: Sequence[Union[int, float]],
lookbehind_bins: Sequence[Union[int, float]],
) -> list[ArtifactContainer]:
"""A collection of plots that are always generated."""
pred_proba_percentiles = positive_rate_to_pred_probs(
pred_probs=eval_dataset.y_hat_probs,
positive_rate_thresholds=cfg.eval.positive_rate_thresholds,
)

lookahead_bins = cfg.eval.lookahead_bins

return [
ArtifactContainer(
label="sensitivity_by_time_by_threshold",
artifact=plot_sensitivity_by_time_to_outcome_heatmap(
eval_dataset=eval_dataset,
pred_proba_thresholds=pred_proba_percentiles,
bins=lookbehind_bins,
bins=lookahead_bins,
save_path=save_dir / "sensitivity_by_time_by_threshold.png",
),
),
ArtifactContainer(
label="auc_by_time_from_first_visit",
artifact=plot_auc_by_time_from_first_visit(
eval_dataset=eval_dataset,
bins=lookbehind_bins,
bins=lookahead_bins,
save_path=save_dir / "auc_by_time_from_first_visit.png",
),
),
Expand Down Expand Up @@ -199,16 +167,13 @@ def run_full_evaluation(
pipe_metadata: The metadata for the pipe.
upload_to_wandb: Whether to upload to wandb.
"""
lookahead_bins, lookbehind_bins = filter_plot_bins(cfg=cfg)

# Create the directory if it doesn't exist
save_dir.mkdir(parents=True, exist_ok=True)

artifact_containers = create_base_plot_artifacts(
cfg=cfg,
eval_dataset=eval_dataset,
lookahead_bins=lookahead_bins,
lookbehind_bins=lookbehind_bins,
save_dir=save_dir,
)

Expand Down
2 changes: 1 addition & 1 deletion src/psycopt2d/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _process_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame:
- Drop patients with outcome before drop_patient_if_outcome_before_date
- Process timestamp columns
- Drop visits where mmin_lookahead, min_lookbehind or min_prediction_time_date are not met
- Drop visits where min_lookahead, min_lookbehind or min_prediction_time_date are not met
- Drop features with lookbehinds not in lookbehind_combination
Returns:
Expand Down
2 changes: 2 additions & 0 deletions src/psycopt2d/visualization/base_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def plot_basic_chart(

plt.xlabel(x_title)
plt.ylabel(y_title)
plt.xticks(fontsize=7)
plt.xticks(rotation=45)
if save_path is not None:
plt.savefig(save_path)
plt.close()
Expand Down
7 changes: 4 additions & 3 deletions src/psycopt2d/visualization/performance_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def create_performance_by_calendar_time_df(
y_hat (Iterable[int, float]): Predicted probabilities or labels depending on metric
timestamps (Iterable[pd.Timestamp]): Timestamps of predictions
metric_fn (Callable): Callable which returns the metric to calculate
bin_period (str): How to bin time. "M" for year/month, "Y" for year
bin_period (str): How to bin time. "M" for year/month, "Y" for year, "Q" for quarter
Returns:
pd.DataFrame: Dataframe ready for plotting
"""
df = pd.DataFrame({"y": labels, "y_hat": y_hat, "timestamp": timestamps})
df["time_bin"] = df["timestamp"].astype(f"datetime64[{bin_period}]")

df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq=bin_period).format()

output_df = df.groupby("time_bin").apply(calc_performance, metric_fn)

Expand Down Expand Up @@ -159,7 +160,7 @@ def plot_auc_by_time_from_first_visit(
prettify_bins: Optional[bool] = True,
save_path: Optional[Path] = None,
) -> Union[None, Path]:
"""Plot AUC as a function of time to first visit.
"""Plot AUC as a function of time from first visit.
Args:
eval_dataset (EvalDataset): EvalDataset object
Expand Down
10 changes: 2 additions & 8 deletions src/psycopt2d/visualization/sens_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _generate_sensitivity_array(
def _annotate_heatmap(
image: matplotlib.image.AxesImage,
data: Optional[np.ndarray] = None,
value_formatter: str = "{x:.2f}",
textcolors: tuple = ("black", "white"),
threshold: Optional[float] = None,
**textkw,
Expand All @@ -160,7 +159,6 @@ def _annotate_heatmap(
Args:
image (matplotlib.image.AxesImage): The AxesImage to be labeled.
data (np.ndarray): Data used to annotate. If None, the image's data is used. Defaults to None.
value_formatter (str, optional): The format of the annotations inside the heatmap. This should either use the string format method, e.g. "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`. Defaults to "{x:.2f}".
textcolors (tuple, optional): A pair of colors. The first is used for values below a threshold, the second for those above. Defaults to ("black", "white").
threshold (float, optional): Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Defaults to None.
**kwargs (dict, optional): All other arguments are forwarded to each call to `text` used to create the text labels. Defaults to {}.
Expand All @@ -186,10 +184,6 @@ def _annotate_heatmap(
| textkw
)

# Get the formatter in case a string is supplied
if isinstance(value_formatter, str):
value_formatter = matplotlib.ticker.StrMethodFormatter(value_formatter)

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
Expand All @@ -204,7 +198,7 @@ def _annotate_heatmap(
text = image.axes.text(
heat_col_idx,
heat_row_idx,
value_formatter(data[heat_row_idx, heat_col_idx], None), # type: ignore
str(data[heat_row_idx, heat_col_idx]), # type: ignore
**test_kwargs,
)
texts.append(text)
Expand Down Expand Up @@ -256,7 +250,7 @@ def _format_sens_by_time_heatmap(
axes.tick_params(which="minor", bottom=False, left=False)

# Add annotations
_ = _annotate_heatmap(image, value_formatter="{x:.1f}") # type: ignore
_annotate_heatmap(image) # type: ignore

# Set axis labels and title
axes.set(
Expand Down
9 changes: 9 additions & 0 deletions tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def test_plot_performance_by_calendar_time(synth_eval_dataset: EvalDataset):
)


def test_plot_performance_by_calendar_time_quarterly(synth_eval_dataset: EvalDataset):
plot_metric_by_calendar_time(
eval_dataset=synth_eval_dataset,
bin_period="Q",
metric_fn=roc_auc_score,
y_title="AUC",
)


def test_plot_metric_until_diagnosis(synth_eval_dataset: EvalDataset):
plot_metric_by_time_until_diagnosis(
eval_dataset=synth_eval_dataset,
Expand Down

0 comments on commit f1851de

Please sign in to comment.