Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #350 from Aarhus-Psychiatry-Research/martbern/refa…
Browse files Browse the repository at this point in the history
…ctor_eval_model

Refactor eval model
  • Loading branch information
MartinBernstorff authored Dec 22, 2022
2 parents 625e1ff + f683fe6 commit 288feb6
Show file tree
Hide file tree
Showing 43 changed files with 654 additions and 595 deletions.
23 changes: 23 additions & 0 deletions application/artifacts/custom_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

from application.artifacts.plots.performance_by_n_hba1c import (
plot_performance_by_n_hba1c,
)
from psycop_model_training.model_eval.dataclasses import ArtifactContainer, EvalDataset


def create_custom_plot_artifacts(
eval_dataset: EvalDataset,
save_dir: Path,
) -> list[ArtifactContainer]:
"""A collection of plots that are only generated for your specific use
case."""
return [
ArtifactContainer(
label="performance_by_n_hba1c",
artifact=plot_performance_by_n_hba1c(
eval_dataset=eval_dataset,
save_path=save_dir / "performance_by_n_hba1c.png",
),
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

from sklearn.metrics import roc_auc_score

from psycop_model_training.model_eval.base_artifacts.plots.base_charts import (
plot_basic_chart,
)
from psycop_model_training.model_eval.base_artifacts.plots.utils import (
create_performance_by_input,
)
from psycop_model_training.model_eval.dataclasses import EvalDataset
from psycop_model_training.model_eval.plots.base_charts import plot_basic_chart
from psycop_model_training.model_eval.plots.utils import create_performance_by_input


def plot_performance_by_n_hba1c(
Expand Down
1 change: 0 additions & 1 deletion application/config/train/default_training.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ n_splits: 3 # (int, Null): Number of k-folds during CV. If Null, loads pre-defin
n_trials_per_lookahead: 300
n_jobs_per_trainer: 1
n_active_trainers: 10
random_delay_per_job_seconds: 0
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from psycop_model_training.application_modules.get_search_space import (
SearchSpaceInferrer,
)
from psycop_model_training.application_modules.setup import setup
from psycop_model_training.application_modules.process_manager_setup import setup
from psycop_model_training.application_modules.trainer_spawner import spawn_trainers
from psycop_model_training.data_loader.data_loader import DataLoader

Expand Down
4 changes: 2 additions & 2 deletions application/train_model_from_application_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"""
import hydra

from psycop_model_training.application_modules.train_model import train_model
from psycop_model_training.training.train_and_eval import CONFIG_PATH
from psycop_model_training.application_modules.train_model.main import train_model
from psycop_model_training.training.train_and_predict import CONFIG_PATH


@hydra.main(
Expand Down
138 changes: 0 additions & 138 deletions src/psycop_model_training/application_modules/train_model.py

This file was deleted.

Empty file.
65 changes: 65 additions & 0 deletions src/psycop_model_training/application_modules/train_model/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Train a single model and evaluate it."""
import wandb
from omegaconf import DictConfig

from application.artifacts.custom_artifacts import create_custom_plot_artifacts
from psycop_model_training.application_modules.wandb_handler import WandbHandler
from psycop_model_training.data_loader.utils import (
load_and_filter_train_and_val_from_cfg,
)
from psycop_model_training.model_eval.model_evaluator import ModelEvaluator
from psycop_model_training.preprocessing.post_split.pipeline import (
create_post_split_pipeline,
)
from psycop_model_training.training.train_and_predict import train_and_predict
from psycop_model_training.utils.col_name_inference import get_col_names
from psycop_model_training.utils.config_schemas.conf_utils import (
convert_omegaconf_to_pydantic_object,
)
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.utils import PROJECT_ROOT, SHARED_RESOURCES_PATH


def train_model(cfg: DictConfig):
"""Main function for training a single model."""
if not isinstance(cfg, FullConfigSchema):
cfg = convert_omegaconf_to_pydantic_object(cfg)

WandbHandler(cfg=cfg).setup_wandb()

dataset = load_and_filter_train_and_val_from_cfg(cfg)
pipe = create_post_split_pipeline(cfg)
outcome_col_name, train_col_names = get_col_names(cfg, dataset.train)

eval_dataset = train_and_predict(
cfg=cfg,
train=dataset.train,
val=dataset.val,
pipe=pipe,
outcome_col_name=outcome_col_name,
train_col_names=train_col_names,
n_splits=cfg.train.n_splits,
)

if wandb.run.id and cfg.project.wandb.mode != "offline":
eval_dir_path = SHARED_RESOURCES_PATH / cfg.project.name / wandb.run.id
else:
eval_dir_path = PROJECT_ROOT / "tests" / "test_eval_results"
eval_dir_path.mkdir(parents=True, exist_ok=True)

custom_artifacts = create_custom_plot_artifacts(
eval_dataset=eval_dataset,
save_dir=eval_dir_path,
)

roc_auc = ModelEvaluator(
eval_dir_path=eval_dir_path,
cfg=cfg,
pipe=pipe,
eval_ds=eval_dataset,
raw_train_set=dataset.train,
custom_artifacts=custom_artifacts,
upload_to_wandb=cfg.project.wandb.mode != "offline",
).evaluate()

return roc_auc
42 changes: 42 additions & 0 deletions src/psycop_model_training/application_modules/wandb_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Dict

import wandb
from omegaconf import DictConfig, OmegaConf

from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.config_schemas.project import WandbSchema
from psycop_model_training.utils.utils import create_wandb_folders, flatten_nested_dict


class WandbHandler:
"""Class for handling wandb setup and logging."""

def __init__(self, cfg: FullConfigSchema):
self.cfg = cfg

# Required on Windows because the wandb process is sometimes unable to initialise
create_wandb_folders()

def _get_cfg_as_dict(self) -> dict[str, Any]:
if isinstance(self.cfg, DictConfig):
# Create flattened dict for logging to wandb
# Wandb doesn't allow configs to be nested, so we
# flatten it.
return flatten_nested_dict(
OmegaConf.to_container(self.cfg),
sep=".",
) # type: ignore
else:
# For testing, we can take a FullConfig object instead. Simplifies boilerplate.
return self.cfg.__dict__

def setup_wandb(self):
"""Setup wandb for the current run."""
wandb.init(
project=f"{self.cfg.project.name}-baseline-model-training",
reinit=True,
mode=self.cfg.project.wandb.mode,
group=self.cfg.project.wandb.group,
config=self._get_cfg_as_dict(),
entity=self.cfg.project.wandb.entity,
)
4 changes: 2 additions & 2 deletions src/psycop_model_training/archive/model_training_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from wasabi import msg

from psycop_model_training.model_eval.dataclasses import ModelEvalData
from psycop_model_training.model_eval.evaluate_model import run_full_evaluation
from psycop_model_training.model_eval.evaluate_model import evaluate_performance
from psycop_model_training.utils.config_schemas.full_config import FullConfigSchema
from psycop_model_training.utils.utils import (
MODEL_PREDICTIONS_PATH,
Expand Down Expand Up @@ -166,7 +166,7 @@ def _do_evaluation(self, run_id: str) -> None:
run: Run = wandb.init(project=self.project_name, entity=self.entity, id=run_id) # type: ignore

# run evaluation
run_full_evaluation(
evaluate_performance(
cfg=eval_data.cfg,
eval_dataset=eval_data.eval_dataset,
pipe_metadata=eval_data.pipe_metadata,
Expand Down
Loading

0 comments on commit 288feb6

Please sign in to comment.