Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
feat: update descriptive stats table (#444)
Browse files Browse the repository at this point in the history
- [ ] I have battle-tested on Overtaci (RMAPPS1279)
- [ ] I have assigned ranges (e.g. `>=0.1, <0.2`) to all new
dependencies (allows dependabot to keep dependency ranges wide for
better compatibility)
- [ ] At least one of the commits is prefixed with either "fix:" or
"feat:"

## Notes for reviewers
Reviewers can skip X, but should pay attention to Y.
  • Loading branch information
MartinBernstorff authored Mar 23, 2023
2 parents ca4fc6e + 14716a1 commit 8498bc8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
1 change: 0 additions & 1 deletion .cookiecutter.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"github_user": "",
"version": "0.0.0",
"copyright_year": "2023",
"setup_for_dev_bash_script": ["y", "n"],
"license": ["MIT", "None"],
"_copy_without_render": [
"*.github"
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ figures/
# wandb
wandb
outputs
.*venv*

# virtual env
.venv*/
venv/

# MacOS
Expand Down
35 changes: 27 additions & 8 deletions src/psycop_model_training/data_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from typing import Literal, Optional

import pandas as pd
from psycop_model_training.config_schemas.data import DataSchema
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.config_schemas.preprocessing import (
PreSplitPreprocessingConfigSchema,
)
from psycop_model_training.data_loader.data_classes import SplitDataset
from psycop_model_training.data_loader.data_loader import DataLoader
from psycop_model_training.preprocessing.pre_split.full_processor import (
Expand All @@ -20,25 +24,27 @@ def get_latest_dataset_dir(path: Path) -> Path:


def load_and_filter_split_from_cfg(
cfg: FullConfigSchema,
data_cfg: DataSchema,
pre_split_cfg: PreSplitPreprocessingConfigSchema,
split: Literal["train", "test", "val"],
cache_dir: Optional[Path] = None,
) -> pd.DataFrame:
"""Load train dataset from config.
Args:
cfg (FullConfig): Config
data_cfg (DataSchema): Data config
pre_split_cfg (PreSplitPreprocessingConfigSchema): Pre-split config
split (Literal["train", "test", "val"]): Split to load
cache_dir (Optional[Path], optional): Directory. Defaults to None, in which case no caching is used.
Returns:
pd.DataFrame: Train dataset
"""
dataset = DataLoader(data_cfg=cfg.data).load_dataset_from_dir(split_names=split)
dataset = DataLoader(data_cfg=data_cfg).load_dataset_from_dir(split_names=split)
filtered_data = pre_split_process_full_dataset(
dataset=dataset,
pre_split_cfg=cfg.preprocessing.pre_split,
data_cfg=cfg.data,
pre_split_cfg=pre_split_cfg,
data_cfg=data_cfg,
cache_dir=cache_dir,
)

Expand All @@ -58,14 +64,27 @@ def load_and_filter_train_from_cfg(
Returns:
pd.DataFrame: Train dataset
"""
return load_and_filter_split_from_cfg(cfg=cfg, split="train", cache_dir=cache_dir)
return load_and_filter_split_from_cfg(
pre_split_cfg=cfg.preprocessing.pre_split,
data_cfg=cfg.data,
split="train",
cache_dir=cache_dir,
)


def load_and_filter_train_and_val_from_cfg(cfg: FullConfigSchema) -> SplitDataset:
"""Load train and validation data from file."""
return SplitDataset(
train=load_and_filter_split_from_cfg(cfg=cfg, split="train"),
val=load_and_filter_split_from_cfg(cfg=cfg, split="val"),
train=load_and_filter_split_from_cfg(
pre_split_cfg=cfg.preprocessing.pre_split,
data_cfg=cfg.data,
split="train",
),
val=load_and_filter_split_from_cfg(
pre_split_cfg=cfg.preprocessing.pre_split,
data_cfg=cfg.data,
split="val",
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ def _generate_age_stats(

df = self._get_column_header_df()

age_mean = round(self.eval_dataset.age.mean(), 2)

age_span = f"{self.eval_dataset.age.quantile(0.05)} - {self.eval_dataset.age.quantile(0.95)}"
age_mean = round(self.eval_dataset.age.mean(), 1)
age_span = f"{self.eval_dataset.age.quantile(0.25)} - {self.eval_dataset.age.quantile(0.75)}"

df = df.append( # type: ignore
{
"category": "(visit_level) age (mean / interval)",
"category": "(visit_level) age (mean / 25-75 quartile interval)",
"stat_1": age_mean,
"stat_1_unit": "years",
"stat_2": age_span,
"stat_2_unit": "years",
},
ignore_index=True,
)

age_counts = bin_continuous_data(
self.eval_dataset.age,
bins=[0, 18, 35, 60, 100],
bins=[0, 17, *range(24, 75, 10)],
)[0].value_counts()

age_percentages = round(age_counts / len(self.eval_dataset.age) * 100, 2)
age_percentages = round(age_counts / len(self.eval_dataset.age) * 100, 1)

for i, _ in enumerate(age_counts):
df = df.append( # type: ignore
Expand Down Expand Up @@ -111,6 +111,12 @@ def _generate_eval_col_stats(self) -> pd.DataFrame:

df = self._get_column_header_df()

if (
not hasattr(self.eval_dataset, "custom_columns")
or self.eval_dataset.custom_columns is None
):
return df

eval_cols: list[dict[str, pd.Series]] = [
{name: values}
for name, values in self.eval_dataset.custom_columns.items()
Expand Down Expand Up @@ -174,6 +180,7 @@ def _generate_visit_level_stats(

# General stats
visits_followed_by_positive_outcome = self.eval_dataset.y.sum()

visits_followed_by_positive_outcome_percentage = round(
(visits_followed_by_positive_outcome / len(self.eval_dataset.ids) * 100),
2,
Expand Down Expand Up @@ -290,6 +297,8 @@ def generate_descriptive_stats_table(
Returns:
Union[pd.DataFrame, wandb.Table]: Table 1.
"""
if save_path is not None:
save_path.parent.mkdir(parents=True, exist_ok=True)

if self.eval_dataset.age is not None:
age_stats = self._generate_age_stats()
Expand Down

0 comments on commit 8498bc8

Please sign in to comment.