-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
103 lines (87 loc) · 3.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# --------------------------------------------------------
# SPCNet training
# Copyright (c) 2021 PIMED@STanford
#
# Written by Arun Seetharaman
# --------------------------------------------------------
import argparse
import numpy as np
import os
import matplotlib.pyplot as plt
from model.SPCNet import SPCNet
from utils.util_functions import create_folds, prepare_data, concatenate_data
def train(args):
"""
Train SPCNet
:param args: network parameters and configurations
"""
# call the SPCNET model
model = SPCNet()
# a dicationary for training data paths including T2, ADC, and Porstate Gland mask plus corresponding lesion labels.
path_dict = {'t2': args.t2_filepath,
'adc': args.adc_filepath,
'prostate': args.mask_filepath,
'all_cancer': args.all_cancer_filepath,
'agg_cancer': args.agg_cancer_filepath,
'ind_cancer': args.ind_cancer_filepath}
t2_list, adc_list, mask_list, label_list, _, _ = prepare_data(path_dict=path_dict,
cancer_only=True)
case_ids = np.arange(len(t2_list))
splits = create_folds(case_ids, args.folds)
for fold_idx, (train, test) in enumerate(splits):
# renormalize each fold from 0 to 1 individually
t2_np, adc_np, y, stats = concatenate_data(t2_list,
adc_list,
label_list,
mask_list,
train,
fold_idx,
args.output_filepath)
x, y, _ = model.get_x_y(t2_np, adc_np, y)
t2_val, adc_val, y_val, _ = concatenate_data(
t2_list,
adc_list,
label_list,
mask_list, test, fold_idx,
args.output_filepath,
stats)
x_val, y_val, num_channels = model.get_x_y(t2_val,
adc_val,
y_val)
validation_data = (x_val, y_val)
# create the SPCNET model
network = model.network(lr=args.lr, num_channels=num_channels)
# train the model preprocessed data
history = network.fit(x,
y,
args.batch_size,
args.epochs,
validation_data=validation_data,
verbose=2)
# Save model weights for each fold
network.save(os.path.join(args.output_filepath, f'hed_fold_{fold_idx}.h5'))
# Plot training and validation loss for each fold.
plt.figure()
plt.plot(history.history['ofuse_loss'])
if validation_data:
plt.plot(history.history['val_ofuse_loss'])
plt.title('ofuse_loss')
if validation_data:
plt.legend(['train', 'val'])
plt.savefig(os.path.join(args.output_filepath, f'loss_{fold_idx}.png'))
if __name__ == "__main__":
# parse parameters
parser = argparse.ArgumentParser(description='')
parser.add_argument('--output_filepath', type=str, required=True)
parser.add_argument('--t2_filepath', type=str, required=True)
parser.add_argument('--adc_filepath', type=str, required=True)
parser.add_argument('--mask_filepath', type=str, required=True)
parser.add_argument('--all_cancer_filepath', type=str, required=True)
parser.add_argument('--agg_cancer_filepath', type=str, required=True)
parser.add_argument('--ind_cancer_filepath', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--folds', type=int, default=5)
parser.add_argument('--lr', type=float, default=1e-3)
args = parser.parse_args()
train(args)