Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #351 from Aarhus-Psychiatry-Research/martbern/remo…
Browse files Browse the repository at this point in the history
…ve_t2d_specifics

Remove t2d specifics
  • Loading branch information
MartinBernstorff authored Dec 23, 2022
2 parents 8ee9c69 + 25e1823 commit 9231b13
Show file tree
Hide file tree
Showing 49 changed files with 192 additions and 172 deletions.
6 changes: 5 additions & 1 deletion application/artifacts/plots/performance_by_n_hba1c.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def plot_performance_by_n_hba1c(
metric_fn: Callable = roc_auc_score,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
save_path: Optional[Path] = None,
n_hba1c_col_name: Optional[str] = "eval_hba1c_within_9999_days_count_fallback_nan",
) -> Union[None, Path]:
"""Plot bar plot of performance (default AUC) by number of HbA1c
measurements.
Expand All @@ -34,21 +35,24 @@ def plot_performance_by_n_hba1c(
metric_fn (Callable): Callable which returns the metric to calculate
y_limits (tuple[float, float]): y-axis limits. Defaults to (0.5, 1.0).
save_path (Path, optional): Path to save figure. Defaults to None.
n_hba1c_col_name (str, optional): Name of column containing number of
HbA1c measurements. Defaults to "n_hba1c".
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
"""

df = create_performance_by_input(
eval_dataset=eval_dataset,
input=eval_dataset.custom.n_hba1c,
input=eval_dataset[n_hba1c_col_name],
input_name="n_hba1c",
metric_fn=metric_fn,
bins=bins,
prettify_bins=prettify_bins,
)

sort_order = sorted(df["n_hba1c_binned"].unique())

return plot_basic_chart(
x_values=df["n_hba1c_binned"],
y_values=df["metric"],
Expand Down
4 changes: 2 additions & 2 deletions application/config/data/default_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ data:
id: dw_ek_borger
age: pred_age_in_years
exclusion_timestamp: timestamp_exclusion
custom:
n_hba1c: eval_hba1c_within_9999_days_count_fallback_nan
custom_columns:
- eval_hba1c_within_9999_days_count_fallback_nan
2 changes: 1 addition & 1 deletion application/inspect_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Example of how to inspect a dataset using the configs."""
from psycop_model_training.config_schemas import load_test_cfg_as_pydantic
from psycop_model_training.data_loader.utils import (
load_and_filter_train_from_cfg,
load_train_raw,
)
from psycop_model_training.utils.config_schemas import load_test_cfg_as_pydantic


def main():
Expand Down
15 changes: 10 additions & 5 deletions application/train_model_from_application_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
file, rather than an installed module.
"""
import hydra
from omegaconf import DictConfig

from application.artifacts.custom_artifacts import create_custom_plot_artifacts
from psycop_model_training.application_modules.train_model.main import train_model
from psycop_model_training.config_schemas.conf_utils import (
convert_omegaconf_to_pydantic_object,
)
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.training.train_and_predict import CONFIG_PATH


Expand All @@ -14,10 +20,9 @@
config_name="default_config",
version_base="1.2",
)
def main():
def main(cfg: DictConfig):
"""Main."""
train_model()

if not isinstance(cfg, FullConfigSchema):
cfg = convert_omegaconf_to_pydantic_object(cfg)

if __name__ == "__main__":
main()
train_model(cfg=cfg, custom_artifact_fn=create_custom_plot_artifacts)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "psycop_model_training"
version = "0.27.2"
description = "Training scripts for the psycop-t2d project"
description = "PSYCOP model training utilities"
authors = ["Your Name <you@example.com>"]

[tool.poetry.dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import random
from typing import List, Optional, Union

import pandas as pd
from wasabi import Printer

from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.col_name_inference import (
infer_look_distance,
infer_outcome_col_name,
)
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema


class TrainerSpec(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import wandb
from random_word import RandomWords

from psycop_model_training.utils.config_schemas.conf_utils import (
load_app_cfg_as_pydantic,
)
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.config_schemas.conf_utils import load_app_cfg_as_pydantic
from psycop_model_training.config_schemas.full_config import FullConfigSchema


def create_random_wandb_group_name():
Expand Down
43 changes: 24 additions & 19 deletions src/psycop_model_training/application_modules/train_model/main.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
"""Train a single model and evaluate it."""
from typing import Callable, Optional

import wandb
from omegaconf import DictConfig

from application.artifacts.custom_artifacts import create_custom_plot_artifacts
from psycop_model_training.application_modules.wandb_handler import WandbHandler
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.data_loader.utils import (
load_and_filter_train_and_val_from_cfg,
)
from psycop_model_training.model_eval.dataclasses import ArtifactContainer
from psycop_model_training.model_eval.model_evaluator import ModelEvaluator
from psycop_model_training.preprocessing.post_split.pipeline import (
create_post_split_pipeline,
)
from psycop_model_training.training.train_and_predict import train_and_predict
from psycop_model_training.utils.col_name_inference import get_col_names
from psycop_model_training.utils.config_schemas.conf_utils import (
convert_omegaconf_to_pydantic_object,
)
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.utils import PROJECT_ROOT, SHARED_RESOURCES_PATH


def train_model(cfg: DictConfig):
"""Main function for training a single model."""
if not isinstance(cfg, FullConfigSchema):
cfg = convert_omegaconf_to_pydantic_object(cfg)
def get_eval_dir(cfg: FullConfigSchema):
"""Get the directory to save evaluation results to."""
if wandb.run.id and cfg.project.wandb.mode != "offline":
eval_dir_path = SHARED_RESOURCES_PATH / cfg.project.name / wandb.run.name
else:
eval_dir_path = PROJECT_ROOT / "tests" / "test_eval_results"
eval_dir_path.mkdir(parents=True, exist_ok=True)

return eval_dir_path


def train_model(cfg: FullConfigSchema, custom_artifact_fn: Optional[Callable] = None):
"""Main function for training a single model."""
WandbHandler(cfg=cfg).setup_wandb()
eval_dir_path = get_eval_dir(cfg)

dataset = load_and_filter_train_and_val_from_cfg(cfg)
pipe = create_post_split_pipeline(cfg)
Expand All @@ -41,15 +48,13 @@ def train_model(cfg: DictConfig):
n_splits=cfg.train.n_splits,
)

if wandb.run.id and cfg.project.wandb.mode != "offline":
eval_dir_path = SHARED_RESOURCES_PATH / cfg.project.name / wandb.run.id
else:
eval_dir_path = PROJECT_ROOT / "tests" / "test_eval_results"
eval_dir_path.mkdir(parents=True, exist_ok=True)

custom_artifacts = create_custom_plot_artifacts(
eval_dataset=eval_dataset,
save_dir=eval_dir_path,
custom_artifacts = (
custom_artifact_fn(
eval_dataset=eval_dataset,
save_dir=eval_dir_path,
)
if custom_artifact_fn
else None
)

roc_auc = ModelEvaluator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from wasabi import Printer

from psycop_model_training.application_modules.get_search_space import TrainerSpec
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.config_schemas.full_config import FullConfigSchema


def start_trainer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import wandb
from omegaconf import DictConfig, OmegaConf

from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.config_schemas.project import WandbSchema
from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.config_schemas.project import WandbSchema
from psycop_model_training.utils.utils import create_wandb_folders, flatten_nested_dict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from wandb.sdk.wandb_run import Run # pylint: disable=no-name-in-module
from wasabi import msg

from psycop_model_training.config_schemas.full_config import FullConfigSchema
from psycop_model_training.model_eval.dataclasses import ModelEvalData
from psycop_model_training.model_eval.evaluate_model import evaluate_performance
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.utils import (
MODEL_PREDICTIONS_PATH,
PROJECT_ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf

from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.config_schemas.basemodel import BaseModel
from psycop_model_training.config_schemas.full_config import FullConfigSchema


def convert_omegaconf_to_pydantic_object(
Expand All @@ -30,7 +30,7 @@ def convert_omegaconf_to_pydantic_object(

def load_test_cfg_as_omegaconf(
config_file_name: str = "default_config",
config_dir_path_rel: str = "../../../../tests/config/",
config_dir_path_rel: str = "../../../tests/config/",
overrides: Optional[list[str]] = None,
) -> DictConfig:
"""Load config as omegaconf object."""
Expand Down Expand Up @@ -81,7 +81,6 @@ def load_test_cfg_as_pydantic(
cfg = load_test_cfg_as_omegaconf(
config_file_name=config_file_name,
overrides=overrides,
config_dir_path_rel="../../../../tests/config/",
)

return convert_omegaconf_to_pydantic_object(conf=cfg, allow_mutation=allow_mutation)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union

from psycop_model_training.utils.basemodel import BaseModel


class CustomColNames(BaseModel):
"""All custom column names, i.e. columns that won't generalise across
projects."""

n_hba1c: str
from psycop_model_training.config_schemas.basemodel import BaseModel


class ColumnNamesSchema(BaseModel):
Expand All @@ -20,9 +13,7 @@ class ColumnNamesSchema(BaseModel):
age: str # Name of the age column
exclusion_timestamp: str # Name of the exclusion timestamps column.
# Drops all visits whose pred_timestamp <= exclusion_timestamp.

custom: Optional[CustomColNames] = None
# Column names that are custom to the given prediction problem.
custom_columns: Optional[list[str]] = None


class DataSchema(BaseModel):
Expand All @@ -39,3 +30,4 @@ class DataSchema(BaseModel):

pred_prefix: str # prefix of predictor columns
outc_prefix: str # prefix of outcome columns
outc_prefix: str # prefix of outcome columns
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Eval config schema."""
from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel


class EvalConfSchema(BaseModel):
Expand Down
19 changes: 19 additions & 0 deletions src/psycop_model_training/config_schemas/full_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Full configuration schema."""
from psycop_model_training.config_schemas.basemodel import BaseModel
from psycop_model_training.config_schemas.data import DataSchema
from psycop_model_training.config_schemas.eval import EvalConfSchema
from psycop_model_training.config_schemas.model import ModelConfSchema
from psycop_model_training.config_schemas.preprocessing import PreprocessingConfigSchema
from psycop_model_training.config_schemas.project import ProjectSchema
from psycop_model_training.config_schemas.train import TrainConfSchema


class FullConfigSchema(BaseModel):
"""A recipe for a full configuration object."""

project: ProjectSchema
data: DataSchema
preprocessing: PreprocessingConfigSchema
model: ModelConfSchema
train: TrainConfSchema
eval: EvalConfSchema
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Model configuration schemas."""
from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel


class ModelConfSchema(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from typing import Literal, Optional, Union

from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel


class FeatureSelectionSchema(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Project configuration schemas."""
from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel


class WandbSchema(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Optional

from psycop_model_training.utils.basemodel import BaseModel
from psycop_model_training.config_schemas.basemodel import BaseModel


class TrainConfSchema(BaseModel):
Expand Down
25 changes: 18 additions & 7 deletions src/psycop_model_training/data_loader/col_name_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import Levenshtein
import pandas as pd

from psycop_model_training.utils.config_schemas.data import ColumnNamesSchema
from psycop_model_training.config_schemas.data import ColumnNamesSchema


def get_most_likely_str_from_edit_distance(
Expand Down Expand Up @@ -46,27 +46,38 @@ def check_columns_exist_in_dataset(
"""Check that all columns in the config exist in the dataset."""
# Iterate over attributes in the config
error_strs = []
col_names = []

# Get the column names to check
for attr in dir(col_name_schema):
# Skip private attributes
if attr.startswith("_"):
continue

col_name = getattr(col_name_schema, attr)
attr_val = getattr(col_name_schema, attr)

# Skip col names that are not string
if not isinstance(col_name, str):
if isinstance(attr_val, str):
col_names.append(attr_val)
continue

# Skip col names that are not string
if not isinstance(attr_val, str):
if isinstance(attr_val, list):
for item in attr_val:
if isinstance(item, str):
col_names.append(item)

# Check that the col_names exist in the dataset
for item in col_names:
# Check that the column exists in the dataset
if not col_name in df:
if item not in df:
most_likely_alternatives = get_most_likely_str_from_edit_distance(
candidate_strs=df.columns,
input_str=col_name,
input_str=item,
n_str_to_return=3,
)

error_str = f"Column '{col_name}' in config but not in dataset.\n"
error_str = f"Column '{item}' in config but not in dataset.\n"
error_str += f" Did you mean {most_likely_alternatives}? \n"
error_strs.append(error_str)

Expand Down
Loading

0 comments on commit 9231b13

Please sign in to comment.