From 81c0f9314c839485b847dfbeaad50210c02f307e Mon Sep 17 00:00:00 2001 From: bokajgd Date: Wed, 9 Nov 2022 10:32:50 +0100 Subject: [PATCH 1/9] feat: adjustments to eval Fixes #292 --- src/psycopt2d/visualization/performance_over_time.py | 8 ++++++-- tests/model_evaluation/test_visualizations.py | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/psycopt2d/visualization/performance_over_time.py b/src/psycopt2d/visualization/performance_over_time.py index 320c1397..62d9d152 100644 --- a/src/psycopt2d/visualization/performance_over_time.py +++ b/src/psycopt2d/visualization/performance_over_time.py @@ -32,13 +32,17 @@ def create_performance_by_calendar_time_df( y_hat (Iterable[int, float]): Predicted probabilities or labels depending on metric timestamps (Iterable[pd.Timestamp]): Timestamps of predictions metric_fn (Callable): Callable which returns the metric to calculate - bin_period (str): How to bin time. "M" for year/month, "Y" for year + bin_period (str): How to bin time. "M" for year/month, "Y" for year, "Q" for quarter Returns: pd.DataFrame: Dataframe ready for plotting """ df = pd.DataFrame({"y": labels, "y_hat": y_hat, "timestamp": timestamps}) - df["time_bin"] = df["timestamp"].astype(f"datetime64[{bin_period}]") + + if bin_period == "Q": + df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq="Q") + else: + df["time_bin"] = df["timestamp"].astype(f"datetime64[{bin_period}]") output_df = df.groupby("time_bin").apply(calc_performance, metric_fn) diff --git a/tests/model_evaluation/test_visualizations.py b/tests/model_evaluation/test_visualizations.py index f8b87a9f..7a336acb 100644 --- a/tests/model_evaluation/test_visualizations.py +++ b/tests/model_evaluation/test_visualizations.py @@ -107,6 +107,15 @@ def test_plot_performance_by_calendar_time(synth_eval_dataset: EvalDataset): ) +def test_plot_performance_by_calendar_time_quarterly(synth_eval_dataset: EvalDataset): + plot_metric_by_calendar_time( + eval_dataset=synth_eval_dataset, + bin_period="Q", + metric_fn=roc_auc_score, + y_title="AUC", + ) + + def test_plot_metric_until_diagnosis(synth_eval_dataset: EvalDataset): plot_metric_by_time_until_diagnosis( eval_dataset=synth_eval_dataset, From d543d39f81ad368d22978673231edd12e4d781d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 11:49:38 +0100 Subject: [PATCH 2/9] # fix: bin by quarterly + rotated xticks --- src/psycopt2d/visualization/base_charts.py | 1 + src/psycopt2d/visualization/performance_over_time.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/psycopt2d/visualization/base_charts.py b/src/psycopt2d/visualization/base_charts.py index bbad7fe6..138d18db 100644 --- a/src/psycopt2d/visualization/base_charts.py +++ b/src/psycopt2d/visualization/base_charts.py @@ -62,6 +62,7 @@ def plot_basic_chart( plt.xlabel(x_title) plt.ylabel(y_title) + plt.xticks(rotation=45) if save_path is not None: plt.savefig(save_path) plt.close() diff --git a/src/psycopt2d/visualization/performance_over_time.py b/src/psycopt2d/visualization/performance_over_time.py index 62d9d152..d9df657c 100644 --- a/src/psycopt2d/visualization/performance_over_time.py +++ b/src/psycopt2d/visualization/performance_over_time.py @@ -40,7 +40,7 @@ def create_performance_by_calendar_time_df( df = pd.DataFrame({"y": labels, "y_hat": y_hat, "timestamp": timestamps}) if bin_period == "Q": - df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq="Q") + df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq="Q").format() else: df["time_bin"] = df["timestamp"].astype(f"datetime64[{bin_period}]") From beae28ba9e05f9aaafd27fe8be2679ea989801fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 13:28:33 +0100 Subject: [PATCH 3/9] fix: remove automatic bin trimming, not needed --- src/psycopt2d/evaluate_model.py | 41 ++----------------- src/psycopt2d/load.py | 2 +- .../visualization/performance_over_time.py | 2 +- 3 files changed, 6 insertions(+), 39 deletions(-) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index 29b5040a..32896e6d 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -57,38 +57,6 @@ def upload_artifacts_to_wandb( run.log({artifact_container.label: wandb_table}) -def filter_plot_bins( - cfg: FullConfigSchema, -): - """Remove bins that don't make sense given the other items in the config. - - E.g. it doesn't make sense to plot bins with no contents. - """ - # Bins for plotting - lookahead_bins: Iterable[int] = cfg.eval.lookahead_bins - lookbehind_bins: Iterable[int] = cfg.eval.lookbehind_bins - - # Drop date_bins_direction if they are further away than min_lookdirection_days - if cfg.data.min_lookbehind_days: - lookbehind_bins = [ - b for b in lookbehind_bins if cfg.data.min_lookbehind_days < b - ] - - if cfg.data.min_lookahead_days: - lookahead_bins = [ - b for b in lookahead_bins if cfg.data.min_lookahead_days < abs(b) - ] - - # Invert date_bins_behind to negative if it's not already - if min(lookbehind_bins) >= 0: - lookbehind_bins = [-d for d in lookbehind_bins] - - # Sort so they're monotonically increasing - lookbehind_bins = sorted(lookbehind_bins) - - return lookahead_bins, lookbehind_bins - - def create_base_plot_artifacts( cfg: FullConfigSchema, eval_dataset: EvalDataset, @@ -108,7 +76,7 @@ def create_base_plot_artifacts( artifact=plot_sensitivity_by_time_to_outcome_heatmap( eval_dataset=eval_dataset, pred_proba_thresholds=pred_proba_percentiles, - bins=lookbehind_bins, + bins=lookahead_bins, save_path=save_dir / "sensitivity_by_time_by_threshold.png", ), ), @@ -116,7 +84,7 @@ def create_base_plot_artifacts( label="auc_by_time_from_first_visit", artifact=plot_auc_by_time_from_first_visit( eval_dataset=eval_dataset, - bins=lookbehind_bins, + bins=lookahead_bins, save_path=save_dir / "auc_by_time_from_first_visit.png", ), ), @@ -191,7 +159,6 @@ def run_full_evaluation( pipe_metadata: The metadata for the pipe. upload_to_wandb: Whether to upload to wandb. """ - lookahead_bins, lookbehind_bins = filter_plot_bins(cfg=cfg) # Create the directory if it doesn't exist save_dir.mkdir(parents=True, exist_ok=True) @@ -199,8 +166,8 @@ def run_full_evaluation( artifact_containers = create_base_plot_artifacts( cfg=cfg, eval_dataset=eval_dataset, - lookahead_bins=lookahead_bins, - lookbehind_bins=lookbehind_bins, + lookahead_bins=cfg.eval.lookahead_bins, + lookbehind_bins=cfg.eval.lookbehind_bins, save_dir=save_dir, ) diff --git a/src/psycopt2d/load.py b/src/psycopt2d/load.py index efdc17e1..af910b37 100644 --- a/src/psycopt2d/load.py +++ b/src/psycopt2d/load.py @@ -412,7 +412,7 @@ def _process_dataset(self, dataset: pd.DataFrame) -> pd.DataFrame: - Drop patients with outcome before drop_patient_if_outcome_before_date - Process timestamp columns - - Drop visits where mmin_lookahead, min_lookbehind or min_prediction_time_date are not met + - Drop visits where min_lookahead, min_lookbehind or min_prediction_time_date are not met - Drop features with lookbehinds not in lookbehind_combination Returns: diff --git a/src/psycopt2d/visualization/performance_over_time.py b/src/psycopt2d/visualization/performance_over_time.py index d9df657c..7d37331e 100644 --- a/src/psycopt2d/visualization/performance_over_time.py +++ b/src/psycopt2d/visualization/performance_over_time.py @@ -163,7 +163,7 @@ def plot_auc_by_time_from_first_visit( prettify_bins: Optional[bool] = True, save_path: Optional[Path] = None, ) -> Union[None, Path]: - """Plot AUC as a function of time to first visit. + """Plot AUC as a function of time from first visit. Args: eval_dataset (EvalDataset): EvalDataset object From cd498e33fc464d9ce2eb82d6e9a00da6290f0691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 14:11:54 +0100 Subject: [PATCH 4/9] fix: keep two decimals on heatmap annotations --- src/psycopt2d/visualization/sens_over_time.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/psycopt2d/visualization/sens_over_time.py b/src/psycopt2d/visualization/sens_over_time.py index 569d51d6..1439fb7a 100644 --- a/src/psycopt2d/visualization/sens_over_time.py +++ b/src/psycopt2d/visualization/sens_over_time.py @@ -150,7 +150,6 @@ def _generate_sensitivity_array( def _annotate_heatmap( image: matplotlib.image.AxesImage, data: Optional[np.ndarray] = None, - value_formatter: str = "{x:.2f}", textcolors: tuple = ("black", "white"), threshold: Optional[float] = None, **textkw, @@ -160,7 +159,6 @@ def _annotate_heatmap( Args: image (matplotlib.image.AxesImage): The AxesImage to be labeled. data (np.ndarray): Data used to annotate. If None, the image's data is used. Defaults to None. - value_formatter (str, optional): The format of the annotations inside the heatmap. This should either use the string format method, e.g. "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`. Defaults to "{x:.2f}". textcolors (tuple, optional): A pair of colors. The first is used for values below a threshold, the second for those above. Defaults to ("black", "white"). threshold (float, optional): Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Defaults to None. **kwargs (dict, optional): All other arguments are forwarded to each call to `text` used to create the text labels. Defaults to {}. @@ -186,10 +184,6 @@ def _annotate_heatmap( | textkw ) - # Get the formatter in case a string is supplied - if isinstance(value_formatter, str): - value_formatter = matplotlib.ticker.StrMethodFormatter(value_formatter) - # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] @@ -204,7 +198,7 @@ def _annotate_heatmap( text = image.axes.text( heat_col_idx, heat_row_idx, - value_formatter(data[heat_row_idx, heat_col_idx], None), # type: ignore + str(data[heat_row_idx, heat_col_idx]), # type: ignore **test_kwargs, ) texts.append(text) @@ -256,7 +250,7 @@ def _format_sens_by_time_heatmap( axes.tick_params(which="minor", bottom=False, left=False) # Add annotations - _ = _annotate_heatmap(image, value_formatter="{x:.1f}") # type: ignore + _ = _annotate_heatmap(image) # type: ignore # Set axis labels and title axes.set( From 51e6b74241f6c72d71c1b5ef8b51c79d0d3c94b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 14:57:02 +0100 Subject: [PATCH 5/9] fix: remove unused args --- src/psycopt2d/evaluate_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index 32896e6d..82d67d28 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -62,7 +62,6 @@ def create_base_plot_artifacts( eval_dataset: EvalDataset, save_dir: Path, lookahead_bins: Sequence[Union[int, float]], - lookbehind_bins: Sequence[Union[int, float]], ) -> list[ArtifactContainer]: """A collection of plots that are always generated.""" pred_proba_percentiles = positive_rate_to_pred_probs( From a60659fc2b229ab79e494a8b9492523bce74e571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Wed, 9 Nov 2022 14:58:20 +0100 Subject: [PATCH 6/9] fix: remove unused arg --- src/psycopt2d/evaluate_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index 82d67d28..f4b50792 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -166,7 +166,6 @@ def run_full_evaluation( cfg=cfg, eval_dataset=eval_dataset, lookahead_bins=cfg.eval.lookahead_bins, - lookbehind_bins=cfg.eval.lookbehind_bins, save_dir=save_dir, ) From f5e5a4250ddbf894a66d26306727844e8e748c18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Mon, 14 Nov 2022 09:47:45 +0100 Subject: [PATCH 7/9] fix: review comments --- src/psycopt2d/evaluate_model.py | 4 ++-- src/psycopt2d/visualization/performance_over_time.py | 5 +---- src/psycopt2d/visualization/sens_over_time.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index f4b50792..e87b17cb 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -61,7 +61,6 @@ def create_base_plot_artifacts( cfg: FullConfigSchema, eval_dataset: EvalDataset, save_dir: Path, - lookahead_bins: Sequence[Union[int, float]], ) -> list[ArtifactContainer]: """A collection of plots that are always generated.""" pred_proba_percentiles = positive_rate_to_pred_probs( @@ -69,6 +68,8 @@ def create_base_plot_artifacts( positive_rate_thresholds=cfg.eval.positive_rate_thresholds, ) + lookahead_bins = cfg.eval.lookahead_bins + return [ ArtifactContainer( label="sensitivity_by_time_by_threshold", @@ -165,7 +166,6 @@ def run_full_evaluation( artifact_containers = create_base_plot_artifacts( cfg=cfg, eval_dataset=eval_dataset, - lookahead_bins=cfg.eval.lookahead_bins, save_dir=save_dir, ) diff --git a/src/psycopt2d/visualization/performance_over_time.py b/src/psycopt2d/visualization/performance_over_time.py index 7d37331e..b8aa3296 100644 --- a/src/psycopt2d/visualization/performance_over_time.py +++ b/src/psycopt2d/visualization/performance_over_time.py @@ -39,10 +39,7 @@ def create_performance_by_calendar_time_df( """ df = pd.DataFrame({"y": labels, "y_hat": y_hat, "timestamp": timestamps}) - if bin_period == "Q": - df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq="Q").format() - else: - df["time_bin"] = df["timestamp"].astype(f"datetime64[{bin_period}]") + df["time_bin"] = pd.PeriodIndex(df["timestamp"], freq=bin_period).format() output_df = df.groupby("time_bin").apply(calc_performance, metric_fn) diff --git a/src/psycopt2d/visualization/sens_over_time.py b/src/psycopt2d/visualization/sens_over_time.py index 1439fb7a..ceccc0f0 100644 --- a/src/psycopt2d/visualization/sens_over_time.py +++ b/src/psycopt2d/visualization/sens_over_time.py @@ -250,7 +250,7 @@ def _format_sens_by_time_heatmap( axes.tick_params(which="minor", bottom=False, left=False) # Add annotations - _ = _annotate_heatmap(image) # type: ignore + _annotate_heatmap(image) # type: ignore # Set axis labels and title axes.set( From 36245c9056288c941ea7f6a9bfa2d104fde364ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Mon, 14 Nov 2022 09:48:03 +0100 Subject: [PATCH 8/9] chore: linting --- src/psycopt2d/evaluate_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/psycopt2d/evaluate_model.py b/src/psycopt2d/evaluate_model.py index e87b17cb..54336801 100644 --- a/src/psycopt2d/evaluate_model.py +++ b/src/psycopt2d/evaluate_model.py @@ -1,7 +1,7 @@ """_summary_""" -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from pathlib import Path, PosixPath, WindowsPath -from typing import Optional, Union +from typing import Optional import pandas as pd import wandb From 2a003251bbb7df70a7d3cba8e95b30c89e272019 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakob=20Gr=C3=B8hn?= Date: Mon, 14 Nov 2022 10:03:16 +0100 Subject: [PATCH 9/9] fix: smaller xticks fontsize --- src/psycopt2d/visualization/base_charts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/psycopt2d/visualization/base_charts.py b/src/psycopt2d/visualization/base_charts.py index 138d18db..bd8ecf63 100644 --- a/src/psycopt2d/visualization/base_charts.py +++ b/src/psycopt2d/visualization/base_charts.py @@ -62,6 +62,7 @@ def plot_basic_chart( plt.xlabel(x_title) plt.ylabel(y_title) + plt.xticks(fontsize=7) plt.xticks(rotation=45) if save_path is not None: plt.savefig(save_path)