forked from Tencent/MedicalNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
151 lines (129 loc) · 5.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
'''
Training code for MRBrainS18 datasets segmentation
Written by Whalechen
'''
from setting import parse_opts
from datasets.brains18 import BrainS18Dataset
from model import generate_model
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import time
from utils.logger import log
from scipy import ndimage
import os
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder, sets):
# settings
batches_per_epoch = len(data_loader)
log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
loss_seg = nn.CrossEntropyLoss(ignore_index=-1)
print("Current setting is:")
print(sets)
print("\n\n")
if not sets.no_cuda:
loss_seg = loss_seg.cuda()
model.train()
train_time_sp = time.time()
for epoch in range(total_epochs):
log.info('Start epoch {}'.format(epoch))
scheduler.step()
log.info('lr = {}'.format(scheduler.get_lr()))
for batch_id, batch_data in enumerate(data_loader):
# getting data batch
batch_id_sp = epoch * batches_per_epoch
volumes, label_masks = batch_data
if not sets.no_cuda:
volumes = volumes.cuda()
optimizer.zero_grad()
out_masks = model(volumes)
# resize label
[n, _, d, h, w] = out_masks.shape
new_label_masks = np.zeros([n, d, h, w])
for label_id in range(n):
label_mask = label_masks[label_id]
[ori_c, ori_d, ori_h, ori_w] = label_mask.shape
label_mask = np.reshape(label_mask, [ori_d, ori_h, ori_w])
scale = [d*1.0/ori_d, h*1.0/ori_h, w*1.0/ori_w]
label_mask = ndimage.interpolation.zoom(label_mask, scale, order=0)
new_label_masks[label_id] = label_mask
new_label_masks = torch.tensor(new_label_masks).to(torch.int64)
if not sets.no_cuda:
new_label_masks = new_label_masks.cuda()
# calculating loss
loss_value_seg = loss_seg(out_masks, new_label_masks)
loss = loss_value_seg
loss.backward()
optimizer.step()
avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
log.info(
'Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}'\
.format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))
if not sets.ci_test:
# save model
if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
#if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
model_save_dir = os.path.dirname(model_save_path)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
torch.save({
'ecpoch': epoch,
'batch_id': batch_id,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()},
model_save_path)
print('Finished training')
if sets.ci_test:
exit()
if __name__ == '__main__':
# settting
sets = parse_opts()
if sets.ci_test:
sets.img_list = './toy_data/test_ci.txt'
sets.n_epochs = 1
sets.no_cuda = True
sets.data_root = './toy_data'
sets.pretrain_path = ''
sets.num_workers = 0
sets.model_depth = 10
sets.resnet_shortcut = 'A'
sets.input_D = 14
sets.input_H = 28
sets.input_W = 28
# getting model
torch.manual_seed(sets.manual_seed)
model, parameters = generate_model(sets)
print (model)
# optimizer
if sets.ci_test:
params = [{'params': parameters, 'lr': sets.learning_rate}]
else:
params = [
{ 'params': parameters['base_parameters'], 'lr': sets.learning_rate },
{ 'params': parameters['new_parameters'], 'lr': sets.learning_rate*100 }
]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# train from resume
if sets.resume_path:
if os.path.isfile(sets.resume_path):
print("=> loading checkpoint '{}'".format(sets.resume_path))
checkpoint = torch.load(sets.resume_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(sets.resume_path, checkpoint['epoch']))
# getting data
sets.phase = 'train'
if sets.no_cuda:
sets.pin_memory = False
else:
sets.pin_memory = True
training_dataset = BrainS18Dataset(sets.data_root, sets.img_list, sets)
data_loader = DataLoader(training_dataset, batch_size=sets.batch_size, shuffle=True, num_workers=sets.num_workers, pin_memory=sets.pin_memory)
# training
train(data_loader, model, optimizer, scheduler, total_epochs=sets.n_epochs, save_interval=sets.save_intervals, save_folder=sets.save_folder, sets=sets)