forked from zubata88/mdgru
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RUN_mdgru.py
executable file
·139 lines (115 loc) · 6.77 KB
/
RUN_mdgru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from mdgru.helper import define_arguments
__author__ = "Simon Andermatt"
__copyright__ = "Copyright (C) 2017 Simon Andermatt"
import logging
# logging.basicConfig(level=logging.INFO) #- this statement would cuas dupplicate logs
import os
import numpy as np
import sys
from mdgru.data.grid_collection import GridDataCollection, ThreadedGridDataCollection
# from options.parser import clean_datacollection_args
from mdgru.runner import Runner
from mdgru.helper import compile_arguments
import argparse
def run_mdgru(args=None):
"""Executes a training/ testing or training and testing run for the mdgru network"""
# Add logging.Streamhandler already in RUN_mdgru.py, to allow catching early debug logs; instead, FileHandler will still be added later in runner.py
loggers = [logging.getLogger(n) for n in ['model', 'eval', 'runner', 'helper', 'data']]
formatter = logging.Formatter('%(asctime)s %(name)s\t%(levelname)s:\t%(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
ch.setLevel(logging.DEBUG)
for logger in loggers:
logger.setLevel(logging.DEBUG)
logger.addHandler(ch)
logging.getLogger('runner').info('Starting argument parsing!')
# Parse arguments
fullparameters = " ".join(args if args is not None else sys.argv)
parser = argparse.ArgumentParser(description="evaluate any data with given parameters", add_help=False)
pre_parameter = parser.add_argument_group('Options changing parameter. Use together with --help')
pre_parameter.add_argument('--use_pytorch', action='store_true', help='use experimental pytorch version. Only core functionality is provided')
pre_parameter.add_argument('--gpu', type=int, nargs='+', default=[0], help='set gpu ids')
pre_parameter.add_argument('--nonthreaded', action="store_true",
help="disallow threading during training to preload data before the processing")
pre_parameter.add_argument('--dice_loss_label', default=None, type=int, nargs="+", help='labels for which the dice losses shall be calculated')
pre_parameter.add_argument('--dice_loss_weight', default=None, type=float, nargs="+", help='weights for the dice losses of the individual classes. same size as dice_loss_label or scalar if dice_autoweighted. final loss: sum(dice_loss_weight)*diceloss + (1-sum(dice_loss_weight))*crossentropy')
pre_parameter.add_argument('--dice_autoweighted', action="store_true", help='weights the label Dices with the squared inverse gold standard area/volume; specify which labels with dice_loss_label; sum(dice_loss_weight) is used as a weighting between crossentropy and diceloss')
pre_parameter.add_argument('--dice_generalized', action="store_true", help='total intersections of all labels over total sums of all labels, instead of linearly combined class Dices')
pre_parameter.add_argument('--dice_cc', action='store_true', help='dice loss for binary segmentation per true component')
pre_args, _ = parser.parse_known_args(args=args)
parser.add_argument('-h','--help', action='store_true', help='print this help message')
# Set environment flag(s) and finally import the classes that depend upon them
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(g) for g in pre_args.gpu])
if pre_args.use_pytorch:
if pre_args.dice_cc:
from mdgru.model_pytorch.mdgru_classification import MDGRUClassificationCC as modelcls
else:
from mdgru.model_pytorch.mdgru_classification import MDGRUClassification as modelcls
from mdgru.eval.torch import SupervisedEvaluationTorch as evalcls
else:
if pre_args.dice_generalized:
from mdgru.model.mdgru_classification import MDGRUClassificationWithGeneralizedDiceLoss as modelcls
elif pre_args.dice_loss_label != None or pre_args.dice_autoweighted:
from mdgru.model.mdgru_classification import MDGRUClassificationWithDiceLoss as modelcls
else:
from mdgru.model.mdgru_classification import MDGRUClassification as modelcls
from mdgru.eval.tf import SupervisedEvaluationTensorflow as evalcls
# Set the necessary classes
# dc = GridDataCollection
tdc = GridDataCollection if pre_args.nonthreaded else ThreadedGridDataCollection
define_arguments(modelcls, parser.add_argument_group('Model Parameters'))
define_arguments(evalcls, parser.add_argument_group('Evaluation Parameters'))
define_arguments(Runner, parser.add_argument_group('Runner Parameters'))
define_arguments(tdc, parser.add_argument_group('Data Parameters'))
args = parser.parse_args(args=args)
# print(args)
if args.help:
parser.print_help()
return
if not args.use_pytorch:
if args.gpubound != 1:
modelcls.set_allowed_gpu_memory_fraction(args.gpubound)
# Set up datacollections
# args_tr, args_val, args_te = clean_datacollection_args(args)
# Set up model and evaluation
kw = vars(args)
args_eval, _ = compile_arguments(evalcls, kw, True, keep_entries=True)
args_model, _ = compile_arguments(modelcls, kw, True, keep_entries=True)
args_data, _ = compile_arguments(tdc, kw, True, keep_entries=True)
args_eval.update(args_model)
args_eval.update(args_data)
if not args.use_pytorch:
if args.checkpointfiles is not None:
args_eval['namespace'] = modelcls.get_model_name_from_ckpt(args.checkpointfiles[0])
args_eval['channels_first'] = args.use_pytorch
#--- add dice loss options
args_eval['dice_loss_label'] = args.dice_loss_label
args_eval['dice_loss_weight'] = args.dice_loss_weight
args_eval['dice_autoweighted'] = args.dice_autoweighted
# if args_tr is not None:
# traindc = tdc(**args_tr)
# if args_val is not None:
# valdc = tdc(**args_val)
# if args_te is not None:
# testdc = dc(**args_te)
# if args.only_test: #FIXME: this is not the smartest way of doing it, make sure that you can allow missing entries in this dict!
# datadict = {"train": testdc, "validation": testdc, "test": testdc}
# elif args.only_train:
# datadict = {"train": traindc, "validation": valdc, "test": valdc}
# else:
# datadict = {"train": traindc, "validation": valdc, "test": testdc}
logging.getLogger('runner').debug('Starting class initialization!')
ev = evalcls(modelcls, tdc, args_eval)
logging.getLogger('runner').debug('Finished class initialization!')
# Set up runner
args_runner, _ = compile_arguments(Runner, kw, True, keep_entries=True)
args_runner.update({
"experimentloc": os.path.join(args.datapath, 'experiments'),
"fullparameters": fullparameters,
# "estimatefilenames": optionname
})
runner = Runner(ev, **args_runner)
# Run computation
return runner.run()
if __name__ == "__main__":
run_mdgru()