Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Feat: Generate table 1 (#398)
Browse files Browse the repository at this point in the history
- [x] I have battle-tested on Overtaci (RMAPPS1279)
- [ ] I have assigned ranges (e.g. `>=0.1, <0.2`) to all new
dependencies (allows dependabot to keep dependency ranges wide for
better compatability)
- [x] At least one of the commits is prefixed with either "fix:" or
"feat:"

## Notes for reviewers
Many things are perhaps still a bit messy. Furthermore, I haven't had
time to write up a proper test. However, I feel I need to have some new
eyes look it through to see if there are any major changes that need to
be made.

I feel like it is pretty generalisable. Currently, the base function
adds:

**Visit level:**
- Age (mean/std)
- Age counts in 5 intervals (count/%)
- Sex (count/%)
- eval_ column stats: (count/%) for binary columns and (mean/std) for
continuous
- visits followed by a positive outcome (count/%)
**Patient level:**
- patients with positive outcome (count/%)
- time from first contact to positive outcome (mean/std in days)
  • Loading branch information
MartinBernstorff authored Mar 6, 2023
2 parents abeb1a7 + 62493fe commit bf0edb1
Show file tree
Hide file tree
Showing 15 changed files with 1,407 additions and 1,009 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ outputs

# virtual env
.venv/
venv/

# MacOS
.DS_Store
Expand All @@ -125,6 +126,7 @@ package.json
*.pkl
.aucs/
*.parquet
*.csv

evaluation_results
*.pqt
1 change: 1 addition & 0 deletions src/psycop_model_training/config_schemas/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ColumnNamesSchema(BaseModel):
outcome_timestamp: str # Column name for outcome timestamps
id: str # Citizen colnames
age: str # Name of the age column
is_female: str # Name of the sex column
exclusion_timestamp: Optional[str] # Name of the exclusion timestamps column.
# Drops all visits whose pred_timestamp <= exclusion_timestamp.
custom_columns: Optional[list[str]] = None
Expand Down
3 changes: 3 additions & 0 deletions src/psycop_model_training/config_schemas/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ class EvalConfSchema(BaseModel):
force: bool = False
# Whether to force evaluation even if wandb is not "run". Used for testing.

descriptive_stats_table: bool = True
# Whether to generate table 1.

top_n_feature_importances: int
# How many feature_importances to plot. Plots the most important n features. A table with all features is also logged.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from psycop_model_training.model_eval.base_artifacts.plots.time_from_first_positive_to_event import (
plot_time_from_first_positive_to_event,
)
from psycop_model_training.model_eval.base_artifacts.tables.descriptive_stats_table import (
DescriptiveStatsTable,
)
from psycop_model_training.model_eval.base_artifacts.tables.performance_by_threshold import (
generate_performance_by_positive_rate_table,
)
Expand Down Expand Up @@ -176,13 +179,24 @@ def create_base_plot_artifacts(self) -> list[ArtifactContainer]:
),
]

def get_descriptive_stats_table_artifact(self):
"""Returns descriptive stats table artifact."""
return [
ArtifactContainer(
label="descriptive_stats_table",
artifact=DescriptiveStatsTable(
self.eval_ds,
).generate_descriptive_stats_table(),
),
]

def get_feature_selection_artifacts(self):
"""Returns a list of artifacts related to feature selection."""
return [
ArtifactContainer(
label="selected_features",
artifact=generate_selected_features_table(
selected_features_dict=self.pipe_metadata.selected_features,
eval_dataset=self.pipe_metadata.selected_features,
output_format="df",
),
),
Expand Down Expand Up @@ -211,6 +225,9 @@ def get_all_artifacts(self) -> list[ArtifactContainer]:
"""Generates artifacts from an EvalDataset."""
artifact_containers = self.create_base_plot_artifacts()

if self.cfg.eval.descriptive_stats_table:
artifact_containers += self.get_descriptive_stats_table_artifact()

if self.pipe_metadata and self.pipe_metadata.feature_importances:
artifact_containers += self.get_feature_importance_artifacts()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
"""Code for generating a descriptive stats table."""
import warnings
from pathlib import Path
from typing import Optional, Union

import pandas as pd
import wandb

from psycop_model_training.model_eval.base_artifacts.tables.tables import output_table
from psycop_model_training.model_eval.dataclasses import EvalDataset
from psycop_model_training.utils.utils import bin_continuous_data


class DescriptiveStatsTable:
"""Class for generating a descriptive stats table."""

def __init__(
self,
eval_dataset: EvalDataset,
) -> None:
self.eval_dataset = eval_dataset.to_df()

def _get_column_header_df(self):
"""Create empty dataframe with default columns headers.
Returns:
pd.DataFrame: Empty dataframe with default columns headers. Includes columns for category, two statistics and there units.
"""
return pd.DataFrame(
columns=["category", "stat_1", "stat_1_unit", "stat_2", "stat_2_unit"],
)

def _generate_age_stats(
self,
) -> pd.DataFrame:
"""Add age stats to table 1."""

df = self._get_column_header_df()

age_mean = round(self.eval_dataset["age"].mean(), 2)

age_span = f'{self.eval_dataset["age"].quantile(0.05)} - {self.eval_dataset["age"].quantile(0.95)}'

df = df.append(
{
"category": "(visit_level) age (mean / interval)",
"stat_1": age_mean,
"stat_1_unit": "years",
"stat_2": age_span,
"stat_2_unit": "years",
},
ignore_index=True,
)
age_counts = bin_continuous_data(
self.eval_dataset["age"],
bins=[0, 18, 35, 60, 100],
).value_counts()

age_percentages = round(age_counts / len(self.eval_dataset) * 100, 2)

for i, _ in enumerate(age_counts):
df = df.append(
{
"category": f"(visit level) age {age_counts.index[i]}",
"stat_1": int(age_counts.iloc[i]),
"stat_1_unit": "patients",
"stat_2": age_percentages.iloc[i],
"stat_2_unit": "%",
},
ignore_index=True,
)

return df

def _generate_sex_stats(
self,
) -> pd.DataFrame:
"""Add sex stats to table 1."""

df = self._get_column_header_df()

sex_counts = self.eval_dataset["is_female"].value_counts()
sex_percentages = sex_counts / len(self.eval_dataset) * 100

for i, n in enumerate(sex_counts):
if n < 5:
warnings.warn(
"WARNING: One of the sex categories has less than 5 individuals. This category will be excluded from the table.",
)
return df

df = df.append(
{
"category": f"(visit level) {sex_counts.index[i]}",
"stat_1": int(sex_counts[i]),
"stat_1_unit": "patients",
"stat_2": sex_percentages[i],
"stat_2_unit": "%",
},
ignore_index=True,
)

return df

def _generate_eval_col_stats(self) -> pd.DataFrame:
"""Generate stats for all eval_ columns to table 1.
Finds all columns starting with 'eval_' and adds visit level
stats for these columns. Checks if the column is binary or
continuous and adds stats accordingly.
"""

df = self._get_column_header_df()

eval_cols = [
col for col in self.eval_dataset.columns if col.startswith("eval_")
]

for col in eval_cols:
if len(self.eval_dataset[col].unique()) == 2:
# Binary variable stats:
col_count = self.eval_dataset[col].value_counts()
col_percentage = col_count / len(self.eval_dataset) * 100

if col_count[0] < 5 or col_count[1] < 5:
warnings.warn(
f"WARNING: One of categories in {col} has less than 5 individuals. This category will be excluded from the table.",
)
else:
df = df.append(
{
"category": f"(visit level) {col} ",
"stat_1": int(col_count[1]),
"stat_1_unit": "patients",
"stat_2": col_percentage[1],
"stat_2_unit": "%",
},
ignore_index=True,
)

elif len(self.eval_dataset[col].unique()) > 2:
# Continuous variable stats:
col_mean = round(self.eval_dataset[col].mean(), 2)
col_std = round(self.eval_dataset[col].std(), 2)
df = df.append(
{
"category": f"(visit level) {col}",
"stat_1": col_mean,
"stat_1_unit": "mean",
"stat_2": col_std,
"stat_2_unit": "std",
},
ignore_index=True,
)

else:
warnings.warn(
f"WARNING: {col} has only one value. This column will be excluded from the table.",
)

return df

def _generate_visit_level_stats(
self,
) -> pd.DataFrame:
"""Generate all visit level stats to table 1."""

# Stats for eval_ cols
df = self._generate_eval_col_stats()

# General stats
visits_followed_by_positive_outcome = self.eval_dataset["y"].sum()
visits_followed_by_positive_outcome_percentage = round(
(visits_followed_by_positive_outcome / len(self.eval_dataset) * 100),
2,
)

df = df.append(
{
"category": "(visit_level) visits followed by positive outcome",
"stat_1": visits_followed_by_positive_outcome,
"stat_1_unit": "visits",
"stat_2": visits_followed_by_positive_outcome_percentage,
"stat_2_unit": "%",
},
ignore_index=True,
)

return df

def _calc_time_to_first_positive_outcome_stats(
self,
patients_with_positive_outcome_data: pd.DataFrame,
) -> float:
"""Calculate mean time to first positive outcome (currently very
slow)."""

grouped_data = patients_with_positive_outcome_data.groupby("ids")

time_to_first_positive_outcome = grouped_data.apply(
lambda x: x["outcome_timestamps"].min() - x["pred_timestamps"].min(),
)

# Convert to days (float)
time_to_first_positive_outcome = (
time_to_first_positive_outcome.dt.total_seconds() / (24 * 60 * 60)
) # Not using timedelta.days to keep higher temporal precision

return round(time_to_first_positive_outcome.mean(), 2), round(
time_to_first_positive_outcome.std(),
2,
)

def _generate_patient_level_stats(
self,
) -> pd.DataFrame:
"""Add patient level stats to table 1."""

df = self._get_column_header_df()

# General stats
patients_with_positive_outcome = self.eval_dataset[self.eval_dataset["y"] == 1][
"ids"
].unique()
n_patients_with_positive_outcome = len(patients_with_positive_outcome)
patients_with_positive_outcome_percentage = round(
(
n_patients_with_positive_outcome
/ len(self.eval_dataset["ids"].unique())
* 100
),
2,
)

df = df.append(
{
"category": "(patient_level) patients_with_positive_outcome",
"stat_1": n_patients_with_positive_outcome,
"stat_1_unit": "visits",
"stat_2": patients_with_positive_outcome_percentage,
"stat_2_unit": "%",
},
ignore_index=True,
)

patients_with_positive_outcome_data = self.eval_dataset[
self.eval_dataset["ids"].isin(patients_with_positive_outcome)
]

(
mean_time_to_first_positive_outcome,
std_time_to_first_positive_outomce,
) = self._calc_time_to_first_positive_outcome_stats(
patients_with_positive_outcome_data,
)

df = df.append(
{
"category": "(patient level) time_to_first_positive_outcome",
"stat_1": mean_time_to_first_positive_outcome,
"stat_1_unit": "mean",
"stat_2": std_time_to_first_positive_outomce,
"stat_2_unit": "std",
},
ignore_index=True,
)

return df

def generate_descriptive_stats_table(
self,
output_format: str = "df",
save_path: Optional[Path] = None,
) -> Union[pd.DataFrame, wandb.Table]:
"""Generate descriptive stats table. Calculates relevant statistics
from the evaluation dataset and returns a pandas dataframe or wandb
table. If save_path is provided, the table is saved as a csv file.
Args:
output_format (str, optional): Output format. Defaults to "df".
save_path (Optional[Path], optional): Path to save the table as a csv file. Defaults to None.
Returns:
Union[pd.DataFrame, wandb.Table]: Table 1.
"""

if "age" in self.eval_dataset.columns:
age_stats = self._generate_age_stats()

if "is_female" in self.eval_dataset.columns:
sex_stats = self._generate_sex_stats()

visit_level_stats = self._generate_visit_level_stats()

patient_level_stats = self._generate_patient_level_stats()

table_1_df_list = [age_stats, sex_stats, visit_level_stats, patient_level_stats]
table_1 = pd.concat(table_1_df_list, ignore_index=True)

if save_path is not None:
output_table(output_format="df", df=table_1)

table_1.to_csv(save_path, index=False)

return output_table(output_format=output_format, df=table_1)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tables for evaluation of models."""
"""Tables for description and evaluation of models and patient population."""
from typing import Union

import pandas as pd
Expand Down
Loading

0 comments on commit bf0edb1

Please sign in to comment.