-
Notifications
You must be signed in to change notification settings - Fork 0
/
cc1010_training.py
465 lines (422 loc) · 15.8 KB
/
cc1010_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import logging
import sys
import time
import monai
import numpy as np
import torch
from monai.apps.detection.metrics.coco import COCOMetric
from monai.apps.detection.metrics.matching import matching_batch
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.apps.detection.networks.retinanet_network import (
RetinaNet,
resnet_fpn_feature_extractor,
)
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape
from monai.data import DataLoader, Dataset, box_utils, load_decathlon_datalist
from monai.data.utils import no_collation
from monai.networks.nets import resnet
from monai.transforms import NormalizeIntensityd
from monai.utils import set_determinism
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from generate_transforms import (
generate_detection_train_transform,
generate_detection_val_transform,
)
from visualize_image import visualize_one_xy_slice_in_3d_image
from warmup_scheduler import GradualWarmupScheduler
def main():
parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
parser.add_argument(
"-e",
"--environment-file",
default="./config/environment.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./config/config_train_cc1010.json",
help="config json file that stores hyper-parameters",
)
parser.add_argument(
"-v",
"--verbose",
default=False,
action="store_true",
help="whether to print verbose detail during training, recommend True when you are not sure about hyper-parameters",
)
args = parser.parse_args()
set_determinism(seed=0)
amp = True
monai.config.print_config()
torch.backends.cudnn.benchmark = True
torch.set_num_threads(4)
env_dict = json.load(open(args.environment_file))
config_dict = json.load(open(args.config_file))
for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)
# 1. define transform
intensity_transform = NormalizeIntensityd(
keys=["image"],
nonzero=False,
channel_wise=True,
)
train_transforms = generate_detection_train_transform(
"image",
"box",
"label",
args.gt_box_mode,
intensity_transform,
args.patch_size,
args.batch_size,
affine_lps_to_ras=True,
amp=amp,
)
val_transforms = generate_detection_val_transform(
"image",
"box",
"label",
args.gt_box_mode,
intensity_transform,
affine_lps_to_ras=True,
amp=amp,
)
# 2. prepare training data
# create a training data loader
train_data = load_decathlon_datalist(
args.data_list_file_path,
is_segmentation=True,
data_list_key="training",
base_dir=args.data_base_dir,
)
train_ds = Dataset(
data=train_data,
transform=train_transforms,
)
train_loader = DataLoader(
train_ds,
batch_size=1,
shuffle=True,
num_workers=7,
pin_memory=torch.cuda.is_available(),
collate_fn=no_collation,
persistent_workers=True,
)
# create a validation data loader
val_data = load_decathlon_datalist(
args.data_list_file_path,
is_segmentation=True,
data_list_key="validation",
base_dir=args.data_base_dir,
)
val_ds = Dataset(
data=val_data,
transform=val_transforms,
)
val_loader = DataLoader(
val_ds,
batch_size=1,
num_workers=2,
pin_memory=torch.cuda.is_available(),
collate_fn=no_collation,
persistent_workers=True,
)
# 3. build model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1) build anchor generator
# returned_layers: when target boxes are small, set it smaller
# base_anchor_shapes: anchor shape for the most high-resolution output,
# when target boxes are small, set it smaller
anchor_generator = AnchorGeneratorWithAnchorShape(
feature_map_scales=[2**l for l in range(len(args.returned_layers) + 1)],
base_anchor_shapes=args.base_anchor_shapes,
)
# 2) build network
conv1_t_size = [max(7, 2 * s + 1) for s in args.conv1_t_stride]
backbone = resnet.ResNet(
block=resnet.ResNetBottleneck,
layers=[3, 4, 6, 3],
block_inplanes=resnet.get_inplanes(),
n_input_channels=args.n_input_channels,
conv1_t_stride=args.conv1_t_stride,
conv1_t_size=conv1_t_size,
)
feature_extractor = resnet_fpn_feature_extractor(
backbone=backbone,
spatial_dims=args.spatial_dims,
pretrained_backbone=False,
trainable_backbone_layers=None,
returned_layers=args.returned_layers,
)
num_anchors = anchor_generator.num_anchors_per_location()[0]
size_divisible = [
s * 2 * 2 ** max(args.returned_layers)
for s in feature_extractor.body.conv1.stride
]
net = torch.jit.script(
RetinaNet(
spatial_dims=args.spatial_dims,
num_classes=len(args.fg_labels),
num_anchors=num_anchors,
feature_extractor=feature_extractor,
size_divisible=size_divisible,
)
)
pretrained_path = env_dict.get("pretrained_path", "")
if pretrained_path:
print(f"Loading pretrained network {pretrained_path}")
net = torch.jit.load(pretrained_path)
print("Loaded pretrained network")
# 3) build detector
detector = RetinaNetDetector(
network=net, anchor_generator=anchor_generator, debug=args.verbose
).to(device)
# set training components
detector.set_atss_matcher(num_candidates=4, center_in_gt=False)
detector.set_hard_negative_sampler(
batch_size_per_image=64,
positive_fraction=args.balanced_sampler_pos_fraction,
pool_size=20,
min_neg=16,
)
detector.set_target_keys(box_key="box", label_key="label")
# set validation components
detector.set_box_selector_parameters(
score_thresh=args.score_thresh,
topk_candidates_per_level=1000,
nms_thresh=args.nms_thresh,
detections_per_img=100,
)
detector.set_sliding_window_inferer(
roi_size=args.val_patch_size,
overlap=0.25,
sw_batch_size=1,
mode="constant",
device="cpu",
)
# 4. Initialize training
# initialize optimizer
optimizer = torch.optim.SGD(
detector.network.parameters(),
args.lr,
momentum=0.9,
weight_decay=3e-5,
nesterov=True,
)
after_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=150, gamma=0.1
)
scheduler_warmup = GradualWarmupScheduler(
optimizer, multiplier=1, total_epoch=10, after_scheduler=after_scheduler
)
scaler = torch.cuda.amp.GradScaler() if amp else None
optimizer.zero_grad()
optimizer.step()
# initialize tensorboard writer
tensorboard_writer = SummaryWriter(args.tfevent_path)
# 5. train
val_interval = 5 # do validation every val_interval epochs
coco_metric = COCOMetric(
classes=["Nz", "Iz", "LPA", "RPA"], iou_list=[0.1], max_detection=[10]
)
best_val_epoch_metric = 0.0
best_val_epoch = -1 # the epoch that gives best validation metrics
max_epochs = 300
epoch_len = len(train_ds) // train_loader.batch_size
w_cls = config_dict.get(
"w_cls", 1.0
) # weight between classification loss and box regression loss, default 1.0
for epoch in range(max_epochs):
# ------------- Training -------------
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
detector.train()
epoch_loss = 0
epoch_cls_loss = 0
epoch_box_reg_loss = 0
step = 0
start_time = time.time()
scheduler_warmup.step()
# Training
for batch_idx, batch_data in enumerate(tqdm(train_loader)):
step += 1
inputs = [
batch_data_ii["image"].to(device)
for batch_data_i in batch_data
for batch_data_ii in batch_data_i
]
targets = [
dict(
label=batch_data_ii["label"].to(device),
box=batch_data_ii["box"].to(device),
)
for batch_data_i in batch_data
for batch_data_ii in batch_data_i
]
for param in detector.network.parameters():
param.grad = None
if amp and (scaler is not None):
with torch.cuda.amp.autocast():
outputs = detector(inputs, targets)
loss = (
w_cls * outputs[detector.cls_key]
+ outputs[detector.box_reg_key]
)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = detector(inputs, targets)
loss = w_cls * outputs[detector.cls_key] + outputs[detector.box_reg_key]
loss.backward()
optimizer.step()
# save to tensorboard
epoch_loss += loss.detach().item()
epoch_cls_loss += outputs[detector.cls_key].detach().item()
epoch_box_reg_loss += outputs[detector.box_reg_key].detach().item()
if args.verbose:
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
tensorboard_writer.add_scalar(
"train_loss", loss.detach().item(), epoch_len * epoch + step
)
end_time = time.time()
print(f"Training time: {end_time-start_time}s")
del inputs, batch_data
torch.cuda.empty_cache()
gc.collect()
# save to tensorboard
epoch_loss /= step
epoch_cls_loss /= step
epoch_box_reg_loss /= step
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
tensorboard_writer.add_scalar("avg_train_loss", epoch_loss, epoch + 1)
tensorboard_writer.add_scalar("avg_train_cls_loss", epoch_cls_loss, epoch + 1)
tensorboard_writer.add_scalar(
"avg_train_box_reg_loss", epoch_box_reg_loss, epoch + 1
)
tensorboard_writer.add_scalar(
"train_lr", optimizer.param_groups[0]["lr"], epoch + 1
)
# save last trained model
torch.jit.save(detector.network, env_dict["model_path"][:-3] + "_last.pt")
print("saved last model")
# ------------- Validation for model selection -------------
if (epoch + 1) % val_interval == 0:
detector.eval()
val_outputs_all = []
val_targets_all = []
start_time = time.time()
with torch.no_grad():
for val_data in val_loader:
# if all val_data_i["image"] smaller than args.val_patch_size, no need to use inferer
# otherwise, need inferer to handle large input images.
use_inferer = not all(
[
val_data_i["image"][0, ...].numel()
< np.prod(args.val_patch_size)
for val_data_i in val_data
]
)
val_inputs = [
val_data_i.pop("image").to(device) for val_data_i in val_data
]
if amp:
with torch.cuda.amp.autocast():
val_outputs = detector(val_inputs, use_inferer=use_inferer)
else:
val_outputs = detector(val_inputs, use_inferer=use_inferer)
# save outputs for evaluation
val_outputs_all += val_outputs
val_targets_all += val_data
end_time = time.time()
print(f"Validation time: {end_time-start_time}s")
# visualize an inference image and boxes to tensorboard
draw_img = visualize_one_xy_slice_in_3d_image(
gt_boxes=val_data[0]["box"].cpu().detach().numpy(),
image=val_inputs[0][0, ...].cpu().detach().numpy(),
pred_boxes=val_outputs[0][detector.target_box_key]
.cpu()
.detach()
.numpy(),
)
tensorboard_writer.add_image(
"val_img_xy", draw_img.transpose([2, 1, 0]), epoch + 1
)
# compute metrics
del val_inputs
torch.cuda.empty_cache()
results_metric = matching_batch(
iou_fn=box_utils.box_iou,
iou_thresholds=coco_metric.iou_thresholds,
pred_boxes=[
val_data_i[detector.target_box_key].cpu().detach().numpy()
for val_data_i in val_outputs_all
],
pred_classes=[
val_data_i[detector.target_label_key].cpu().detach().numpy()
for val_data_i in val_outputs_all
],
pred_scores=[
val_data_i[detector.pred_score_key].cpu().detach().numpy()
for val_data_i in val_outputs_all
],
gt_boxes=[
val_data_i[detector.target_box_key].cpu().detach().numpy()
for val_data_i in val_targets_all
],
gt_classes=[
val_data_i[detector.target_label_key].cpu().detach().numpy()
for val_data_i in val_targets_all
],
)
val_epoch_metric_dict = coco_metric(results_metric)[0]
print(val_epoch_metric_dict)
# write to tensorboard event
for k in val_epoch_metric_dict.keys():
tensorboard_writer.add_scalar(
"val_" + k, val_epoch_metric_dict[k], epoch + 1
)
val_epoch_metric = val_epoch_metric_dict.values()
val_epoch_metric = sum(val_epoch_metric) / len(val_epoch_metric)
tensorboard_writer.add_scalar("val_metric", val_epoch_metric, epoch + 1)
# save best trained model
if val_epoch_metric > best_val_epoch_metric:
best_val_epoch_metric = val_epoch_metric
best_val_epoch = epoch + 1
torch.jit.save(detector.network, env_dict["model_path"])
print("saved new best metric model")
print(
"current epoch: {} current metric: {:.4f} "
"best metric: {:.4f} at epoch {}".format(
epoch + 1, val_epoch_metric, best_val_epoch_metric, best_val_epoch
)
)
print(
f"train completed, best_metric: {best_val_epoch_metric:.4f} "
f"at epoch: {best_val_epoch}"
)
tensorboard_writer.close()
if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()