This repository has been archived by the owner on May 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #350 from Aarhus-Psychiatry-Research/martbern/refa…
…ctor_eval_model Refactor eval model
- Loading branch information
Showing
43 changed files
with
654 additions
and
595 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from pathlib import Path | ||
|
||
from application.artifacts.plots.performance_by_n_hba1c import ( | ||
plot_performance_by_n_hba1c, | ||
) | ||
from psycop_model_training.model_eval.dataclasses import ArtifactContainer, EvalDataset | ||
|
||
|
||
def create_custom_plot_artifacts( | ||
eval_dataset: EvalDataset, | ||
save_dir: Path, | ||
) -> list[ArtifactContainer]: | ||
"""A collection of plots that are only generated for your specific use | ||
case.""" | ||
return [ | ||
ArtifactContainer( | ||
label="performance_by_n_hba1c", | ||
artifact=plot_performance_by_n_hba1c( | ||
eval_dataset=eval_dataset, | ||
save_path=save_dir / "performance_by_n_hba1c.png", | ||
), | ||
), | ||
] |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
138 changes: 0 additions & 138 deletions
138
src/psycop_model_training/application_modules/train_model.py
This file was deleted.
Oops, something went wrong.
Empty file.
65 changes: 65 additions & 0 deletions
65
src/psycop_model_training/application_modules/train_model/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Train a single model and evaluate it.""" | ||
import wandb | ||
from omegaconf import DictConfig | ||
|
||
from application.artifacts.custom_artifacts import create_custom_plot_artifacts | ||
from psycop_model_training.application_modules.wandb_handler import WandbHandler | ||
from psycop_model_training.data_loader.utils import ( | ||
load_and_filter_train_and_val_from_cfg, | ||
) | ||
from psycop_model_training.model_eval.model_evaluator import ModelEvaluator | ||
from psycop_model_training.preprocessing.post_split.pipeline import ( | ||
create_post_split_pipeline, | ||
) | ||
from psycop_model_training.training.train_and_predict import train_and_predict | ||
from psycop_model_training.utils.col_name_inference import get_col_names | ||
from psycop_model_training.utils.config_schemas.conf_utils import ( | ||
convert_omegaconf_to_pydantic_object, | ||
) | ||
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema | ||
from psycop_model_training.utils.utils import PROJECT_ROOT, SHARED_RESOURCES_PATH | ||
|
||
|
||
def train_model(cfg: DictConfig): | ||
"""Main function for training a single model.""" | ||
if not isinstance(cfg, FullConfigSchema): | ||
cfg = convert_omegaconf_to_pydantic_object(cfg) | ||
|
||
WandbHandler(cfg=cfg).setup_wandb() | ||
|
||
dataset = load_and_filter_train_and_val_from_cfg(cfg) | ||
pipe = create_post_split_pipeline(cfg) | ||
outcome_col_name, train_col_names = get_col_names(cfg, dataset.train) | ||
|
||
eval_dataset = train_and_predict( | ||
cfg=cfg, | ||
train=dataset.train, | ||
val=dataset.val, | ||
pipe=pipe, | ||
outcome_col_name=outcome_col_name, | ||
train_col_names=train_col_names, | ||
n_splits=cfg.train.n_splits, | ||
) | ||
|
||
if wandb.run.id and cfg.project.wandb.mode != "offline": | ||
eval_dir_path = SHARED_RESOURCES_PATH / cfg.project.name / wandb.run.id | ||
else: | ||
eval_dir_path = PROJECT_ROOT / "tests" / "test_eval_results" | ||
eval_dir_path.mkdir(parents=True, exist_ok=True) | ||
|
||
custom_artifacts = create_custom_plot_artifacts( | ||
eval_dataset=eval_dataset, | ||
save_dir=eval_dir_path, | ||
) | ||
|
||
roc_auc = ModelEvaluator( | ||
eval_dir_path=eval_dir_path, | ||
cfg=cfg, | ||
pipe=pipe, | ||
eval_ds=eval_dataset, | ||
raw_train_set=dataset.train, | ||
custom_artifacts=custom_artifacts, | ||
upload_to_wandb=cfg.project.wandb.mode != "offline", | ||
).evaluate() | ||
|
||
return roc_auc |
42 changes: 42 additions & 0 deletions
42
src/psycop_model_training/application_modules/wandb_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Any, Dict | ||
|
||
import wandb | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema | ||
from psycop_model_training.utils.config_schemas.project import WandbSchema | ||
from psycop_model_training.utils.utils import create_wandb_folders, flatten_nested_dict | ||
|
||
|
||
class WandbHandler: | ||
"""Class for handling wandb setup and logging.""" | ||
|
||
def __init__(self, cfg: FullConfigSchema): | ||
self.cfg = cfg | ||
|
||
# Required on Windows because the wandb process is sometimes unable to initialise | ||
create_wandb_folders() | ||
|
||
def _get_cfg_as_dict(self) -> dict[str, Any]: | ||
if isinstance(self.cfg, DictConfig): | ||
# Create flattened dict for logging to wandb | ||
# Wandb doesn't allow configs to be nested, so we | ||
# flatten it. | ||
return flatten_nested_dict( | ||
OmegaConf.to_container(self.cfg), | ||
sep=".", | ||
) # type: ignore | ||
else: | ||
# For testing, we can take a FullConfig object instead. Simplifies boilerplate. | ||
return self.cfg.__dict__ | ||
|
||
def setup_wandb(self): | ||
"""Setup wandb for the current run.""" | ||
wandb.init( | ||
project=f"{self.cfg.project.name}-baseline-model-training", | ||
reinit=True, | ||
mode=self.cfg.project.wandb.mode, | ||
group=self.cfg.project.wandb.group, | ||
config=self._get_cfg_as_dict(), | ||
entity=self.cfg.project.wandb.entity, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.