Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Plot perf by sex (#402)
Browse files Browse the repository at this point in the history
- [x] I have battle-tested on Overtaci (RMAPPS1279)
- [x] At least one of the commits is prefixed with either "fix:" or
"feat:"

Fixes issue #393.

## Notes for reviewers
We changed the name of the argument "prettify_bins" to
"continuous_input_to_bins", since the functionality is whether is should
be binned or not, rather than whether it should be pretty. Also, for
both performance plots by age and sex, the metric is optional (default
auc), though metric_fn_to_input can only handle roc_auc_score.
  • Loading branch information
MartinBernstorff authored Feb 28, 2023
2 parents a91e34f + 194966d commit 2479a7d
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def plot_basic_chart(
frameon=True,
)

plt.tight_layout()

if save_path is not None:
plt.savefig(save_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def plot_performance_by_age(
eval_dataset: EvalDataset,
save_path: Optional[Path] = None,
bins: Sequence[Union[int, float]] = (18, 25, 35, 50, 70),
prettify_bins: Optional[bool] = True,
bin_continuous_input: Optional[bool] = True,
metric_fn: Callable = roc_auc_score,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
) -> Union[None, Path]:
Expand All @@ -28,8 +28,7 @@ def plot_performance_by_age(
Args:
eval_dataset: EvalDataset object
bins (Sequence[Union[int, float]]): Bins to group by. Defaults to (18, 25, 35, 50, 70, 100).
prettify_bins (bool, optional): Whether to prettify bin names. I.e. make
bins look like "18-25" instead of "[18-25])". Defaults to True.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
metric_fn (Callable): Callable which returns the metric to calculate
save_path (Path, optional): Path to save figure. Defaults to None.
y_limits (tuple[float, float], optional): y-axis limits. Defaults to (0.5, 1.0).
Expand All @@ -44,7 +43,7 @@ def plot_performance_by_age(
input_name="age",
metric_fn=metric_fn,
bins=bins,
prettify_bins=prettify_bins,
bin_continuous_input=bin_continuous_input,
)

sort_order = sorted(df["age_binned"].unique())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Optional, Union

from sklearn.metrics import roc_auc_score

from psycop_model_training.model_eval.base_artifacts.plots.base_charts import (
plot_basic_chart,
)
from psycop_model_training.model_eval.base_artifacts.plots.utils import (
create_performance_by_input,
)
from psycop_model_training.model_eval.dataclasses import EvalDataset


def plot_performance_by_sex(
eval_dataset: EvalDataset,
save_path: Optional[Path] = None,
metric_fn: Callable = roc_auc_score,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
) -> Union[None, Path]:
"""Plot bar plot of performance (default AUC) by sex at time of prediction.
Args:
eval_dataset: EvalDataset object
save_path (Path, optional): Path to save figure. Defaults to None.
metric_fn (Callable): Callable which returns the metric to calculate
y_limits (tuple[float, float], optional): y-axis limits. Defaults to (0.0, 1.0).
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
"""

df = create_performance_by_input(
eval_dataset=eval_dataset,
input=eval_dataset.is_female,
input_name="sex",
metric_fn=metric_fn,
bins=None,
bin_continuous_input=False,
)

df.sex = df.sex.replace({1: "female", 0: "male"})

return plot_basic_chart(
x_values=df["sex"],
y_values=df["metric"],
x_title="Sex",
y_title="AUC",
y_limits=y_limits,
plot_type=["bar"],
save_path=save_path,
)
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def create_performance_by_time_from_event_df(
metric_fn: Callable,
direction: str,
bins: Iterable[float],
prettify_bins: Optional[bool] = True,
bin_continuous_input: Optional[bool] = True,
drop_na_events: Optional[bool] = True,
) -> pd.DataFrame:
"""Create dataframe for plotting performance metric from time to or from
Expand All @@ -291,8 +291,7 @@ def create_performance_by_time_from_event_df(
direction (str): Which direction to calculate time difference.
Can either be 'prediction-event' or 'event-prediction'.
bins (Iterable[float]): Bins to group by.
prettify_bins (bool, optional): Whether to prettify bin names. I.e. make
bins look like "1-7" instead of "[1-7]". Defaults to True.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
drop_na_events (bool, optional): Whether to drop rows where the event is NA. Defaults to True.
Returns:
Expand Down Expand Up @@ -328,7 +327,7 @@ def create_performance_by_time_from_event_df(
)

# bin data
bin_fn = bin_continuous_data if prettify_bins else round_floats_to_edge
bin_fn = bin_continuous_data if bin_continuous_input else round_floats_to_edge

# Convert df["days_from_event"] to int if possible
df["days_from_event_binned"] = bin_fn(df["days_from_event"], bins=bins)
Expand All @@ -342,7 +341,7 @@ def create_performance_by_time_from_event_df(
def plot_auc_by_time_from_first_visit(
eval_dataset: EvalDataset,
bins: tuple = (0, 28, 182, 365, 730, 1825),
prettify_bins: Optional[bool] = True,
bin_continuous_input: Optional[bool] = True,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
save_path: Optional[Path] = None,
) -> Union[None, Path]:
Expand All @@ -351,8 +350,7 @@ def plot_auc_by_time_from_first_visit(
Args:
eval_dataset (EvalDataset): EvalDataset object
bins (list, optional): Bins to group by. Defaults to [0, 28, 182, 365, 730, 1825].
prettify_bins (bool, optional): Prettify bin names. I.e. make
bins look like "1-7" instead of "[1-7)" Defaults to True.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0).
save_path (Path, optional): Path to save figure. Defaults to None.
Expand All @@ -372,7 +370,7 @@ def plot_auc_by_time_from_first_visit(
prediction_timestamps=eval_dataset.pred_timestamps,
direction="prediction-event",
bins=bins,
prettify_bins=prettify_bins,
bin_continuous_input=bin_continuous_input,
drop_na_events=False,
metric_fn=roc_auc_score,
)
Expand Down Expand Up @@ -400,7 +398,7 @@ def plot_metric_by_time_until_diagnosis(
-28,
-0,
),
prettify_bins: bool = True,
bin_continuous_input: bool = True,
metric_fn: Callable = f1_score,
y_title: str = "F1",
y_limits: Optional[tuple[float, float]] = None,
Expand All @@ -414,7 +412,7 @@ def plot_metric_by_time_until_diagnosis(
eval_dataset (EvalDataset): EvalDataset object
bins (list, optional): Bins to group by. Negative values indicate days after
diagnosis. Defaults to (-1825, -730, -365, -182, -28, -14, -7, -1, 0)
prettify_bins (bool, optional): Whether to prettify bin names. Defaults to True.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
metric_fn (Callable): Which performance metric function to use.
y_title (str): Title for y-axis (metric name)
y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to None.
Expand All @@ -430,7 +428,7 @@ def plot_metric_by_time_until_diagnosis(
prediction_timestamps=eval_dataset.pred_timestamps,
direction="event-prediction",
bins=bins,
prettify_bins=prettify_bins,
bin_continuous_input=bin_continuous_input,
drop_na_events=True,
metric_fn=metric_fn,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def create_performance_by_input(
input: Sequence[Union[int, float]],
input_name: str,
bins: Sequence[Union[int, float]] = (0, 1, 2, 5, 10),
prettify_bins: Optional[bool] = True,
bin_continuous_input: Optional[bool] = True,
metric_fn: Callable = roc_auc_score,
) -> pd.DataFrame:
"""Calculate performance by given input values, e.g. age or number of hbac1
Expand All @@ -81,8 +81,7 @@ def create_performance_by_input(
input (Sequence[Union[int, float]]): Input values to calculate performance by
input_name (str): Name of the input
bins (Sequence[Union[int, float]]): Bins to group by. Defaults to (0, 1, 2, 5, 10, 100).
prettify_bins (bool, optional): Whether to prettify bin names. I.e. make
bins look like "1-7" instead of "[1-7)". Defaults to True.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
metric_fn (Callable): Callable which returns the metric to calculate
Returns:
Expand All @@ -97,7 +96,7 @@ def create_performance_by_input(
)

# bin data
if prettify_bins:
if bin_continuous_input:
df[f"{input_name}_binned"] = bin_continuous_data(df[input_name], bins=bins)

output_df = df.groupby(f"{input_name}_binned").apply(
Expand Down
1 change: 1 addition & 0 deletions src/psycop_model_training/model_eval/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class EvalDataset(BaseModel):
y_hat_probs: pd.Series
y_hat_int: pd.Series
age: Optional[pd.Series] = None
is_female: Optional[pd.Series] = None
exclusion_timestamps: Optional[pd.Series] = None
custom_columns: Optional[dict[str, pd.Series]] = None

Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
CONFIG_DIR_PATH_REL = "../application/config"


def add_age_gender(df: pd.DataFrame):
def add_age_is_female(df: pd.DataFrame):
"""Add age and gender columns to dataframe.
Args:
df (pd.DataFrame): The dataframe to add age
"""
ids = pd.DataFrame({"dw_ek_borger": df["dw_ek_borger"].unique()})
ids["age"] = np.random.randint(17, 95, len(ids))
ids["gender"] = np.where(ids["dw_ek_borger"] > 30_000, "F", "M")
ids["is_female"] = np.where(ids["dw_ek_borger"] > 30_000, 1, 0)

return df.merge(ids)

Expand All @@ -32,7 +32,7 @@ def synth_eval_dataset() -> EvalDataset:
"""Load synthetic data."""
csv_path = Path("tests") / "test_data" / "synth_eval_data.csv"
df = pd.read_csv(csv_path)
df = add_age_gender(df)
df = add_age_is_female(df)

# Convert all timestamp cols to datetime
for col in [col for col in df.columns if "timestamp" in col]:
Expand All @@ -46,6 +46,7 @@ def synth_eval_dataset() -> EvalDataset:
pred_timestamps=df["timestamp"],
outcome_timestamps=df["timestamp_t2d_diag"],
age=df["age"],
is_female=df["is_female"],
)


Expand Down
12 changes: 11 additions & 1 deletion tests/model_evaluation/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from psycop_model_training.model_eval.base_artifacts.plots.performance_by_age import (
plot_performance_by_age,
)
from psycop_model_training.model_eval.base_artifacts.plots.performance_by_sex import (
plot_performance_by_sex,
)
from psycop_model_training.model_eval.base_artifacts.plots.performance_over_time import (
plot_auc_by_time_from_first_visit,
plot_metric_by_calendar_time,
Expand Down Expand Up @@ -104,7 +107,14 @@ def test_plot_bar_chart(synth_eval_dataset: EvalDataset):
def test_plot_performance_by_age(synth_eval_dataset: EvalDataset):
plot_performance_by_age(
eval_dataset=synth_eval_dataset,
save_path=PROJECT_ROOT / "test.png",
save_path=PROJECT_ROOT / "test_performance_plot_by_age.png",
)


def test_plot_performance_by_sex(synth_eval_dataset: EvalDataset):
plot_performance_by_sex(
eval_dataset=synth_eval_dataset,
save_path=PROJECT_ROOT / "test_performance_plot_by_sex.png",
)


Expand Down

0 comments on commit 2479a7d

Please sign in to comment.