-
Notifications
You must be signed in to change notification settings - Fork 43
/
run_cellvit.py
103 lines (96 loc) · 3.68 KB
/
run_cellvit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
# Running an Experiment Using CellViT cell segmentation network
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import inspect
import os
import sys
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import wandb
from base_ml.base_cli import ExperimentBaseParser
from cell_segmentation.experiments.experiment_cellvit_pannuke import (
ExperimentCellVitPanNuke,
)
from cell_segmentation.experiments.experiment_cellvit_conic import (
ExperimentCellViTCoNic,
)
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import (
InferenceCellViT,
)
if __name__ == "__main__":
# Parse arguments
configuration_parser = ExperimentBaseParser()
configuration = configuration_parser.parse_arguments()
if configuration["data"]["dataset"].lower() == "pannuke":
experiment_class = ExperimentCellVitPanNuke
elif configuration["data"]["dataset"].lower() == "conic":
experiment_class = ExperimentCellViTCoNic
# Setup experiment
if "checkpoint" in configuration:
# continue checkpoint
experiment = experiment_class(
default_conf=configuration, checkpoint=configuration["checkpoint"]
)
outdir = experiment.run_experiment()
inference = InferenceCellViT(
run_dir=outdir,
gpu=configuration["gpu"],
checkpoint_name=configuration["eval_checkpoint"],
magnification=configuration["data"].get("magnification", 40),
)
(
trained_model,
inference_dataloader,
dataset_config,
) = inference.setup_patch_inference()
inference.run_patch_inference(
trained_model, inference_dataloader, dataset_config, generate_plots=False
)
else:
experiment = experiment_class(default_conf=configuration)
if configuration["run_sweep"] is True:
# run new sweep
sweep_configuration = experiment_class.extract_sweep_arguments(
configuration
)
os.environ["WANDB_DIR"] = os.path.abspath(
configuration["logging"]["wandb_dir"]
)
sweep_id = wandb.sweep(
sweep=sweep_configuration, project=configuration["logging"]["project"]
)
wandb.agent(sweep_id=sweep_id, function=experiment.run_experiment)
elif "agent" in configuration and configuration["agent"] is not None:
# add agent to already existing sweep, not run sweep must be set to true
configuration["run_sweep"] = True
os.environ["WANDB_DIR"] = os.path.abspath(
configuration["logging"]["wandb_dir"]
)
wandb.agent(
sweep_id=configuration["agent"], function=experiment.run_experiment
)
else:
# casual run
outdir = experiment.run_experiment()
inference = InferenceCellViT(
run_dir=outdir,
gpu=configuration["gpu"],
checkpoint_name=configuration["eval_checkpoint"],
magnification=configuration["data"].get("magnification", 40),
)
(
trained_model,
inference_dataloader,
dataset_config,
) = inference.setup_patch_inference()
inference.run_patch_inference(
trained_model,
inference_dataloader,
dataset_config,
generate_plots=False,
)
wandb.finish()