From f325297b7e9417658a226adb7b2cbcc925ae924d Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Fri, 1 Sep 2023 15:32:01 +0800 Subject: [PATCH] Add YOLOv8x Segmentation --- MODEL_ZOO.md | 12 +- configs/yolov8/README.md | 15 +- configs/yolov8/seg/hyp.scratch.high.seg.yaml | 73 ++ configs/yolov8/seg/yolov8-seg-base.yaml | 49 + configs/yolov8/seg/yolov8x-seg.yaml | 13 + demo/predict.py | 150 ++- deploy/predict.py | 30 +- deploy/test.py | 14 +- mindyolo/data/albumentations.py | 35 +- mindyolo/data/copypaste.py | 48 - mindyolo/data/dataset.py | 1133 ++++++++++++------ mindyolo/data/loader.py | 10 +- mindyolo/data/perspective.py | 110 -- mindyolo/data/utils.py | 129 ++ mindyolo/models/heads/__init__.py | 10 +- mindyolo/models/heads/yolov8_head.py | 44 + mindyolo/models/losses/label_assignment.py | 1 - mindyolo/models/losses/yolov8_loss.py | 182 ++- mindyolo/models/model_factory.py | 6 +- mindyolo/models/yolov3.py | 11 +- mindyolo/models/yolov4.py | 10 +- mindyolo/models/yolov5.py | 11 +- mindyolo/models/yolov7.py | 11 +- mindyolo/models/yolov8.py | 11 +- mindyolo/models/yolox.py | 11 +- mindyolo/utils/callback.py | 3 +- mindyolo/utils/metrics.py | 148 ++- mindyolo/{data => utils}/poly.py | 0 mindyolo/utils/train_step_factory.py | 102 +- mindyolo/utils/trainer_factory.py | 16 +- mindyolo/utils/utils.py | 9 +- mindyolo/version.py | 2 +- test.py | 201 +++- tests/dataset_plots.py | 5 +- tests/modules/test_create_loader.py | 11 +- tests/modules/test_create_trainer.py | 1 + train.py | 16 +- 37 files changed, 1929 insertions(+), 714 deletions(-) create mode 100644 configs/yolov8/seg/hyp.scratch.high.seg.yaml create mode 100644 configs/yolov8/seg/yolov8-seg-base.yaml create mode 100644 configs/yolov8/seg/yolov8x-seg.yaml delete mode 100644 mindyolo/data/copypaste.py delete mode 100644 mindyolo/data/perspective.py create mode 100644 mindyolo/data/utils.py delete mode 100644 mindyolo/models/losses/label_assignment.py rename mindyolo/{data => utils}/poly.py (100%) diff --git a/MODEL_ZOO.md b/MODEL_ZOO.md index b67323f6..ecca1395 100644 --- a/MODEL_ZOO.md +++ b/MODEL_ZOO.md @@ -1,5 +1,7 @@ # MindYOLO Model Zoo and Baselines +## Detection + | Name | Scale | Context | ImageSize | Dataset | Box mAP (%) | Params | FLOPs | Recipe | Download | |--------|--------------------|----------|-----------|--------------|-------------|--------|--------|--------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------| | YOLOv8 | N | D910x8-G | 640 | MS COCO 2017 | 37.2 | 3.2M | 8.7G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/yolov8n.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-n_500e_mAP372-cc07f5bd.ckpt) | @@ -26,12 +28,16 @@ | YOLOX | X | D910x8-G | 640 | MS COCO 2017 | 51.6 | 99.1M | 281.9G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolox/yolox-x.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolox/yolox-x_300e_map516-52216d90.ckpt) | | YOLOX | Darknet53 | D910x8-G | 640 | MS COCO 2017 | 47.7 | 63.7M | 185.3G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolox/yolox-darknet53.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolox/yolox-darknet53_300e_map477-b5fcaba9.ckpt) | -
+## Segmentation + +| Name | Scale | Context | ImageSize | Dataset | Box mAP (%) | Mask mAP (%) | Params | FLOPs | Recipe | Download | +|------------|-------|----------|-----------|--------------|-------------|--------------|--------|--------|---------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------| +| YOLOv8-seg | X | D910x8-G | 640 | MS COCO 2017 | 52.5 | 42.9 | 71.8M | 344.1G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/seg/yolov8x-seg.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x-seg_300e_mAP_mask_429-b4920557.ckpt) | -#### Depoly inference +## Depoly inference - See [support list](./deploy/README.md) -#### Notes +## Notes - Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. - Box mAP: Accuracy reported on the validation set. diff --git a/configs/yolov8/README.md b/configs/yolov8/README.md index 17e6fc0b..c92de332 100644 --- a/configs/yolov8/README.md +++ b/configs/yolov8/README.md @@ -9,6 +9,8 @@ Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-ar ## Results +### Detection +
| Name | Scale | Arch | Context | ImageSize | Dataset | Box mAP (%) | Params | FLOPs | Recipe | Download | @@ -20,9 +22,18 @@ Ultralytics YOLOv8, developed by Ultralytics, is a cutting-edge, state-of-the-ar | YOLOv8 | X | P5 | D910x8-G | 640 | MS COCO 2017 | 53.7 | 68.2M | 257.8G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/yolov8x.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x_500e_mAP537-b958e1c7.ckpt) |
-
-#### Notes +### Segmentation + +
+ +| Name | Scale | Arch | Context | ImageSize | Dataset | Box mAP (%) | Mask mAP (%) | Params | FLOPs | Recipe | Download | +|------------|-------|------|----------|-----------|--------------|-------------|--------------|--------|--------|---------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------| +| YOLOv8-seg | X | P5 | D910x8-G | 640 | MS COCO 2017 | 52.5 | 42.9 | 71.8M | 344.1G | [yaml](https://github.com/mindspore-lab/mindyolo/blob/master/configs/yolov8/seg/yolov8x-seg.yaml) | [weights](https://download.mindspore.cn/toolkits/mindyolo/yolov8/yolov8-x-seg_300e_mAP_mask_429-b4920557.ckpt) | + +
+ +### Notes - Context: Training context denoted as {device}x{pieces}-{MS mode}, where mindspore mode can be G - graph mode or F - pynative mode with ms function. For example, D910x8-G is for training on 8 pieces of Ascend 910 NPU using graph mode. - Box mAP: Accuracy reported on the validation set. diff --git a/configs/yolov8/seg/hyp.scratch.high.seg.yaml b/configs/yolov8/seg/hyp.scratch.high.seg.yaml new file mode 100644 index 00000000..09a3e8c1 --- /dev/null +++ b/configs/yolov8/seg/hyp.scratch.high.seg.yaml @@ -0,0 +1,73 @@ +epochs: 300 # total train epochs + +optimizer: + optimizer: momentum + lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) + momentum: 0.937 # SGD momentum/Adam beta1 + nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm + loss_scale: 1.0 # loss scale for optimizer + warmup_epochs: 3 # warmup epochs (fractions ok) + warmup_momentum: 0.8 # warmup initial momentum + warmup_bias_lr: 0.1 # warmup initial bias lr + min_warmup_step: 1000 # minimum warmup step + group_param: yolov8 # group param strategy + gp_weight_decay: 0.0010078125 # group param weight decay 5e-4 + start_factor: 1.0 + end_factor: 0.01 + +loss: + name: YOLOv8SegLoss + box: 7.5 # box loss gain + cls: 0.5 # cls loss gain + dfl: 1.5 # dfl loss gain + reg_max: 16 + nm: 32 + overlap: True + max_object_num: 600 + +data: + num_parallel_workers: 4 + + train_transforms: { + stage_epochs: [ 290, 10 ], + trans_list: [ + [ + {func_name: resample_segments}, + {func_name: mosaic, prob: 1.0}, + {func_name: copy_paste, prob: 0.3}, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0}, + {func_name: mixup, alpha: 32.0, beta: 32.0, prob: 0.15, pre_transform: [ + { func_name: resample_segments }, + { func_name: mosaic, prob: 1.0 }, + { func_name: copy_paste, prob: 0.3 }, + { func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0 },] + }, + {func_name: albumentations, random_resized_crop: False}, # random_resized_crop not support seg task + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4 }, + {func_name: fliplr, prob: 0.5 }, + {func_name: segment_poly2mask, mask_overlap: True, mask_ratio: 4 }, + {func_name: label_norm, xyxy2xywh_: True }, + {func_name: label_pad, padding_size: 160, padding_value: -1 }, + {func_name: image_norm, scale: 255. }, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True } + ], + [ + {func_name: resample_segments}, + {func_name: letterbox, scaleup: True }, + {func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.9, shear: 0.0 }, + {func_name: albumentations, random_resized_crop: False}, # random_resized_crop not support seg task + {func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4 }, + {func_name: fliplr, prob: 0.5 }, + {func_name: segment_poly2mask, mask_overlap: True, mask_ratio: 4 }, + {func_name: label_norm, xyxy2xywh_: True }, + {func_name: label_pad, padding_size: 160, padding_value: -1 }, + {func_name: image_norm, scale: 255. }, + {func_name: image_transpose, bgr2rgb: True, hwc2chw: True } + ]] + } + + test_transforms: [ + { func_name: letterbox, scaleup: False }, + { func_name: image_norm, scale: 255. }, + { func_name: image_transpose, bgr2rgb: True, hwc2chw: True } + ] diff --git a/configs/yolov8/seg/yolov8-seg-base.yaml b/configs/yolov8/seg/yolov8-seg-base.yaml new file mode 100644 index 00000000..a78de63a --- /dev/null +++ b/configs/yolov8/seg/yolov8-seg-base.yaml @@ -0,0 +1,49 @@ +task: segment +epochs: 500 # total train epochs +per_batch_size: 16 # 16 * 8 = 128 +img_size: 640 +iou_thres: 0.7 +conf_free: True +sync_bn: True +opencv_threads_num: 0 # opencv: disable threading optimizations + +network: + model_name: yolov8 + nc: 80 # number of classes + reg_max: 16 + + stride: [8, 16, 32] + + # YOLOv8.0n backbone + backbone: + # [from, repeats, module, args] + - [-1, 1, ConvNormAct, [64, 3, 2]] # 0-P1/2 + - [-1, 1, ConvNormAct, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, ConvNormAct, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, ConvNormAct, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, ConvNormAct, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + + # YOLOv8.0n head + head: + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, Upsample, [None, 2, 'nearest']] + - [[-1, 4], 1, Concat, [1] ] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, ConvNormAct, [256, 3, 2]] + - [[ -1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, ConvNormAct, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, YOLOv8Head, [nc, reg_max, stride]] # Detect(P3, P4, P5) diff --git a/configs/yolov8/seg/yolov8x-seg.yaml b/configs/yolov8/seg/yolov8x-seg.yaml new file mode 100644 index 00000000..3528e80a --- /dev/null +++ b/configs/yolov8/seg/yolov8x-seg.yaml @@ -0,0 +1,13 @@ +__BASE__: [ + '../../coco.yaml', + './hyp.scratch.high.seg.yaml', + './yolov8-seg-base.yaml' +] + +recompute: True +recompute_layers: 2 + +network: + depth_multiple: 1.00 # scales module repeats + width_multiple: 1.25 # scales convolution channels + max_channels: 512 diff --git a/demo/predict.py b/demo/predict.py index e4868dcb..61f4d3fe 100644 --- a/demo/predict.py +++ b/demo/predict.py @@ -15,12 +15,13 @@ from mindyolo.models import create_model from mindyolo.utils import logger from mindyolo.utils.config import parse_args -from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh +from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh, process_mask_upsample, scale_image from mindyolo.utils.utils import draw_result, set_seed def get_parser_infer(parents=None): parser = argparse.ArgumentParser(description="Infer", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect", "segment"]) parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU") parser.add_argument("--ms_mode", type=int, default=0, help="train mode, graph/pynative") parser.add_argument("--ms_amp_level", type=str, default="O0", help="amp level, O0/O1/O2") @@ -84,6 +85,7 @@ def detect( nms_time_limit: float = 60.0, img_size: int = 640, stride: int = 32, + num_class: int = 80, is_coco_dataset: bool = True, ): # Resize @@ -159,6 +161,106 @@ def detect( return result_dict +def segment( + network: nn.Cell, + img: np.ndarray, + conf_thres: float = 0.25, + iou_thres: float = 0.65, + conf_free: bool = False, + nms_time_limit: float = 60.0, + img_size: int = 640, + stride: int = 32, + num_class: int = 80, + is_coco_dataset: bool = True, +): + # Resize + h_ori, w_ori = img.shape[:2] # orig hw + r = img_size / max(h_ori, w_ori) # resize image to img_size + if r != 1: # always resize down, only resize up if training with augmentation + interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR + img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp) + h, w = img.shape[:2] + if h < img_size or w < img_size: + new_h, new_w = math.ceil(h / stride) * stride, math.ceil(w / stride) * stride + dh, dw = (new_h - h) / 2, (new_w - w) / 2 + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + + # Transpose Norm + img = img[:, :, ::-1].transpose(2, 0, 1) / 255.0 + imgs_tensor = Tensor(img[None], ms.float32) + + # Run infer + _t = time.time() + out, (_, _, prototypes) = network(imgs_tensor) # inference and training outputs + infer_times = time.time() - _t + + # Run NMS + t = time.time() + _c = num_class + 4 if conf_free else num_class + 5 + out = out.asnumpy() + bboxes, mask_coefficient = out[:, :, :_c], out[:, :, _c:] + out = non_max_suppression( + bboxes, + mask_coefficient, + conf_thres=conf_thres, + iou_thres=iou_thres, + conf_free=conf_free, + multi_label=True, + time_limit=nms_time_limit, + ) + nms_times = time.time() - t + + prototypes = prototypes.asnumpy() + + result_dict = {"category_id": [], "bbox": [], "score": [], "segmentation": []} + total_category_ids, total_bboxes, total_scores, total_seg = [], [], [], [] + for si, (pred, proto) in enumerate(zip(out, prototypes)): + if len(pred) == 0: + continue + + # Predictions + pred_masks = process_mask_upsample(proto, pred[:, 6:], pred[:, :4], shape=imgs_tensor[si].shape[1:]) + pred_masks = pred_masks.astype(np.float32) + pred_masks = scale_image((pred_masks.transpose(1, 2, 0)), (h_ori, w_ori)) + predn = np.copy(pred) + scale_coords(img.shape[1:], predn[:, :4], (h_ori, w_ori)) # native-space pred + + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + category_ids, bboxes, scores, segs = [], [], [], [] + for ii, (p, b) in enumerate(zip(pred.tolist(), box.tolist())): + category_ids.append(COCO80_TO_COCO91_CLASS[int(p[5])] if is_coco_dataset else int(p[5])) + bboxes.append([round(x, 3) for x in b]) + scores.append(round(p[4], 5)) + segs.append(pred_masks[:, :, ii]) + + total_category_ids.extend(category_ids) + total_bboxes.extend(bboxes) + total_scores.extend(scores) + total_seg.extend(segs) + + result_dict["category_id"].extend(total_category_ids) + result_dict["bbox"].extend(total_bboxes) + result_dict["score"].extend(total_scores) + result_dict["segmentation"].extend(total_seg) + + t = tuple(x * 1e3 for x in (infer_times, nms_times, infer_times + nms_times)) + (img_size, img_size, 1) # tuple + logger.info(f"Predict result is:") + for k, v in result_dict.items(): + if k == "segmentation": + logger.info(f"{k} shape: {v[0].shape}") + else: + logger.info(f"{k}: {v}") + logger.info(f"Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g;" % t) + logger.info(f"Detect a image success.") + + return result_dict + + def infer(args): # Init set_seed(args.seed) @@ -184,20 +286,38 @@ def infer(args): # Detect is_coco_dataset = "coco" in args.data.dataset_name - result_dict = detect( - network=network, - img=img, - conf_thres=args.conf_thres, - iou_thres=args.iou_thres, - conf_free=args.conf_free, - nms_time_limit=args.nms_time_limit, - img_size=args.img_size, - stride=max(max(args.network.stride), 32), - is_coco_dataset=is_coco_dataset, - ) - if args.save_result: - save_path = os.path.join(args.save_dir, "detect_results") - draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path) + if args.task == "detect": + result_dict = detect( + network=network, + img=img, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + conf_free=args.conf_free, + nms_time_limit=args.nms_time_limit, + img_size=args.img_size, + stride=max(max(args.network.stride), 32), + num_class=args.data.nc, + is_coco_dataset=is_coco_dataset, + ) + if args.save_result: + save_path = os.path.join(args.save_dir, "detect_results") + draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path) + elif args.task == "segment": + result_dict = segment( + network=network, + img=img, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + conf_free=args.conf_free, + nms_time_limit=args.nms_time_limit, + img_size=args.img_size, + stride=max(max(args.network.stride), 32), + num_class=args.data.nc, + is_coco_dataset=is_coco_dataset, + ) + if args.save_result: + save_path = os.path.join(args.save_dir, "segment_results") + draw_result(args.image_path, result_dict, args.data.names, is_coco_dataset=is_coco_dataset, save_path=save_path) logger.info("Infer completed.") diff --git a/deploy/predict.py b/deploy/predict.py index 6d766f7c..85145470 100644 --- a/deploy/predict.py +++ b/deploy/predict.py @@ -21,6 +21,7 @@ def get_parser_infer(parents=None): parser = argparse.ArgumentParser(description="Infer", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect"]) parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU") parser.add_argument("--ms_mode", type=int, default=0, help="train mode, graph/pynative") parser.add_argument("--ms_amp_level", type=str, default="O0", help="amp level, O0/O1/O2") @@ -184,19 +185,22 @@ def infer(args): # Detect is_coco_dataset = "coco" in args.data.dataset_name - result_dict = detect( - network=network, - img=img, - conf_thres=args.conf_thres, - iou_thres=args.iou_thres, - conf_free=args.conf_free, - nms_time_limit=args.nms_time_limit, - img_size=args.img_size, - is_coco_dataset=is_coco_dataset, - ) - if args.save_result: - save_path = os.path.join(args.save_dir, "detect_results") - draw_result(args.image_path, result_dict, args.data.names, save_path=save_path) + if args.task == "detect": + result_dict = detect( + network=network, + img=img, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + conf_free=args.conf_free, + nms_time_limit=args.nms_time_limit, + img_size=args.img_size, + is_coco_dataset=is_coco_dataset, + ) + if args.save_result: + save_path = os.path.join(args.save_dir, "detect_results") + draw_result(args.image_path, result_dict, args.data.names, save_path=save_path) + else: + raise NotImplementedError logger.info("Infer completed.") diff --git a/deploy/test.py b/deploy/test.py index 9bbd1bf2..82d9a08d 100644 --- a/deploy/test.py +++ b/deploy/test.py @@ -17,6 +17,13 @@ def test(args): + if args.task == "detect": + return test_detect(args) + else: + raise NotImplementedError + + +def test_detect(args): # Create Network if args.model_type == "MindX": from infer_engine.mindx import MindXModel @@ -41,7 +48,8 @@ def test(args): dataloader = create_loader( dataset=dataset, batch_collate_fn=dataset.test_collate_fn, - dataset_column_names=dataset.dataset_column_names, + column_names_getitem=dataset.column_names_getitem, + column_names_collate=dataset.column_names_collate, batch_size=args.batch_size, epoch_size=1, rank=0, @@ -63,9 +71,8 @@ def test(args): nms_times = 0.0 result_dicts = [] for i, data in enumerate(loader): - imgs, _, paths, ori_shape, pad, hw_scale = ( + imgs, paths, ori_shape, pad, hw_scale = ( data["image"], - data["labels"], data["img_files"], data["hw_ori"], data["pad"], @@ -147,6 +154,7 @@ def test(args): def get_parser_test(parents=None): parser = argparse.ArgumentParser(description="Test", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect"]) parser.add_argument("--img_size", type=int, default=640, help="inference size (pixels)") parser.add_argument("--rect", type=ast.literal_eval, default=False, help="rectangular training") parser.add_argument( diff --git a/mindyolo/data/albumentations.py b/mindyolo/data/albumentations.py index 141b9967..b0234631 100644 --- a/mindyolo/data/albumentations.py +++ b/mindyolo/data/albumentations.py @@ -4,19 +4,25 @@ import numpy as np import pkg_resources as pkg +from .utils import xyxy2xywh + class Albumentations: # Implement Albumentations augmentation https://github.com/ultralytics/yolov5 # YOLOv5 Albumentations class (optional, only used if package is installed) - def __init__(self, size=640): + def __init__(self, size=640, random_resized_crop=True): self.transform = None prefix = _colorstr("albumentations: ") try: import albumentations as A _check_version(A.__version__, "1.0.3", hard=True) # version requirement - T = [ - A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0), + T = [] + if random_resized_crop: + T.extend([ + A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0), + ]) + T.extend([ A.Blur(p=0.01), A.MedianBlur(p=0.01), A.ToGray(p=0.01), @@ -24,7 +30,7 @@ def __init__(self, size=640): A.RandomBrightnessContrast(p=0.0), A.RandomGamma(p=0.0), A.ImageCompression(quality_lower=75, p=0.0), - ] # transforms + ]) self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) print(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p), flush=True) @@ -36,11 +42,24 @@ def __init__(self, size=640): print(f"{prefix}{e}", flush=True) print("[WARNING] albumentations load failed", flush=True) - def __call__(self, im, labels, p=1.0): + def __call__(self, sample, p=1.0): if self.transform and random.random() < p: - new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed - im, labels = new["image"], np.array([[c, *b] for c, b in zip(new["class_labels"], new["bboxes"])]) - return im, labels + im, bboxes, cls, bbox_format = sample['img'], sample['bboxes'], sample['cls'], sample['bbox_format'] + assert bbox_format in ("ltrb", "xywhn") + if bbox_format == "ltrb" and bboxes.shape[0] > 0: + h, w = im.shape[:2] + bboxes = xyxy2xywh(bboxes) + bboxes[:, [0, 2]] /= w + bboxes[:, [1, 3]] /= h + + new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed + + sample['img'] = new['image'] + sample['bboxes'] = np.array(new['bboxes']) + sample['cls'] = np.array(new['class_labels']) + sample['bbox_format'] = "xywhn" + + return sample def _check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): diff --git a/mindyolo/data/copypaste.py b/mindyolo/data/copypaste.py deleted file mode 100644 index a81d6765..00000000 --- a/mindyolo/data/copypaste.py +++ /dev/null @@ -1,48 +0,0 @@ -import random - -import cv2 -import numpy as np - - -def copy_paste(img, labels, segments, probability=0.5): - # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) - n = len(segments) - if probability and n: - h, w, c = img.shape # height, width, channels - im_new = np.zeros(img.shape, np.uint8) - for j in random.sample(range(n), k=round(probability * n)): - l, s = labels[j], segments[j] - box = w - l[3], l[2], w - l[1], l[4] - ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area - if (ioa < 0.30).all(): # allow 30% obscuration of existing labels - labels = np.concatenate((labels, [[l[0], *box]]), 0) - segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) - cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) - - result = cv2.bitwise_and(src1=img, src2=im_new) - result = cv2.flip(result, 1) # augment segments (flip left-right) - i = result > 0 # pixels to replace - # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch - img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug - - return img, labels, segments - - -def bbox_ioa(box1, box2): - # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 - box2 = box2.transpose() - - # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] - - # Intersection area - inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * ( - np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1) - ).clip(0) - - # box2 area - box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 - - # Intersection over box2 area - return inter_area / box2_area diff --git a/mindyolo/data/dataset.py b/mindyolo/data/dataset.py index 89807f24..9a179db0 100644 --- a/mindyolo/data/dataset.py +++ b/mindyolo/data/dataset.py @@ -1,19 +1,19 @@ import os - import cv2 -from pathlib import Path -import numpy as np -from PIL import ExifTags, Image -from tqdm import tqdm +import math import hashlib import random import glob +import numpy as np +from pathlib import Path +from PIL import ExifTags, Image +from tqdm import tqdm +from copy import deepcopy from mindyolo.utils import logger - -from .albumentations import Albumentations -from .copypaste import copy_paste -from .perspective import random_perspective +from mindyolo.data.albumentations import Albumentations +from mindyolo.data.utils import xywhn2xyxy, xyxy2xywh, xyn2xy, segment2box, segments2boxes, \ + box_candidates, polygons2masks, polygons2masks_overlap, bbox_ioa __all__ = ["COCODataset"] @@ -59,25 +59,42 @@ def __init__( single_cls=False, batch_size=32, stride=32, + num_cls=80, pad=0.0, + return_segments=False, # for segment + return_keypoints=False, # for keypoint + nkpt=0, # for keypoint + ndim=0 # for keypoint ): - self.cache_version = 0.1 - self.path = dataset_path - # acceptable image suffixes self.img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] - self.help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data' + self.cache_version = 0.2 + + self.return_segments = return_segments + self.return_keypoints = return_keypoints + assert not (return_segments and return_keypoints), 'Can not return both segments and keypoints.' + self.path = dataset_path self.img_size = img_size self.augment = augment self.rect = rect self.stride = stride + self.num_cls = num_cls + self.nkpt = nkpt + self.ndim = ndim self.transforms_dict = transforms_dict self.is_training = is_training - if is_training: - self.dataset_column_names = ["image", "labels", "img_files"] + + # set column names + self.column_names_getitem = ['samples'] + if self.is_training: + self.column_names_collate = ['images', 'labels'] + if self.return_segments: + self.column_names_collate = ['images', 'labels', 'masks'] + elif self.return_keypoints: + self.column_names_collate = ['images', 'labels', 'keypoints'] else: - self.dataset_column_names = ["image", "labels", "img_files", "hw_ori", "hw_scale", "pad"] + self.column_names_collate = ["images", "img_files", "hw_ori", "hw_scale", "pad"] try: f = [] # image files @@ -102,9 +119,8 @@ def __init__( cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix(".cache.npy") # cached labels if cache_path.is_file(): cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict - if cache["version"] == self.cache_version and cache["hash"] == self._get_hash( - self.label_files + self.img_files - ): + if cache["version"] == self.cache_version \ + and cache["hash"] == self._get_hash(self.label_files + self.img_files): logger.info(f"Dataset Cache file hash/version check success.") logger.info(f"Load dataset cache from [{cache_path}] success.") else: @@ -127,22 +143,33 @@ def __init__( # Read cache cache.pop("hash") # remove hash cache.pop("version") # remove version - labels, shapes, self.segments = zip(*cache.values()) - self.labels = list(labels) - self.img_shapes = np.array(shapes, dtype=np.float64) - self.img_files = list(cache.keys()) # update - self.label_files = self._img2label_paths(cache.keys()) # update + self.labels = cache['labels'] + self.img_files = [lb['im_file'] for lb in self.labels] # update im_files + + # Check if the dataset is all boxes or all segments + lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in self.labels) + len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) + if len_segments and len_boxes != len_segments: + print( + f'WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, ' + f'len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ' + 'To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.') + for lb in self.labels: + lb['segments'] = [] + if len_cls == 0: + raise ValueError(f'All labels empty in {cache_path}, can not start training without labels.') + if single_cls: for x in self.labels: - x[:, 0] = 0 + x['cls'][:, 0] = 0 - n = len(labels) # number of images + n = len(self.labels) # number of images bi = np.floor(np.arange(n) / batch_size).astype(np.int_) # batch index nb = bi[-1] + 1 # number of batches self.batch = bi # batch index of image # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM) - self.imgs, self.img_hw_ori, self.indices = [None, ] * n, [None, ] * n, range(n) + self.imgs, self.img_hw_ori, self.indices = None, None, range(n) # Rectangular Train/Test if self.rect: @@ -170,16 +197,14 @@ def __init__( self.imgIds = [int(Path(im_file).stem) for im_file in self.img_files] - def cache_labels(self, path=Path("./labels.cache")): - # Get orientation exif tag - for orientation in ExifTags.TAGS.keys(): - if ExifTags.TAGS[orientation] == "Orientation": - break - + def cache_labels(self, path=Path("./labels.cache.npy")): # Cache dataset labels, check images and read shapes - x = {} # dict - nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate + x = {'labels': []} # dict + nm, nf, ne, nc, segments, keypoints = 0, 0, 0, 0, [], None # number missing, found, empty, duplicate pbar = tqdm(zip(self.img_files, self.label_files), desc="Scanning images", total=len(self.img_files)) + if self.return_keypoints and (self.nkpt <= 0 or self.ndim not in (2, 3)): + raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " + "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'") for i, (im_file, lb_file) in enumerate(pbar): try: # verify images @@ -194,26 +219,62 @@ def cache_labels(self, path=Path("./labels.cache")): if os.path.isfile(lb_file): nf += 1 # label found with open(lb_file, "r") as f: - l = [x.split() for x in f.read().strip().splitlines()] - if any([len(x) > 8 for x in l]): # is segment - classes = np.array([x[0] for x in l], dtype=np.float32) - segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...) - l = np.concatenate( - (classes.reshape(-1, 1), self._segments2boxes(segments)), 1 + lb = [x.split() for x in f.read().strip().splitlines()] + if any([len(x) > 6 for x in lb]) and (not self.return_keypoints): # is segment + classes = np.array([x[0] for x in lb], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) + lb = np.concatenate( + (classes.reshape(-1, 1), segments2boxes(segments)), 1 ) # (cls, xywh) - l = np.array(l, dtype=np.float32) - if len(l): - assert l.shape[1] == 5, "labels require 5 columns each" - assert (l >= 0).all(), "negative labels" - assert (l[:, 1:] <= 1).all(), "non-normalized or out of bounds coordinate labels" - assert np.unique(l, axis=0).shape[0] == l.shape[0], "duplicate labels" + lb = np.array(lb, dtype=np.float32) + nl = len(lb) + if nl: + if self.return_keypoints: + assert lb.shape[1] == (5 + self.nkpt * self.ndim), \ + f'labels require {(5 + self.nkpt * self.ndim)} columns each' + assert (lb[:, 5::self.ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels' + assert (lb[:, 6::self.ndim] <= 1).all(), 'non-normalized or out of bounds coordinate labels' + else: + assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected' + assert (lb[:, 1:] <= 1).all(), \ + f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}' + assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}' + # All labels + max_cls = int(lb[:, 0].max()) # max label count + assert max_cls <= self.num_cls, \ + f'Label class {max_cls} exceeds dataset class count {self.num_cls}. ' \ + f'Possible class labels are 0-{self.num_cls - 1}' + _, j = np.unique(lb, axis=0, return_index=True) + if len(j) < nl: # duplicate row check + lb = lb[j] # remove duplicates + if segments: + segments = [segments[x] for x in i] + print(f'WARNING ⚠️ {im_file}: {nl - len(j)} duplicate labels removed') else: ne += 1 # label empty - l = np.zeros((0, 5), dtype=np.float32) + lb = np.zeros((0, (5 + self.nkpt * self.ndim)), dtype=np.float32) \ + if self.return_keypoints else np.zeros((0, 5), dtype=np.float32) else: nm += 1 # label missing - l = np.zeros((0, 5), dtype=np.float32) - x[im_file] = [l, shape, segments] + lb = np.zeros((0, (5 + self.nkpt * self.ndim)), dtype=np.float32) \ + if self.return_keypoints else np.zeros((0, 5), dtype=np.float32) + if self.return_keypoints: + keypoints = lb[:, 5:].reshape(-1, self.nkpt, self.ndim) + if self.ndim == 2: + kpt_mask = np.ones(keypoints.shape[:2], dtype=np.float32) + kpt_mask = np.where(keypoints[..., 0] < 0, 0.0, kpt_mask) + kpt_mask = np.where(keypoints[..., 1] < 0, 0.0, kpt_mask) + keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) + lb = lb[:, :5] + x['labels'].append( + dict( + im_file=im_file, + cls=lb[:, 0:1], # n, 1 + bboxes=lb[:, 1:], # n, 4 + segments=segments, + keypoints=keypoints, + bbox_format='xywhn', + segment_format='polygon')) except Exception as e: nc += 1 print(f"WARNING: Ignoring corrupted image and/or label {im_file}: {e}") @@ -226,56 +287,39 @@ def cache_labels(self, path=Path("./labels.cache")): print(f"WARNING: No labels found in {path}. See {self.help_url}") x["hash"] = self._get_hash(self.label_files + self.img_files) - x["results"] = nf, nm, ne, nc, i + 1 + x["results"] = nf, nm, ne, nc, len(self.img_files) x["version"] = self.cache_version # cache version np.save(path, x) # save for next time logger.info(f"New cache created: {path}") return x def __getitem__(self, index): - index, image, labels, segment, hw_ori, hw_scale, pad = index, None, None, None, None, None, None + sample = self.get_sample(index) + for _i, ori_trans in enumerate(self.transforms_dict): _trans = ori_trans.copy() func_name, prob = _trans.pop("func_name"), _trans.pop("prob", 1.0) - if random.random() < prob: - if func_name == "mosaic": - image, labels = self.mosaic(index, **_trans) - elif func_name == "letterbox": - image, hw_ori = self.load_image(index) - labels = self.labels[index].copy() + if func_name == 'copy_paste': + sample = self.copy_paste(sample, prob) + elif random.random() < prob: + if func_name == "albumentations" and getattr(self, "albumentations", None) is None: + self.albumentations = Albumentations(size=self.img_size) + if func_name == "letterbox": new_shape = self.img_size if not self.rect else self.batch_shapes[self.batch[index]] - image, labels, hw_ori, hw_scale, pad = self.letterbox(image, labels, hw_ori, new_shape, **_trans) - elif func_name == "albumentations": - if getattr(self, "albumentations", None) is None: - self.albumentations = Albumentations(size=self.img_size) - image, labels = self.albumentations(image, labels, **_trans) + sample = self.letterbox(sample, new_shape, **_trans) else: - if image is None: - image, hw_ori = self.load_image(index) - labels = self.labels[index].copy() - new_shape = self.img_size if not self.rect else self.batch_shapes[self.batch[index]] - image, labels, hw_ori, hw_scale, pad = self.letterbox( - image, - labels, - hw_ori, - new_shape, - ) - image, labels = getattr(self, func_name)(image, labels, **_trans) - - image = np.ascontiguousarray(image) + sample = getattr(self, func_name)(sample, **_trans) - if self.is_training: - return image, labels, self.img_files[index] - else: - return image, labels, self.img_files[index], hw_ori, hw_scale, pad + sample['img'] = np.ascontiguousarray(sample['img']) + return sample def __len__(self): return len(self.img_files) - def load_image(self, index): - # loads 1 image from dataset, returns img, original hw, resized hw - img = self.imgs[index] - if img is None: # not cached + def get_sample(self, index): + """Get and return label information from the dataset.""" + sample = deepcopy(self.labels[index]) + if self.imgs is None: path = self.img_files[index] img = cv2.imread(path) # BGR assert img is not None, "Image Not Found " + path @@ -285,86 +329,48 @@ def load_image(self, index): interp = cv2.INTER_AREA if r < 1 and not self.augment else cv2.INTER_LINEAR img = cv2.resize(img, (int(w_ori * r), int(h_ori * r)), interpolation=interp) - return img, np.array([h_ori, w_ori]) # img, hw_original - else: - return self.imgs[index], self.img_hw_ori[index] # img, hw_original + sample['img'], sample['ori_shape'] = img, np.array([h_ori, w_ori]) # img, hw_original - def load_samples(self, index): - # loads images in a 4-mosaic - labels4, segments4 = [], [] - s = self.img_size - mosaic_border = [-s // 2, -s // 2] - yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in mosaic_border] # mosaic center x, y - indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices - for i, index in enumerate(indices): - # Load image - img, _ = self.load_image(index) - (h, w) = img.shape[:2] - - # place img in img4 - if i == 0: # top left - img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles - x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) - x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) - elif i == 1: # top right - x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc - x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h - elif i == 2: # bottom left - x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) - x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) - elif i == 3: # bottom right - x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) - x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) - - img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] - padw = x1a - x1b - padh = y1a - y1b - - # Labels - labels, segments = self.labels[index].copy(), self.segments[index].copy() - if labels.size: - labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format - segments = [xyn2xy(x, w, h, padw, padh) for x in segments] - labels4.append(labels) - segments4.extend(segments) - - # Concat/clip labels - labels4 = np.concatenate(labels4, 0) - for x in (labels4[:, 1:], *segments4): - np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() - - # Augment - sample_labels, sample_images, sample_masks = self._sample_segments(img4, labels4, segments4, probability=0.5) + else: + sample['img'], sample['ori_shape'] = self.imgs[index], self.img_hw_ori[index] # img, hw_original - return sample_labels, sample_images, sample_masks + return sample def mosaic( self, - index, + sample, mosaic9_prob=0.0, - copy_paste_prob=0.0, - degrees=0.0, - translate=0.2, - scale=0.9, - shear=0.0, - perspective=0.0, ): - assert mosaic9_prob >= 0.0 and mosaic9_prob <= 1.0 + segment_format = sample['segment_format'] + bbox_format = sample['bbox_format'] + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' + assert bbox_format == 'xywhn', f'The bbox format should be xywhn, but got {bbox_format}' + + mosaic9_prob = min(1.0, max(mosaic9_prob, 0.0)) if random.random() < (1 - mosaic9_prob): - return self.mosaic4(index, copy_paste_prob, degrees, translate, scale, shear, perspective) + return self._mosaic4(sample) else: - return self.mosaic9(index, copy_paste_prob, degrees, translate, scale, shear, perspective) + return self._mosaic9(sample) - def mosaic4(self, index, copy_paste_prob=0.0, degrees=0.0, translate=0.2, scale=0.9, shear=0.0, perspective=0.0): + def _mosaic4(self, sample): # loads images in a 4-mosaic - labels4, segments4 = [], [] + classes4, bboxes4, segments4 = [], [], [] + mosaic_samples = [sample, ] + indices = random.choices(self.indices, k=3) # 3 additional image indices + + segments_is_list = isinstance(sample['segments'], list) + if segments_is_list: + mosaic_samples += [self.get_sample(i) for i in indices] + else: + mosaic_samples += [self.resample_segments(self.get_sample(i)) for i in indices] + s = self.img_size mosaic_border = [-s // 2, -s // 2] yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in mosaic_border] # mosaic center x, y - indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices - for i, index in enumerate(indices): + + for i, mosaic_sample in enumerate(mosaic_samples): # Load image - img, _ = self.load_image(index) + img = mosaic_sample['img'] (h, w) = img.shape[:2] # place img in img4 @@ -386,45 +392,60 @@ def mosaic4(self, index, copy_paste_prob=0.0, degrees=0.0, translate=0.2, scale= padw = x1a - x1b padh = y1a - y1b - # Labels - labels, segments = self.labels[index].copy(), self.segments[index].copy() - if labels.size: - labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format + # box and cls + cls, bboxes = mosaic_sample['cls'], mosaic_sample['bboxes'] + assert mosaic_sample['bbox_format'] == 'xywhn' + bboxes = xywhn2xyxy(bboxes, w, h, padw, padh) # normalized xywh to pixel xyxy format + classes4.append(cls) + bboxes4.append(bboxes) + + # seg + assert mosaic_sample['segment_format'] == 'polygon' + segments = mosaic_sample['segments'] + if segments_is_list: segments = [xyn2xy(x, w, h, padw, padh) for x in segments] - labels4.append(labels) - segments4.extend(segments) + segments4.extend(segments) + else: + segments = xyn2xy(segments, w, h, padw, padh) + segments4.append(segments) - # Concat/clip labels - labels4 = np.concatenate(labels4, 0) - for x in (labels4[:, 1:], *segments4): - np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() + classes4 = np.concatenate(classes4, 0) + bboxes4 = np.concatenate(bboxes4, 0) + bboxes4 = bboxes4.clip(0, 2 * s) - # Augment - img4, labels4, segments4 = copy_paste(img4, labels4, segments4, probability=copy_paste_prob) - img4, labels4 = random_perspective( - img4, - labels4, - segments4, - degrees=degrees, - translate=translate, - scale=scale, - shear=shear, - perspective=perspective, - border=mosaic_border, - ) # border to remove - - return img4, labels4 - - def mosaic9(self, index, copy_paste_prob=0.0, degrees=0.0, translate=0.2, scale=0.9, shear=0.0, perspective=0.0): + if segments_is_list: + for x in segments4: + np.clip(x, 0, 2 * s, out=x) + else: + segments4 = np.concatenate(segments4, 0) + segments4 = segments4.clip(0, 2 * s) + + sample['img'] = img4 + sample['cls'] = classes4 + sample['bboxes'] = bboxes4 + sample['bbox_format'] = 'ltrb' + sample['segments'] = segments4 + sample['mosaic_border'] = mosaic_border + + return sample + + def _mosaic9(self, sample): # loads images in a 9-mosaic + classes9, bboxes9, segments9 = [], [], [] + mosaic_samples = [sample, ] + indices = random.choices(self.indices, k=8) # 8 additional image indices - labels9, segments9 = [], [] + segments_is_list = isinstance(sample['segments'], list) + if segments_is_list: + mosaic_samples += [self.get_sample(i) for i in indices] + else: + mosaic_samples += [self.resample_segments(self.get_sample(i)) for i in indices] s = self.img_size mosaic_border = [-s // 2, -s // 2] - indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices - for i, index in enumerate(indices): + + for i, mosaic_sample in enumerate(mosaic_samples): # Load image - img, _ = self.load_image(index) + img = mosaic_sample['img'] (h, w) = img.shape[:2] # place img in img9 @@ -452,73 +473,263 @@ def mosaic9(self, index, copy_paste_prob=0.0, degrees=0.0, translate=0.2, scale= padx, pady = c[:2] x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords - # Labels - labels, segments = self.labels[index].copy(), self.segments[index].copy() - if labels.size: - labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format + # box and cls + assert mosaic_sample['bbox_format'] == 'xywhn' + cls, bboxes = mosaic_sample['cls'], mosaic_sample['bboxes'] + bboxes = xywhn2xyxy(bboxes, w, h, padx, pady) # normalized xywh to pixel xyxy format + classes9.append(cls) + bboxes9.append(bboxes) + + # seg + assert mosaic_sample['segment_format'] == 'polygon' + segments = mosaic_sample['segments'] + if segments_is_list: segments = [xyn2xy(x, w, h, padx, pady) for x in segments] - labels9.append(labels) - segments9.extend(segments) + segments9.extend(segments) + else: + segments = xyn2xy(segments, w, h, padx, pady) + segments9.append(segments) # Image - img9[y1:y2, x1:x2] = img[y1 - pady :, x1 - padx :] # img9[ymin:ymax, xmin:xmax] + img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] hp, wp = h, w # height, width previous # Offset yc, xc = [int(random.uniform(0, s)) for _ in mosaic_border] # mosaic center x, y - img9 = img9[yc : yc + 2 * s, xc : xc + 2 * s] + img9 = img9[yc: yc + 2 * s, xc: xc + 2 * s] # Concat/clip labels - labels9 = np.concatenate(labels9, 0) - labels9[:, [1, 3]] -= xc - labels9[:, [2, 4]] -= yc - c = np.array([xc, yc]) # centers - segments9 = [x - c for x in segments9] + classes9 = np.concatenate(classes9, 0) + bboxes9 = np.concatenate(bboxes9, 0) + bboxes9[:, [0, 2]] -= xc + bboxes9[:, [1, 3]] -= yc + bboxes9 = bboxes9.clip(0, 2 * s) + + if segments_is_list: + c = np.array([xc, yc]) # centers + segments9 = [x - c for x in segments9] + for x in segments9: + np.clip(x, 0, 2 * s, out=x) + else: + segments9 = np.concatenate(segments9, 0) + segments9[..., 0] -= xc + segments9[..., 1] -= yc + segments9 = segments9.clip(0, 2 * s) + + sample['img'] = img9 + sample['cls'] = classes9 + sample['bboxes'] = bboxes9 + sample['bbox_format'] = 'ltrb' + sample['segments'] = segments9 + sample['mosaic_border'] = mosaic_border + + return sample + + def resample_segments(self, sample, n=1000): + segment_format = sample['segment_format'] + assert segment_format == 'polygon', f'The segment format is should be polygon, but got {segment_format}' + + segments = sample['segments'] + if len(segments) > 0: + # Up-sample an (n,2) segment + for i, s in enumerate(segments): + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n) + xp = np.arange(len(s)) + segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy + segments = np.stack(segments, axis=0) + else: + segments = np.zeros((0, 1000, 2), dtype=np.float32) + sample['segments'] = segments + return sample + + def copy_paste(self, sample, probability=0.5): + # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) + bbox_format, segment_format = sample['bbox_format'], sample['segment_format'] + assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' - for x in (labels9[:, 1:], *segments9): - np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective() + img = sample['img'] + cls = sample['cls'] + bboxes = sample['bboxes'] + segments = sample['segments'] - # Augment - img9, labels9, segments9 = copy_paste(img9, labels9, segments9, probability=copy_paste_prob) - img9, labels9 = random_perspective( - img9, - labels9, - segments9, - degrees=degrees, - translate=translate, - scale=scale, - shear=shear, - perspective=perspective, - border=mosaic_border, - ) # border to remove - - return img9, labels9 - - def mixup(self, image, labels, alpha=8.0, beta=8.0, needed_mosaic=True): - if needed_mosaic: - mosaic_trans = None - for _trans in self.transforms_dict: - if _trans["func_name"] == "mosaic": - mosaic_trans = _trans.copy() - break - assert mosaic_trans is not None, "Mixup needed mosaic bug 'mosaic' not in transforms_dict" - _, _ = mosaic_trans.pop("func_name"), mosaic_trans.pop("prob", 1.0) - image2, labels2 = self.mosaic(random.randint(0, len(self.labels) - 1), **mosaic_trans) - else: - index2 = random.randint(0, len(self.labels) - 1) - image2, _ = self.load_image(index2) - labels2 = self.labels[index2] + n = len(segments) + if probability and n: + h, w, _ = img.shape # height, width, channels + im_new = np.zeros(img.shape, np.uint8) + for j in random.sample(range(n), k=round(probability * n)): + c, l, s = cls[j], bboxes[j], segments[j] + box = w - l[2], l[1], w - l[0], l[3] + ioa = bbox_ioa(box, bboxes) # intersection over area + if (ioa < 0.30).all(): # allow 30% obscuration of existing labels + cls = np.concatenate((cls, [c]), 0) + bboxes = np.concatenate((bboxes, [box]), 0) + if isinstance(segments, list): + segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) + else: + segments = np.concatenate((segments, [np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)]), 0) + cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + + result = cv2.bitwise_and(src1=img, src2=im_new) + result = cv2.flip(result, 1) # augment segments (flip left-right) + i = result > 0 # pixels to replace + img[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug + + sample['img'] = img + sample['cls'] = cls + sample['bboxes'] = bboxes + sample['segments'] = segments + + return sample + def random_perspective( + self, sample, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0) + ): + bbox_format, segment_format = sample['bbox_format'], sample['segment_format'] + assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' + + img = sample['img'] + cls = sample['cls'] + targets = sample['bboxes'] + segments = sample['segments'] + assert isinstance(segments, np.ndarray), f"segments type expect numpy.ndarray, but got {type(segments)}; " \ + f"maybe you should resample_segments before that." + + border = sample.pop('mosaic_border', border) + height = img.shape[0] + border[0] * 2 # shape(h,w,c) + width = img.shape[1] + border[1] * 2 + + # Center + C = np.eye(3) + C[0, 2] = -img.shape[1] / 2 # x translation (pixels) + C[1, 2] = -img.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3) + P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) + P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3) + a = random.uniform(-degrees, degrees) + s = random.uniform(1 - scale, 1 + scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3) + S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3) + T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) + T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if perspective: + img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114)) + else: # affine + img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) + + # Transform label coordinates + n = len(targets) + if n: + use_segments = len(segments) + if use_segments: # warp segments + n, num = segments.shape[:2] + xy = np.ones((n * num, 3), dtype=segments.dtype) + segments = segments.reshape(-1, 2) + xy[:, :2] = segments + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] + segments = xy.reshape(n, -1, 2) + segments[..., 0] = segments[..., 0].clip(0, width) + segments[..., 1] = segments[..., 1].clip(0, height) + new_bboxes = np.stack([segment2box(xy) for xy in segments], 0) + + else: # warp boxes + xy = np.ones((n * 4, 3)) + xy[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine + + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + new_bboxes = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + + # clip + new_bboxes[:, [0, 2]] = new_bboxes[:, [0, 2]].clip(0, width) + new_bboxes[:, [1, 3]] = new_bboxes[:, [1, 3]].clip(0, height) + + # filter candidates + i = box_candidates(box1=targets.T * s, box2=new_bboxes.T, area_thr=0.01 if use_segments else 0.10) + + cls = cls[i] + targets = new_bboxes[i] + segments = segments[i] + sample['cls'] = cls + sample['bboxes'] = targets + sample['segments'] = segments + + sample['img'] = img + + return sample + + def mixup(self, sample, alpha: 32.0, beta: 32.0, pre_transform=None): + bbox_format, segment_format = sample['bbox_format'], sample['segment_format'] + assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' + + index = random.choices(self.indices, k=1)[0] + sample2 = self.get_sample(index) + if pre_transform: + for _i, ori_trans in enumerate(pre_transform): + _trans = ori_trans.copy() + func_name, prob = _trans.pop("func_name"), _trans.pop("prob", 1.0) + if func_name == 'copy_paste': + sample2 = self.copy_paste(sample2, prob) + elif random.random() < prob: + if func_name == "albumentations" and getattr(self, "albumentations", None) is None: + self.albumentations = Albumentations(size=self.img_size) + sample2 = getattr(self, func_name)(sample2, **_trans) + + assert isinstance(sample['segments'], np.ndarray), \ + f"MixUp: sample segments type expect numpy.ndarray, but got {type(sample['segments'])}; " \ + f"maybe you should resample_segments before that." + assert isinstance(sample2['segments'], np.ndarray), \ + f"MixUp: sample2 segments type expect numpy.ndarray, but got {type(sample2['segments'])}; " \ + f"maybe you should add resample_segments in pre_transform." + + image, image2 = sample['img'], sample2['img'] r = np.random.beta(alpha, beta) # mixup ratio, alpha=beta=8.0 image = (image * r + image2 * (1 - r)).astype(np.uint8) - labels = np.concatenate((labels, labels2), 0) - return image, labels - def pastein(self, image, labels, num_sample=30): + sample['img'] = image + sample['cls'] = np.concatenate((sample['cls'], sample2['cls']), 0) + sample['bboxes'] = np.concatenate((sample['bboxes'], sample2['bboxes']), 0) + sample['segments'] = np.concatenate((sample['segments'], sample2['segments']), 0) + return sample + + def pastein(self, sample, num_sample=30): + bbox_format = sample['bbox_format'] + assert bbox_format == 'ltrb', f'The bbox format should be ltrb, but got {bbox_format}' + assert not self.return_segments, "pastein currently does not support seg data." + assert not self.return_keypoints, "pastein currently does not support keypoint data." + sample.pop('segments', None) + sample.pop('keypoints', None) + + image = sample['img'] + cls = sample['cls'] + bboxes = sample['bboxes'] # load sample sample_labels, sample_images, sample_masks = [], [], [] while len(sample_labels) < num_sample: - sample_labels_, sample_images_, sample_masks_ = self.load_samples(random.randint(0, len(self.labels) - 1)) + sample_labels_, sample_images_, sample_masks_ = self._pastin_load_samples() sample_labels += sample_labels_ sample_images += sample_images_ sample_masks += sample_masks_ @@ -543,13 +754,13 @@ def pastein(self, image, labels, num_sample=30): ymax = min(h, ymin + mask_h) box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) - if len(labels): - ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area + if len(bboxes): + ioa = bbox_ioa(box, bboxes) # intersection over area else: ioa = np.zeros(1) if ( - (ioa < 0.30).all() and len(sample_labels) and (xmax > xmin + 20) and (ymax > ymin + 20) + (ioa < 0.30).all() and len(sample_labels) and (xmax > xmin + 20) and (ymax > ymin + 20) ): # allow 30% obscuration of existing labels sel_ind = random.randint(0, len(sample_labels) - 1) hs, ws, cs = sample_images[sel_ind].shape @@ -560,37 +771,129 @@ def pastein(self, image, labels, num_sample=30): if (r_w > 10) and (r_h > 10): r_mask = cv2.resize(sample_masks[sel_ind], (r_w, r_h)) r_image = cv2.resize(sample_images[sel_ind], (r_w, r_h)) - temp_crop = image[ymin : ymin + r_h, xmin : xmin + r_w] + temp_crop = image[ymin: ymin + r_h, xmin: xmin + r_w] m_ind = r_mask > 0 if m_ind.astype(np.int).sum() > 60: temp_crop[m_ind] = r_image[m_ind] box = np.array([xmin, ymin, xmin + r_w, ymin + r_h], dtype=np.float32) - if len(labels): - labels = np.concatenate((labels, [[sample_labels[sel_ind], *box]]), 0) + if len(bboxes): + cls = np.concatenate((cls, [[sample_labels[sel_ind]]]), 0) + bboxes = np.concatenate((bboxes, [box]), 0) else: - labels = np.array([[sample_labels[sel_ind], *box]]) + cls = np.array([[sample_labels[sel_ind]]]) + bboxes = np.array([box]) - image[ymin : ymin + r_h, xmin : xmin + r_w] = temp_crop # Modify on the original image + image[ymin: ymin + r_h, xmin: xmin + r_w] = temp_crop # Modify on the original image - return image, labels + sample['img'] = image + sample['bboxes'] = bboxes + sample['cls'] = cls + return sample - def random_perspective( - self, image, labels, segments=(), degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0) - ): - image, labels = random_perspective( - image, - labels, - segments, - degrees=degrees, - translate=translate, - scale=scale, - shear=shear, - perspective=perspective, - border=border, - ) - return image, labels + def _pastin_load_samples(self): + # loads images in a 4-mosaic + classes4, bboxes4, segments4 = [], [], [] + mosaic_samples = [] + indices = random.choices(self.indices, k=4) # 3 additional image indices + mosaic_samples += [self.get_sample(i) for i in indices] + s = self.img_size + mosaic_border = [-s // 2, -s // 2] + yc, xc = [int(random.uniform(-x, 2 * s + x)) for x in mosaic_border] # mosaic center x, y - def hsv_augment(self, image, labels, hgain=0.5, sgain=0.5, vgain=0.5): + for i, sample in enumerate(mosaic_samples): + # Load image + img = sample['img'] + (h, w) = img.shape[:2] + + # place img in img4 + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + padw = x1a - x1b + padh = y1a - y1b + + # Labels + cls, bboxes = sample['cls'], sample['bboxes'] + bboxes = xywhn2xyxy(bboxes, w, h, padw, padh) # normalized xywh to pixel xyxy format + + classes4.append(cls) + bboxes4.append(bboxes) + + segments = sample['segments'] + segments_is_list = isinstance(segments, list) + if segments_is_list: + segments = [xyn2xy(x, w, h, padw, padh) for x in segments] + segments4.extend(segments) + else: + segments = xyn2xy(segments, w, h, padw, padh) + segments4.append(segments) + + # Concat/clip labels + classes4 = np.concatenate(classes4, 0) + bboxes4 = np.concatenate(bboxes4, 0) + bboxes4 = bboxes4.clip(0, 2 * s) + + if segments_is_list: + for x in segments4: + np.clip(x, 0, 2 * s, out=x) + else: + segments4 = np.concatenate(segments4, 0) + segments4 = segments4.clip(0, 2 * s) + + # Augment + sample_labels, sample_images, sample_masks = \ + self._pastin_sample_segments(img4, classes4, bboxes4, segments4, probability=0.5) + + return sample_labels, sample_images, sample_masks + + def _pastin_sample_segments(self, img, classes, bboxes, segments, probability=0.5): + # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) + n = len(segments) + sample_labels = [] + sample_images = [] + sample_masks = [] + if probability and n: + h, w, c = img.shape # height, width, channels + for j in random.sample(range(n), k=round(probability * n)): + cls, l, s = classes[j], bboxes[j], segments[j] + box = ( + l[0].astype(int).clip(0, w - 1), + l[1].astype(int).clip(0, h - 1), + l[2].astype(int).clip(0, w - 1), + l[3].astype(int).clip(0, h - 1), + ) + + if (box[2] <= box[0]) or (box[3] <= box[1]): + continue + + sample_labels.append(cls[0]) + + mask = np.zeros(img.shape, np.uint8) + + cv2.drawContours(mask, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) + sample_masks.append(mask[box[1]: box[3], box[0]: box[2], :]) + + result = cv2.bitwise_and(src1=img, src2=mask) + i = result > 0 # pixels to replace + mask[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug + sample_images.append(mask[box[1]: box[3], box[0]: box[2], :]) + + return sample_labels, sample_images, sample_masks + + def hsv_augment(self, sample, hgain=0.5, sgain=0.5, vgain=0.5): + image = sample['img'] r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains hue, sat, val = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV)) dtype = image.dtype # uint8 @@ -602,27 +905,67 @@ def hsv_augment(self, image, labels, hgain=0.5, sgain=0.5, vgain=0.5): img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype) cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=image) # Modify on the original image - return image, labels - def fliplr(self, image, labels): + sample['img'] = image + return sample + + def fliplr(self, sample): # flip left-right + bbox_format = sample['bbox_format'] + assert bbox_format == 'ltrb', f'FlipLR: The bbox format should be ltrb, but got {bbox_format}' + + # flip image + image = sample['img'] image = np.fliplr(image) - if len(labels): - labels[:, 1] = 1 - labels[:, 1] - return image, labels - - def flipud(self, image, labels): - # flip up-down - image = np.flipud(image) - if len(labels): - labels[:, 2] = 1 - labels[:, 2] - return image, labels - - def letterbox(self, image, labels, hw_ori, new_shape, scaleup=False, color=(114, 114, 114)): + sample['img'] = image + + # flip box + _, w = image.shape[:2] + bboxes, bbox_format = sample['bboxes'], sample['bbox_format'] + if bbox_format == "ltrb": + if len(bboxes): + x1 = bboxes[:, 0].copy() + x2 = bboxes[:, 2].copy() + bboxes[:, 0] = w - x2 + bboxes[:, 2] = w - x1 + elif bbox_format == "xywhn": + if len(bboxes): + bboxes[:, 0] = 1 - bboxes[:, 0] + else: + raise NotImplementedError + sample['bboxes'] = bboxes + + # flip seg + if 'segments' in sample: + segment_format, segments = sample['segment_format'], sample['segments'] + assert segment_format == 'polygon', \ + f'FlipLR: The segment format should be polygon, but got {segment_format}' + assert isinstance(segments, np.ndarray), \ + f"FlipLR: segments type expect numpy.ndarray, but got {type(segments)}; " \ + f"maybe you should resample_segments before that." + + if len(segments): + segments[..., 0] = w - segments[..., 0] + + sample['segments'] = segments + + return sample + + def letterbox(self, sample, new_shape=None, xywhn2xyxy_=True, scaleup=False, color=(114, 114, 114)): # Resize and pad image while meeting stride-multiple constraints + + if sample['bbox_format'] == 'ltrb': + xywhn2xyxy_ = False + + if not new_shape: + new_shape = self.img_size + image = sample['img'] + bboxes = sample['bboxes'] + ori_shape = sample['ori_shape'] + shape = image.shape[:2] # current shape [height, width] h, w = shape[:] - h0, w0 = hw_ori + h0, w0 = ori_shape hw_scale = np.array([h / h0, w / w0]) if isinstance(new_shape, int): new_shape = (new_shape, new_shape) @@ -646,78 +989,128 @@ def letterbox(self, image, labels, hw_ori, new_shape, scaleup=False, color=(114, left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border - # convert labels - if labels.size: # normalized xywh to pixel xyxy format - labels[:, 1:] = xywhn2xyxy(labels[:, 1:], r * w, r * h, padw=hw_pad[1], padh=hw_pad[0]) + # convert bboxes + if len(bboxes): + if xywhn2xyxy_: + bboxes = xywhn2xyxy(bboxes, r * w, r * h, padw=dw, padh=dh) + else: + bboxes *= r + bboxes[:, [0, 2]] += dw + bboxes[:, [1, 3]] += dh + + # convert segments + if 'segments' in sample: + segments, segment_format = sample['segments'], sample['segment_format'] + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' + assert isinstance(segments, np.ndarray), \ + f"LetterBox: segments type expect numpy.ndarray, but got {type(segments)}; " \ + f"maybe you should resample_segments before that." + + if len(segments): + if xywhn2xyxy_: + segments[..., 0] *= w + segments[..., 1] *= h + else: + segments *= r + segments[..., 0] += dw + segments[..., 1] += dh + sample['segments'] = segments - return image, labels, hw_ori, hw_scale, hw_pad + sample['bboxes'] = bboxes + sample['img'] = image + sample['hw_scale'] = hw_scale + sample['hw_pad'] = hw_pad + sample['bbox_format'] = 'ltrb' - def label_norm(self, image, labels, xyxy2xywh_=True): - if len(labels) == 0: - return image, labels + return sample - if xyxy2xywh_: - labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh + def label_norm(self, sample, xyxy2xywh_=True): + bbox_format = sample['bbox_format'] + if bbox_format == "xywhn": + return sample - labels[:, [2, 4]] /= image.shape[0] # normalized height 0-1 - labels[:, [1, 3]] /= image.shape[1] # normalized width 0-1 + bboxes = sample['bboxes'] + if len(bboxes) == 0: + sample['bbox_format'] = 'xywhn' + return sample - return image, labels + if xyxy2xywh_: + bboxes = xyxy2xywh(bboxes) # convert xyxy to xywh + height, width = sample['img'].shape[:2] + bboxes[:, [1, 3]] /= height # normalized height 0-1 + bboxes[:, [0, 2]] /= width # normalized width 0-1 + sample['bboxes'] = bboxes + sample['bbox_format'] = 'xywhn' - def label_pad(self, image, labels, padding_size=160, padding_value=-1): + return sample + + def label_pad(self, sample, padding_size=160, padding_value=-1): # create fixed label, avoid dynamic shape problem. - labels_out = np.full((padding_size, 6), padding_value, dtype=np.float32) - nL = len(labels) - if nL: - labels_out[: min(nL, padding_size), 0:1] = 0.0 - labels_out[: min(nL, padding_size), 1:] = labels[: min(nL, padding_size), :] - return image, labels_out + bbox_format = sample['bbox_format'] + assert bbox_format == 'xywhn', f'The bbox format should be xywhn, but got {bbox_format}' - def image_norm(self, image, labels, scale=255.0): + cls, bboxes = sample['cls'], sample['bboxes'] + cls_pad = np.full((padding_size, 1), padding_value, dtype=np.float32) + bboxes_pad = np.full((padding_size, 4), padding_value, dtype=np.float32) + nL = len(bboxes) + if nL: + cls_pad[:min(nL, padding_size)] = cls[:min(nL, padding_size)] + bboxes_pad[:min(nL, padding_size)] = bboxes[:min(nL, padding_size)] + sample['cls'] = cls_pad + sample['bboxes'] = bboxes_pad + + if "segments" in sample: + if sample['segment_format'] == "mask": + segments = sample['segments'] + assert isinstance(segments, np.ndarray), \ + f"Label Pad: segments type expect numpy.ndarray, but got {type(segments)}; " \ + f"maybe you should resample_segments before that." + assert nL == segments.shape[0], f"Label Pad: segments len not equal bboxes" + h, w = segments.shape[1:] + segments_pad = np.full((padding_size, h, w), padding_value, dtype=np.float32) + segments_pad[:min(nL, padding_size)] = segments[:min(nL, padding_size)] + sample['segments'] = segments_pad + + return sample + + def image_norm(self, sample, scale=255.0): + image = sample['img'] image = image.astype(np.float32, copy=False) image /= scale - return image, labels + sample['img'] = image + return sample - def image_transpose(self, image, labels, bgr2rgb=True, hwc2chw=True): + def image_transpose(self, sample, bgr2rgb=True, hwc2chw=True): + image = sample['img'] if bgr2rgb: image = image[:, :, ::-1] if hwc2chw: image = image.transpose(2, 0, 1) - return image, labels - - def _sample_segments(self, img, labels, segments, probability=0.5): - # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) - n = len(segments) - sample_labels = [] - sample_images = [] - sample_masks = [] - if probability and n: - h, w, c = img.shape # height, width, channels - for j in random.sample(range(n), k=round(probability * n)): - l, s = labels[j], segments[j] - box = ( - l[1].astype(int).clip(0, w - 1), - l[2].astype(int).clip(0, h - 1), - l[3].astype(int).clip(0, w - 1), - l[4].astype(int).clip(0, h - 1), - ) - - if (box[2] <= box[0]) or (box[3] <= box[1]): - continue - - sample_labels.append(l[0]) - - mask = np.zeros(img.shape, np.uint8) + sample['image'] = image + return sample + + def segment_poly2mask(self, sample, mask_overlap, mask_ratio): + """convert polygon points to bitmap.""" + segments, segment_format = sample['segments'], sample['segment_format'] + assert segment_format == 'polygon', f'The segment format should be polygon, but got {segment_format}' + assert isinstance(segments, np.ndarray), \ + f"Segment Poly2Mask: segments type expect numpy.ndarray, but got {type(segments)}; " \ + f"maybe you should resample_segments before that." + + h, w = sample['img'].shape[:2] + if mask_overlap: + masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=mask_ratio) + masks = masks[None] # (1, h/mask_ratio, w/mask_ratio) + sample['cls'] = sample['cls'][sorted_idx] + sample['bboxes'] = sample['bboxes'][sorted_idx] + sample['segments'] = masks + sample['segment_format'] = 'overlap' + else: + masks = polygons2masks((h, w), segments, color=1, downsample_ratio=mask_ratio) + sample['segments'] = masks + sample['segment_format'] = 'mask' - cv2.drawContours(mask, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) - sample_masks.append(mask[box[1] : box[3], box[0] : box[2], :]) - - result = cv2.bitwise_and(src1=img, src2=mask) - i = result > 0 # pixels to replace - mask[i] = result[i] # cv2.imwrite('debug.jpg', img) # debug - sample_images.append(mask[box[1] : box[3], box[0] : box[2], :]) - - return sample_labels, sample_images, sample_masks + return sample def _img2label_paths(self, img_paths): # Define label paths as a function of image paths @@ -745,77 +1138,33 @@ def _exif_size(self, img): return s - def _segments2boxes(self, segments): - # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) - boxes = [] - for s in segments: - x, y = s.T # segment xy - boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy - return xyxy2xywh(np.array(boxes)) # cls, xywh - - @staticmethod - def train_collate_fn(imgs, labels, path, batch_info): - for i, l in enumerate(labels): - l[:, 0] = i # add target image index for build_targets() - return np.stack(imgs, 0), np.stack(labels, 0), path - - @staticmethod - def test_collate_fn(imgs, labels, path, hw_ori, hw_scale, pad, batch_info): - for i, l in enumerate(labels): - l[:, 0] = i # add target image index for build_targets() + def train_collate_fn(self, batch_samples, batch_info): + imgs = [sample.pop('img') for sample in batch_samples] + labels = [] + for i, sample in batch_samples: + cls, bboxes = sample.pop('cls'), sample.pop('bboxes') + labels.append(np.concatenate((np.full_like(cls, i), cls, bboxes), axis=-1)) + return_items = [np.stack(imgs, 0), np.stack(labels, 0)] + + if self.return_segments: + masks = [sample.pop('segments', None) for sample in batch_samples] + return_items.append(np.stack(masks, 0)) + if self.return_keypoints: + keypoints = [sample.pop('keypoints', None) for sample in batch_samples] + return_items.append(np.stack(keypoints, 0)) + + return return_items + + def test_collate_fn(self, batch_samples, batch_info): + imgs = [sample.pop('img') for sample in batch_samples] + path = [sample.pop('im_file') for sample in batch_samples] + hw_ori = [sample.pop('ori_shape') for sample in batch_samples] + hw_scale = [sample.pop('hw_scale') for sample in batch_samples] + pad = [sample.pop('hw_pad') for sample in batch_samples] return ( np.stack(imgs, 0), - np.stack(labels, 0), path, np.stack(hw_ori, 0), np.stack(hw_scale, 0), np.stack(pad, 0), ) - - -def bbox_ioa(box1, box2): - # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 - box2 = box2.transpose() - - # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] - - # Intersection area - inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * ( - np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1) - ).clip(0) - - # box2 area - box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 - - # Intersection over box2 area - return inter_area / box2_area - - -def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): - # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right - y = np.copy(x) - y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x - y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y - y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x - y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y - return y - - -def xyxy2xywh(x): - # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right - y = np.copy(x) - y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center - y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center - y[:, 2] = x[:, 2] - x[:, 0] # width - y[:, 3] = x[:, 3] - x[:, 1] # height - return y - - -def xyn2xy(x, w=640, h=640, padw=0, padh=0): - # Convert normalized segments into pixel segments, shape (n,2) - y = np.copy(x) - y[:, 0] = w * x[:, 0] + padw # top left x - y[:, 1] = h * x[:, 1] + padh # top left y - return y diff --git a/mindyolo/data/loader.py b/mindyolo/data/loader.py index cf93a781..c94799f3 100644 --- a/mindyolo/data/loader.py +++ b/mindyolo/data/loader.py @@ -14,7 +14,8 @@ def create_loader( dataset, batch_collate_fn, - dataset_column_names, + column_names_getitem, + column_names_collate, batch_size, epoch_size=1, rank=0, @@ -52,7 +53,7 @@ def create_loader( if rank_size > 1: ds = de.GeneratorDataset( dataset, - column_names=dataset_column_names, + column_names=column_names_getitem, num_parallel_workers=min(8, num_parallel_workers), shuffle=shuffle, python_multiprocessing=python_multiprocessing, @@ -62,13 +63,14 @@ def create_loader( else: ds = de.GeneratorDataset( dataset, - column_names=dataset_column_names, + column_names=column_names_getitem, num_parallel_workers=min(32, num_parallel_workers), shuffle=shuffle, python_multiprocessing=python_multiprocessing, ) ds = ds.batch( - batch_size, per_batch_map=batch_collate_fn, input_columns=dataset_column_names, drop_remainder=drop_remainder + batch_size, per_batch_map=batch_collate_fn, + input_columns=column_names_getitem, output_columns=column_names_collate, drop_remainder=drop_remainder ) ds = ds.repeat(epoch_size) diff --git a/mindyolo/data/perspective.py b/mindyolo/data/perspective.py deleted file mode 100644 index 232cb48c..00000000 --- a/mindyolo/data/perspective.py +++ /dev/null @@ -1,110 +0,0 @@ -import math -import random - -import cv2 -import numpy as np - - -def random_perspective( - img, targets=(), segments=(), degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0) -): - height = img.shape[0] + border[0] * 2 # shape(h,w,c) - width = img.shape[1] + border[1] * 2 - - # Center - C = np.eye(3) - C[0, 2] = -img.shape[1] / 2 # x translation (pixels) - C[1, 2] = -img.shape[0] / 2 # y translation (pixels) - - # Perspective - P = np.eye(3) - P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) - P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) - - # Rotation and Scale - R = np.eye(3) - a = random.uniform(-degrees, degrees) - s = random.uniform(1 - scale, 1.1 + scale) - R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) - - # Shear - S = np.eye(3) - S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) - S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) - - # Translation - T = np.eye(3) - T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) - T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) - - # Combined rotation matrix - M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT - if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed - if perspective: - img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=(114, 114, 114)) - else: # affine - img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) - - # Transform label coordinates - n = len(targets) - if n: - use_segments = any(x.any() for x in segments) - new = np.zeros((n, 4)) - if use_segments: # warp segments - segments = _resample_segments(segments) # upsample - for i, segment in enumerate(segments): - xy = np.ones((len(segment), 3)) - xy[:, :2] = segment - xy = xy @ M.T # transform - xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine - - # clip - new[i] = _segment2box(xy, width, height) - - else: # warp boxes - xy = np.ones((n * 4, 3)) - xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 - xy = xy @ M.T # transform - xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine - - # create new boxes - x = xy[:, [0, 2, 4, 6]] - y = xy[:, [1, 3, 5, 7]] - new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T - - # clip - new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) - new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) - - # filter candidates - i = _box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) - targets = targets[i] - targets[:, 1:5] = new[i] - - return img, targets - - -def _resample_segments(segments, n=1000): - # Up-sample an (n,2) segment - for i, s in enumerate(segments): - s = np.concatenate((s, s[0:1, :]), axis=0) - x = np.linspace(0, len(s) - 1, n) - xp = np.arange(len(s)) - segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy - return segments - - -def _segment2box(segment, width=640, height=640): - # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) - x, y = segment.T # segment xy - inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) - x, y, = x[inside], y[inside] - return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy - - -def _box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) - # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio - w1, h1 = box1[2] - box1[0], box1[3] - box1[1] - w2, h2 = box2[2] - box2[0], box2[3] - box2[1] - ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio - return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates diff --git a/mindyolo/data/utils.py b/mindyolo/data/utils.py new file mode 100644 index 00000000..a0211f2c --- /dev/null +++ b/mindyolo/data/utils.py @@ -0,0 +1,129 @@ +import numpy as np +import cv2 + + +def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): + """Return a (640, 640) overlap mask.""" + masks = np.zeros((imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), + dtype=np.int32 if len(segments) > 255 else np.uint8) + areas = [] + ms = [] + for si in range(len(segments)): + mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) + ms.append(mask) + areas.append(mask.sum()) + areas = np.asarray(areas) + index = np.argsort(-areas) + ms = np.array(ms)[index] + for i in range(len(segments)): + mask = ms[i] * (i + 1) + masks = masks + mask + masks = np.clip(masks, a_min=0, a_max=i + 1) + return masks, index + + +def polygons2masks(imgsz, polygons, color, downsample_ratio=1): + """ + Args: + imgsz (tuple): The image size. + polygons (Union(np.ndarray, list[np.ndarray])): each polygon is [N, M], N is number of polygons, M is number of points (M % 2 = 0) + color (int): color + downsample_ratio (int): downsample ratio + """ + masks = [] + for si in range(len(polygons)): + mask = polygon2mask(imgsz, [polygons[si].reshape(-1)], color, downsample_ratio) + masks.append(mask) + return np.array(masks) + + +def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): + """ + Args: + imgsz (tuple): The image size. + polygons (list[np.ndarray]): [N, M], N is the number of polygons, M is the number of points(Be divided by 2). + color (int): color + downsample_ratio (int): downsample ratio + """ + mask = np.zeros(imgsz, dtype=np.uint8) + polygons = np.asarray(polygons) + polygons = polygons.astype(np.int32) + shape = polygons.shape + polygons = polygons.reshape(shape[0], -1, 2) + cv2.fillPoly(mask, polygons, color=color) + nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) + # NOTE: fillPoly firstly then resize is trying the keep the same way + # of loss calculation when mask-ratio=1. + mask = cv2.resize(mask, (nw, nh)) + return mask + + +def bbox_ioa(box1, box2): + # Returns the intersection over box2 area given box1, box2. box1 is 4, box2 is nx4. boxes are x1y1x2y2 + box2 = box2.transpose() + + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + + # Intersection area + inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * ( + np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1) + ).clip(0) + + # box2 area + box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + 1e-16 + + # Intersection over box2 area + return inter_area / box2_area + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = np.copy(x) + y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x + y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y + y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x + y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xyn2xy(x, w=640, h=640, padw=0, padh=0): + # Convert normalized segments into pixel segments, shape (n,2) + y = np.copy(x) + y[..., 0] = w * x[..., 0] + padw # top left x + y[..., 1] = h * x[..., 1] + padh # top left y + return y + + +def segments2boxes(segments): + # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def segment2box(segment): + # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) + x, y = segment.T # segment xy + return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4) # xyxy + + +def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) + # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates diff --git a/mindyolo/models/heads/__init__.py b/mindyolo/models/heads/__init__.py index e1d6eac8..593e3df8 100644 --- a/mindyolo/models/heads/__init__.py +++ b/mindyolo/models/heads/__init__.py @@ -6,4 +6,12 @@ from .yolov8_head import * from .yolox_head import * -__all__ = ["YOLOv3Head", "YOLOv4Head", "YOLOv5Head", "YOLOv7Head", "YOLOv7AuxHead", "YOLOv8Head", "YOLOXHead"] + +__all__ = [ + "YOLOv3Head", + "YOLOv4Head", + "YOLOv5Head", + "YOLOv7Head", "YOLOv7AuxHead", + "YOLOv8Head", "YOLOv8SegHead", + "YOLOXHead" +] diff --git a/mindyolo/models/heads/yolov8_head.py b/mindyolo/models/heads/yolov8_head.py index 482ddc76..c821409f 100644 --- a/mindyolo/models/heads/yolov8_head.py +++ b/mindyolo/models/heads/yolov8_head.py @@ -109,3 +109,47 @@ def initialize_biases(self): b_np = b[-1].bias.data.asnumpy() b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2) b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32)) + + +class YOLOv8SegHead(YOLOv8Head): + """YOLOv8 Segment head for segmentation models.""" + + def __init__(self, nc=80, reg_max=16, nm=32, npr=256, stride=(), ch=()): + """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" + super().__init__(nc, reg_max, stride, ch) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.proto = Proto(ch[0], self.npr, self.nm) # protos + self.detect = YOLOv8Head.construct + + c4 = max(ch[0] // 4, self.nm) + self.cv4 = nn.CellList([nn.SequentialCell(ConvNormAct(x, c4, 3), ConvNormAct(c4, c4, 3), nn.Conv2d(c4, self.nm, 1, has_bias=True)) for x in ch]) + + def construct(self, x): + """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.""" + p = self.proto(x[0]) # mask protos + bs = p.shape[0] # batch size + + mc = ops.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients + x = self.detect(self, x) # x: out if self.training else (p, out) + if self.training: + return x, mc, p + + mc = ops.transpose(mc, (0, 2, 1)) # (bs, 32, nbox) -> (bs, nbox, 32) + # cat: (bs, nbox, no-84), (bs, nbox, 32) -> (bs, nbox, 84+32) + return ops.cat([x[0], mc], 2), (x[1], mc, p) + + +class Proto(nn.Cell): + """YOLOv8 mask Proto module for segmentation models.""" + + def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks + super().__init__() + self.cv1 = ConvNormAct(c1, c_, k=3) + self.upsample = nn.Conv2dTranspose(c_, c_, 2, 2, padding=0, has_bias=True) # nn.Upsample(scale_factor=2, mode='nearest') + self.cv2 = ConvNormAct(c_, c_, k=3) + self.cv3 = ConvNormAct(c_, c2) + + def construct(self, x): + """Performs a forward pass through layers using an upsampled input image.""" + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) diff --git a/mindyolo/models/losses/label_assignment.py b/mindyolo/models/losses/label_assignment.py deleted file mode 100644 index fef8ca93..00000000 --- a/mindyolo/models/losses/label_assignment.py +++ /dev/null @@ -1 +0,0 @@ -# TODO: General label assign method diff --git a/mindyolo/models/losses/yolov8_loss.py b/mindyolo/models/losses/yolov8_loss.py index 2e668d33..6e85efe6 100644 --- a/mindyolo/models/losses/yolov8_loss.py +++ b/mindyolo/models/losses/yolov8_loss.py @@ -9,7 +9,7 @@ CLIP_VALUE = 1000.0 EPS = 1e-7 -__all__ = ["YOLOv8Loss"] +__all__ = ["YOLOv8Loss", "YOLOv8SegLoss"] @register_model @@ -153,6 +153,165 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): return ops.concat(anchor_points), ops.concat(stride_tensor) +@register_model +class YOLOv8SegLoss(YOLOv8Loss): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, nm=32, overlap=True, max_object_num=600, **kwargs): + super(YOLOv8SegLoss, self).__init__(box, cls, dfl, stride, nc, reg_max) + + self.overlap = overlap + self.nm = nm + self.max_object_num = max_object_num + + # branch name returned by lossitem for print + self.loss_item_name = ["loss", "lbox", "lseg", "lcls", "dfl"] + + def construct(self, preds, target_box, target_seg): + """YOLOv8 Loss + Args: + feats: list of tensor, feats[i] shape: (bs, nc+reg_max*4, hi, wi) + targets: [image_idx,cls,x,y,w,h], shape: (bs, gt_max, 6) + """ + loss = ops.zeros(4, ms.float32) # box, cls, dfl, mask + # (bs, nc+reg_max*4, hi, wi), (bs, k, hi*wi), (bs, k, 138, 138); k = 32; + feats, pred_masks, proto = preds # x, mc, p; + batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + + _x = () + for xi in feats: + _x += (xi.view(batch_size, self.no, -1),) + _x = ops.concat(_x, 2) + pred_distri, pred_scores = _x[:, :self.reg_max * 4, :], _x[:, -self.nc:, :] # (bs, nc, h*w) + + # b, grids, .. + pred_scores = pred_scores.transpose(0, 2, 1) # (bs, h*w, nc) + pred_distri = pred_distri.transpose(0, 2, 1) # (bs, h*w, regmax * 4) + pred_masks = pred_masks.transpose(0, 2, 1) # (bs, h*w, k) + + dtype = pred_scores.dtype + imgsz = get_tensor(feats[0].shape[2:], dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = self.make_anchors(feats, self.stride, 0.5) + + # targets + target_box, mask_gt = self.preprocess(target_box, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = target_box[:, :, :1], target_box[:, :, 1:5] # cls, xyxy + + # pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, shape: (b, h*w, 4) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + self.sigmoid(pred_scores), + (pred_bboxes * stride_tensor).astype(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + # stop gradient + target_bboxes, target_scores, fg_mask, target_gt_idx = ( + ops.stop_gradient(target_bboxes), + ops.stop_gradient(target_scores), + ops.stop_gradient(fg_mask), + ops.stop_gradient(target_gt_idx) + ) + + target_scores_sum = ops.maximum(target_scores.sum(), 1) + + # cls loss + loss[2] = self.bce(pred_scores, ops.cast(target_scores, dtype)).sum() / target_scores_sum # BCE + + # bbox loss + loss[0], loss[3] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor, target_scores, target_scores_sum, fg_mask + ) + + # FIXME: mask target reshape, dynamic shape feature required. + # masks = target_seg # (b, 1, mask_h, mask_w) if overlap else (bs, N, mask_h, mask_w) + # if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + # masks = ops.interpolate(ops.expand_dims(masks, 0), size=(mask_h, mask_w), mode="nearest")[0] + + for i in range(batch_size): + _fg_mask, _fg_mask_index = ops.topk(fg_mask[i].astype(ms.float16), self.max_object_num) + _mask = target_seg[i] # (mask_h, mask_w) if overlap else (n_gt, mask_h, mask_w) + _mask_idx = target_gt_idx[i] # (b, N) -> (N,) + _mask_idx = ops.gather(_mask_idx, _fg_mask_index, axis=0) # (max_object_num,) + + if self.overlap: + _cond = _mask[None, :, :] == (_mask_idx[:, None, None] + 1) + gt_mask = ops.where( + _cond, + ops.ones(_cond.shape, pred_masks.dtype), + ops.zeros(_cond.shape, pred_masks.dtype) + ) + else: + gt_mask = _mask[_mask_idx] # (n_gt, mask_h, mask_w) -> (N, mask_h, mask_w)/(max_object_num, mask_h, mask_w) + + xyxyn = target_bboxes[i] / imgsz[[1, 0, 1, 0]] + marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) + mxyxy = xyxyn * get_tensor((mask_w, mask_h, mask_w, mask_h), xyxyn.dtype) + + _loss_1 = self.single_mask_loss( + gt_mask, pred_masks[i], proto[i], mxyxy, marea, _fg_mask, _fg_mask_index + ) + loss[1] += _loss_1 + + loss[0] *= self.hyp_box # box gain + loss[1] *= self.hyp_box / batch_size # seg gain + loss[2] *= self.hyp_cls # cls gain + loss[3] *= self.hyp_dfl # dfl gain + + return loss.sum() * batch_size, ops.stop_gradient( + ops.concat((loss.sum(keepdims=True), loss)) + ) # loss, lbox, lseg, lcls, ldfl + + def single_mask_loss(self, gt_mask, pred, proto, xyxy, area, _fg_mask, _fg_mask_index): + """Mask loss for one image.""" + pred = ops.gather(pred, _fg_mask_index, axis=0) + xyxy = ops.gather(xyxy, _fg_mask_index, axis=0) + area = ops.gather(area, _fg_mask_index, axis=0) + + _dtype = pred.dtype + pred_mask = ops.matmul( + pred.astype(ms.float16), + proto.astype(ms.float16).view(self.nm, -1) + ).view(-1, *proto.shape[1:]).astype(_dtype) # (n, 32) @ (32,80,80) -> (n,80,80) + + loss = ops.binary_cross_entropy_with_logits( + pred_mask, gt_mask, reduction='none', + weight=ops.ones(1, pred_mask.dtype), + pos_weight=ops.ones(1, pred_mask.dtype) + ) + + single_loss = (self.crop_mask(loss, xyxy).mean(axis=(1, 2)) / ops.clip(area, min=1e-4)) + single_loss *= _fg_mask + + num_seg = ops.clip(_fg_mask.sum(), min=1.0) + + return single_loss.sum() / num_seg + + @staticmethod + def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box + + Args: + masks (Tensor): [h, w, n] tensor of masks + boxes (Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (Tensor): The masks are being cropped to the bounding box. + """ + n, h, w = masks.shape + x1, y1, x2, y2 = ops.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = ops.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = ops.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ops.logical_and( + ops.logical_and((r >= x1), (r < x2)), + ops.logical_and((c >= y1), (c < y2)) + ).astype(x1.dtype) + + class BboxLoss(nn.Cell): def __init__(self, reg_max, use_dfl=False): super().__init__() @@ -388,7 +547,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_gt): fg_mask = mask_pos.sum(-2) # (b, n_gt, N) -> (b, N) # if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes - mask_multi_gts = ops.tile(ops.expand_dims(fg_mask, 1), (1, n_gt, 1)) # (b, n_gt, N) + mask_multi_gts = ops.tile(ops.expand_dims(fg_mask > 1, 1), (1, n_gt, 1)) # (b, n_gt, N) max_overlaps_idx = overlaps.argmax(1) # (b, n_gt, N) -> (b, N) is_max_overlaps = ops.one_hot( max_overlaps_idx, n_gt, on_value=ops.ones(1, ms.int32), off_value=ops.zeros(1, ms.int32) @@ -396,7 +555,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_gt): is_max_overlaps = ops.cast( ops.transpose(is_max_overlaps, (0, 2, 1)), overlaps.dtype ) # (b, N, n_gt) -> (b, n_gt, N) - mask_pos = mnp.where(mask_multi_gts > 0, is_max_overlaps, mask_pos) + mask_pos = mnp.where(mask_multi_gts, is_max_overlaps, mask_pos) fg_mask = mask_pos.sum(-2) # find each grid serve which gt(index) @@ -414,6 +573,23 @@ def xywh2xyxy(x): return y +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format. + + Args: + x (Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + Returns: + y (Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + y = ops.Identity()(x) + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + @ops.constexpr def get_tensor(x, dtype=ms.float32): return Tensor(x, dtype) diff --git a/mindyolo/models/model_factory.py b/mindyolo/models/model_factory.py index 49f43dc7..5d7c5751 100644 --- a/mindyolo/models/model_factory.py +++ b/mindyolo/models/model_factory.py @@ -8,6 +8,7 @@ from .heads import * from .layers import * from .registry import is_model, model_entrypoint +from .initializer import initialize_defult __all__ = ["create_model", "build_model_from_cfg"] @@ -59,6 +60,7 @@ def __init__(self, model_cfg, in_channels=3, num_classes=80, sync_bn=False): f"Turn on recompute, and the results of the first {model_cfg.recompute_layers} layers " f"will be recomputed." ) + initialize_defult(self) def construct(self, x): y, dt = (), () # outputs @@ -183,8 +185,10 @@ def parse_model(d, ch, nc, sync_bn=False): # model_dict, input_channels(3) args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) - elif m in (YOLOv8Head, YOLOXHead): # head of anchor free + elif m in (YOLOv8Head, YOLOv8SegHead, YOLOXHead): # head of anchor free args.append([ch[x] for x in f]) + if m in (YOLOv8SegHead,): + args[3] = math.ceil(min(args[3], max_channels) * gw / 8) * 8 elif m is ReOrg: c2 = ch[f] * 4 else: diff --git a/mindyolo/models/yolov3.py b/mindyolo/models/yolov3.py index 5d870a8c..247064eb 100644 --- a/mindyolo/models/yolov3.py +++ b/mindyolo/models/yolov3.py @@ -4,7 +4,6 @@ from mindspore import Tensor, nn from .heads.yolov3_head import YOLOv3Head -from .initializer import initialize_defult from .model_factory import build_model_from_cfg from .registry import register_model @@ -30,15 +29,12 @@ def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] # default names - self.reset_parameter() + self.initialize_weights() def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - + def initialize_weights(self): # reset parameter for Detect Head m = self.model.model[-1] if isinstance(m, YOLOv3Head): @@ -47,8 +43,7 @@ def reset_parameter(self): @register_model def yolov3(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv3: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolov3 model.""" model = YOLOv3(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/models/yolov4.py b/mindyolo/models/yolov4.py index c54865d8..b70a261b 100644 --- a/mindyolo/models/yolov4.py +++ b/mindyolo/models/yolov4.py @@ -1,6 +1,5 @@ import mindspore.nn as nn -from .initializer import initialize_defult from .model_factory import build_model_from_cfg from .registry import register_model @@ -24,19 +23,12 @@ def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] # default names - self.reset_parameter() - def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - @register_model def yolov4(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv4: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolov4 model.""" model = YOLOv4(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/models/yolov5.py b/mindyolo/models/yolov5.py index 0f50acbe..6fee9f00 100644 --- a/mindyolo/models/yolov5.py +++ b/mindyolo/models/yolov5.py @@ -4,7 +4,6 @@ from mindspore import Tensor, nn from .heads.yolov7_head import YOLOv7AuxHead, YOLOv7Head -from .initializer import initialize_defult from .model_factory import build_model_from_cfg from .registry import register_model @@ -30,15 +29,12 @@ def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] # default names - self.reset_parameter() + self.initialize_weights() def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - + def initialize_weights(self): # reset parameter for Detect Head m = self.model.model[-1] if isinstance(m, YOLOv7Head): @@ -49,8 +45,7 @@ def reset_parameter(self): @register_model def yolov5(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv5: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolov5 model.""" model = YOLOv5(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/models/yolov7.py b/mindyolo/models/yolov7.py index 966071c7..393b0b0e 100644 --- a/mindyolo/models/yolov7.py +++ b/mindyolo/models/yolov7.py @@ -4,7 +4,6 @@ from mindspore import Tensor, nn from .heads.yolov7_head import YOLOv7AuxHead, YOLOv7Head -from .initializer import initialize_defult from .model_factory import build_model_from_cfg from .registry import register_model @@ -30,15 +29,12 @@ def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] # default names - self.reset_parameter() + self.initialize_weights() def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - + def initialize_weights(self): # reset parameter for Detect Head m = self.model.model[-1] if isinstance(m, YOLOv7Head): @@ -49,8 +45,7 @@ def reset_parameter(self): @register_model def yolov7(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv7: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolov7 model.""" model = YOLOv7(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/models/yolov8.py b/mindyolo/models/yolov8.py index 23ee2e7d..befd1a70 100644 --- a/mindyolo/models/yolov8.py +++ b/mindyolo/models/yolov8.py @@ -4,7 +4,6 @@ from mindspore import Tensor, nn from .heads.yolov8_head import YOLOv8Head -from .initializer import initialize_defult from .model_factory import build_model_from_cfg from .registry import register_model @@ -30,15 +29,12 @@ def __init__(self, cfg, in_channels=3, num_classes=None, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] # default names - self.reset_parameter() + self.initialize_weights() def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - + def initialize_weights(self): # reset parameter for Detect Head m = self.model.model[-1] if isinstance(m, YOLOv8Head): @@ -48,8 +44,7 @@ def reset_parameter(self): @register_model def yolov8(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOv8: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolov8 model.""" model = YOLOv8(cfg=cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/models/yolox.py b/mindyolo/models/yolox.py index 4f713ab9..f7463c1d 100644 --- a/mindyolo/models/yolox.py +++ b/mindyolo/models/yolox.py @@ -4,7 +4,6 @@ from mindspore import Tensor, nn from mindyolo.models.registry import register_model -from .initializer import initialize_defult from .heads import YOLOXHead from .model_factory import build_model_from_cfg @@ -29,15 +28,12 @@ def __init__(self, cfg, in_channels=3, num_classes=80, sync_bn=False): self.model = build_model_from_cfg(model_cfg=cfg, in_channels=ch, num_classes=nc, sync_bn=sync_bn) self.names = [str(i) for i in range(nc)] - self.reset_parameter() + self.initialize_weights() def construct(self, x): return self.model(x) - def reset_parameter(self): - # init default - initialize_defult(self) - + def initialize_weights(self): # reset parameter for Detect Head m = self.model.model[-1] assert isinstance(m, YOLOXHead) @@ -46,7 +42,6 @@ def reset_parameter(self): @register_model def yolox(cfg, in_channels=3, num_classes=None, **kwargs) -> YOLOX: - """Get GoogLeNet model. - Refer to the base class `models.GoogLeNet` for more details.""" + """Get yolox model.""" model = YOLOX(cfg, in_channels=in_channels, num_classes=num_classes, **kwargs) return model diff --git a/mindyolo/utils/callback.py b/mindyolo/utils/callback.py index 8d86b37c..df7c61de 100644 --- a/mindyolo/utils/callback.py +++ b/mindyolo/utils/callback.py @@ -154,7 +154,8 @@ def on_train_epoch_begin(self, run_context: RunContext): if self.is_switch_loss and cur_epoch_index == self.switch_epoch_index: logger.info(f"\nAdding L1 loss starts from epoch {self.switch_epoch_index}. Graph recompiling\n") trainer.loss_fn.use_l1 = True - trainer.train_step_fn = create_train_step_fn(network=trainer.network, + trainer.train_step_fn = create_train_step_fn(task='detect', + network=trainer.network, loss_fn=trainer.loss_fn, optimizer=trainer.optimizer, loss_ratio=loss_ratio, diff --git a/mindyolo/utils/metrics.py b/mindyolo/utils/metrics.py index 1565b7ab..8b9745cb 100644 --- a/mindyolo/utils/metrics.py +++ b/mindyolo/utils/metrics.py @@ -1,12 +1,16 @@ import time - +import cv2 import numpy as np +import mindspore as ms +from mindspore import ops, Tensor + __all__ = ["non_max_suppression", "scale_coords", "xyxy2xywh", "xywh2xyxy"] def non_max_suppression( prediction, + mask_coefficient=None, conf_thres=0.25, iou_thres=0.45, conf_free=False, @@ -37,6 +41,14 @@ def non_max_suppression( (prediction[..., :4], prediction[..., 4:].max(-1, keepdims=True), prediction[..., 4:]), axis=-1 ) + nm = 0 + if mask_coefficient is not None: + assert mask_coefficient.shape[:2] == prediction.shape[:2], \ + f"mask_coefficient shape {mask_coefficient.shape[:2]} and " \ + f"prediction.shape {prediction.shape[:2]} are not equal." + nm = mask_coefficient.shape[2] + prediction = np.concatenate((prediction, mask_coefficient), axis=-1) + # Settings min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height max_det = 300 # maximum number of detections per image @@ -47,7 +59,7 @@ def non_max_suppression( merge = False # use merge-NMS t = time.time() - output = [np.zeros((0, 6))] * prediction.shape[0] + output = [np.zeros((0, 6+nm))] * prediction.shape[0] for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height @@ -60,20 +72,22 @@ def non_max_suppression( # Scale class with conf if not conf_free: if nc == 1: - x[:, 5:] = x[:, 4:5] # signle cls no need to multiplicate. + x[:, 5:5+nc] = x[:, 4:5] # signle cls no need to multiplicate. else: - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + x[:, 5:5+nc] *= x[:, 4:5] # conf = obj_conf * cls_conf # Box (center x, center y, width, height) to (x1, y1, x2, y2) box = xywh2xyxy(x[:, :4]) # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5:] > conf_thres).nonzero() - x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32)), 1) + i, j = (x[:, 5:5+nc] > conf_thres).nonzero() + x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32)), 1) if nm == 0 else \ + np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32), x[i, -nm:]), 1) else: # best class only - conf, j = x[:, 5:].max(1, keepdim=True) - x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + conf, j = x[:, 5:5+nc].max(1, keepdim=True) + x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] if nm == 0 else \ + np.concatenate((box, conf, j.float(), x[:, -nm:]), 1)[conf.view(-1) > conf_thres] # Filter by class if classes is not None: @@ -221,3 +235,121 @@ def xyxy2xywh(x): y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 3] = x[:, 3] - x[:, 1] # height return y + + +#------------------------for segment------------------------ + +def scale_image(masks, img0_shape, pad=None): + """ + Takes a mask, and resizes it to the original image size + Args: + masks (numpy.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3]. + img0_shape (tuple): the original image shape + ratio_pad (tuple): the ratio of the padding to the original image. + Returns: + masks (numpy.ndarray): The masks that are being returned. + """ + + # Rescale coordinates (xyxy) from img1_shape to img0_shape + img1_shape = masks.shape + if (np.array(img1_shape[:2]) == np.array(img0_shape[:2])).all(): + return masks + + if pad is None: + ratio = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # ratio = old / new + pad = (img1_shape[0] - img0_shape[0] * ratio) / 2, (img1_shape[1] - img0_shape[1] * ratio) / 2 + + top, left = int(pad[0]), int(pad[1]) # y, x + bottom, right = int(img1_shape[0] - pad[0]), int(img1_shape[1] - pad[1]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, dsize=(img0_shape[1], img0_shape[0]), interpolation=cv2.INTER_LINEAR) + # masks = ops.interpolate(Tensor(masks, dtype=ms.float32)[None], shape, mode='bilinear', align_corners=False)[0].asnumpy() # CHW + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box + Args: + masks (numpy.ndarray): [h, w, n] array of masks + boxes (numpy.ndarray): [n, 4] array of bbox coordinates in relative point form + Returns: + (numpy.ndarray): The masks are being cropped to the bounding box. + """ + n, h, w = masks.shape + x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask_upsample(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher + quality but is slower. + Args: + protos (numpy.ndarray): [mask_dim, mask_h, mask_w] + masks_in (numpy.ndarray): [n, mask_dim], n is number of masks after nms + bboxes (numpy.ndarray): [n, 4], n is number of masks after nms + shape (tuple): the size of the input image (h,w) + Returns: + (numpy.ndarray): The upsampled masks. + """ + assert len(shape) == 2, f"The length of the shape is {len(shape)}, expected to be 2." + c, mh, mw = protos.shape # CHW + masks = sigmoid((np.matmul(masks_in, protos.reshape(c, -1)))).reshape(-1, mh, mw) + + # interpolate bilinear + # (n, mh, mw) -> (mh, mw, n) -> (*shape, n) -> (n, *shape) + # masks = cv2.resize(masks.transpose(1, 2, 0), dsize=shape, interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1) + masks = ops.interpolate(Tensor(masks, dtype=ms.float32)[None], shape, mode='bilinear', align_corners=False)[0].asnumpy() # CHW + + masks = crop_mask(masks, bboxes) # CHW + return masks > 0.5 + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (numpy.ndarray): A array of shape [mask_dim, mask_h, mask_w]. + masks_in (numpy.ndarray): A array of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (numpy.ndarray): A array of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (numpy.ndarray): A binary mask array of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + + assert len(shape) == 2, f"The length of the shape is {len(shape)}, expected to be 2." + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = sigmoid(np.matmul(masks_in, protos.view(c, -1))).reshape(-1, mh, mw) # CHW + + downsampled_bboxes = np.copy(bboxes) + downsampled_bboxes[:, 0] *= mw / iw + downsampled_bboxes[:, 2] *= mw / iw + downsampled_bboxes[:, 3] *= mh / ih + downsampled_bboxes[:, 1] *= mh / ih + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + # masks = cv2.resize(masks.transpose(1, 2, 0), dsize=shape, interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1) + masks = ops.interpolate(Tensor(masks, dtype=ms.float32)[None], shape, mode='bilinear', align_corners=False)[0].asnumpy() # CHW + return masks > 0.5 + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + +#---------------------------------------------------------- \ No newline at end of file diff --git a/mindyolo/data/poly.py b/mindyolo/utils/poly.py similarity index 100% rename from mindyolo/data/poly.py rename to mindyolo/utils/poly.py diff --git a/mindyolo/utils/train_step_factory.py b/mindyolo/utils/train_step_factory.py index f0e7640d..b7eec98b 100644 --- a/mindyolo/utils/train_step_factory.py +++ b/mindyolo/utils/train_step_factory.py @@ -34,41 +34,87 @@ def get_loss_scaler(ms_loss_scaler="static", scale_value=1024, scale_factor=2, s return loss_scaler -def create_train_step_fn(network, loss_fn, optimizer, loss_ratio, scaler, reducer, - ema=None, overflow_still_update=False, ms_jit=False): +def create_train_step_fn(task, network, loss_fn, optimizer, loss_ratio, scaler, reducer, + ema=None, overflow_still_update=False, ms_jit=False, clip_grad=False, clip_grad_value=10.): from mindspore.amp import all_finite use_ema = True if ema else False - def forward_func(x, label): - pred = network(x) - loss, loss_items = loss_fn(pred, label, x) - loss *= loss_ratio - return scaler.scale(loss), ops.stop_gradient(loss_items) - - grad_fn = ops.value_and_grad(forward_func, grad_position=None, weights=optimizer.parameters, has_aux=True) - - def train_step_func(x, label, optimizer_update=True): - (loss, loss_items), grads = grad_fn(x, label) - grads = reducer(grads) - unscaled_grads = scaler.unscale(grads) - grads_finite = all_finite(unscaled_grads) - - if optimizer_update: - if grads_finite: - loss = ops.depend(loss, optimizer(unscaled_grads)) - if use_ema: - loss = ops.depend(loss, ema.update()) - else: - if overflow_still_update: + if task == "detect": + + def forward_func(x, label): + pred = network(x) + loss, loss_items = loss_fn(pred, label, x) + loss *= loss_ratio + return scaler.scale(loss), ops.stop_gradient(loss_items) + + grad_fn = ops.value_and_grad(forward_func, grad_position=None, weights=optimizer.parameters, has_aux=True) + + def train_step_func(x, label, optimizer_update=True): + (loss, loss_items), grads = grad_fn(x, label) + grads = reducer(grads) + unscaled_grads = scaler.unscale(grads) + grads_finite = all_finite(unscaled_grads) + + if clip_grad: + unscaled_grads = ops.clip_by_global_norm(unscaled_grads, clip_norm=clip_grad_value) + + if optimizer_update: + if grads_finite: loss = ops.depend(loss, optimizer(unscaled_grads)) if use_ema: loss = ops.depend(loss, ema.update()) + else: + if overflow_still_update: + loss = ops.depend(loss, optimizer(unscaled_grads)) + if use_ema: + loss = ops.depend(loss, ema.update()) + + return scaler.unscale(loss), loss_items, unscaled_grads, grads_finite + + @ms.jit + def jit_warpper(*args): + return train_step_func(*args) + + return train_step_func if not ms_jit else jit_warpper + + elif task == "segment": + + def forward_func(x, label, seg): + pred = network(x) + loss, loss_items = loss_fn(pred, label, seg) + loss *= loss_ratio + return scaler.scale(loss), ops.stop_gradient(loss_items) - return scaler.unscale(loss), loss_items, unscaled_grads, grads_finite + grad_fn = ops.value_and_grad(forward_func, grad_position=None, weights=optimizer.parameters, has_aux=True) - @ms.jit - def jit_warpper(*args): - return train_step_func(*args) + def train_step_func(x, label, seg, optimizer_update=True): + (loss, loss_items), grads = grad_fn(x, label, seg) + grads = reducer(grads) + unscaled_grads = scaler.unscale(grads) + grads_finite = all_finite(unscaled_grads) - return train_step_func if not ms_jit else jit_warpper + if clip_grad: + unscaled_grads = ops.clip_by_global_norm(unscaled_grads, clip_norm=clip_grad_value) + + if optimizer_update: + if grads_finite: + loss = ops.depend(loss, optimizer(unscaled_grads)) + if use_ema: + loss = ops.depend(loss, ema.update()) + else: + if overflow_still_update: + loss = ops.depend(loss, optimizer(unscaled_grads)) + if use_ema: + loss = ops.depend(loss, ema.update()) + + return scaler.unscale(loss), loss_items, unscaled_grads, grads_finite + + @ms.jit + def jit_warpper(*args): + return train_step_func(*args) + + return train_step_func if not ms_jit else jit_warpper + + else: + raise NotImplementedError \ No newline at end of file diff --git a/mindyolo/utils/trainer_factory.py b/mindyolo/utils/trainer_factory.py index ea07aa2c..6eac11fa 100644 --- a/mindyolo/utils/trainer_factory.py +++ b/mindyolo/utils/trainer_factory.py @@ -165,8 +165,10 @@ def train( self.optimizer.momentum = Tensor(warmup_momentum[i], dtype) imgs, labels = data["image"], data["labels"] + segments = None if 'segment' not in data else data["segment"] self._on_train_step_begin(run_context) - run_context.loss, run_context.lr = self.train_step(imgs, labels, cur_step=cur_step,cur_epoch=cur_epoch) + run_context.loss, run_context.lr = self.train_step(imgs, labels, segments, + cur_step=cur_step,cur_epoch=cur_epoch) self._on_train_step_end(run_context) # train log @@ -352,9 +354,12 @@ def train_with_datasink( self._on_train_end(run_context) logger.info("End Train.") - def train_step(self, imgs, labels, cur_step=0, cur_epoch=0): + def train_step(self, imgs, labels, segments=None, cur_step=0, cur_epoch=0): if self.accumulate == 1: - loss, loss_item, _, grads_finite = self.train_step_fn(imgs, labels, True) + if segments is None: + loss, loss_item, _, grads_finite = self.train_step_fn(imgs, labels, True) + else: + loss, loss_item, _, grads_finite = self.train_step_fn(imgs, labels, segments, True) self.scaler.adjust(grads_finite) if not grads_finite and (cur_step % self.log_interval == 0): if self.overflow_still_update: @@ -362,7 +367,10 @@ def train_step(self, imgs, labels, cur_step=0, cur_epoch=0): else: logger.warning(f"overflow, drop step, loss scale adjust to {self.scaler.scale_value.asnumpy()}") else: - loss, loss_item, grads, grads_finite = self.train_step_fn(imgs, labels, False) + if segments is None: + loss, loss_item, grads, grads_finite = self.train_step_fn(imgs, labels, False) + else: + loss, loss_item, grads, grads_finite = self.train_step_fn(imgs, labels, segments, False) self.scaler.adjust(grads_finite) if grads_finite or self.overflow_still_update: self.accumulate_cur_step += 1 diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index a9ba5bd3..ff98e827 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -140,15 +140,15 @@ def freeze_layers(network, freeze=[]): def draw_result(img_path, result_dict, data_names, is_coco_dataset=True, save_path="./detect_results"): import random - import cv2 - from mindyolo.data import COCO80_TO_COCO91_CLASS os.makedirs(save_path, exist_ok=True) save_result_path = os.path.join(save_path, img_path.split("/")[-1]) im = cv2.imread(img_path) category_id, bbox, score = result_dict["category_id"], result_dict["bbox"], result_dict["score"] + seg = result_dict.get("segmentation", None) + mask = None if seg is None else np.zeros_like(im, dtype=np.float32) for i in range(len(bbox)): # draw box x_l, y_t, w, h = bbox[i][:] @@ -156,6 +156,9 @@ def draw_result(img_path, result_dict, data_names, is_coco_dataset=True, save_pa x_l, y_t, x_r, y_b = int(x_l), int(y_t), int(x_r), int(y_b) _color = [random.randint(0, 255) for _ in range(3)] cv2.rectangle(im, (x_l, y_t), (x_r, y_b), tuple(_color), 2) + if seg: + _color_seg = np.array([random.randint(0, 255) for _ in range(3)], np.float32) + mask += seg[i][:, :, None] * _color_seg[None, None, :] # draw label if is_coco_dataset: @@ -169,6 +172,8 @@ def draw_result(img_path, result_dict, data_names, is_coco_dataset=True, save_pa cv2.putText(im, text, (x_l, y_t - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2) # save results + if seg: + im = (0.7 * im + 0.3 * mask).astype(np.uint8) cv2.imwrite(save_result_path, im) diff --git a/mindyolo/version.py b/mindyolo/version.py index 4dc10e12..456f9644 100644 --- a/mindyolo/version.py +++ b/mindyolo/version.py @@ -1,2 +1,2 @@ """version init""" -__version__ = "0.1.0" +__version__ = "0.3.0-dev" diff --git a/test.py b/test.py index 0c2fdde0..b9085121 100644 --- a/test.py +++ b/test.py @@ -4,27 +4,29 @@ import json import os import time -from typing import Union - import yaml -from pathlib import Path import numpy as np -from mindspore.communication import init, get_rank, get_group_size +from typing import Union +from pathlib import Path +from multiprocessing.pool import ThreadPool from pycocotools.coco import COCO +from pycocotools.mask import encode import mindspore as ms from mindspore import Tensor, context, nn, ParallelMode +from mindspore.communication import init, get_rank, get_group_size from mindyolo.data import COCO80_TO_COCO91_CLASS, COCODataset, create_loader from mindyolo.models.model_factory import create_model from mindyolo.utils import logger, get_logger from mindyolo.utils.config import parse_args -from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh +from mindyolo.utils.metrics import non_max_suppression, scale_coords, xyxy2xywh, scale_image, process_mask_upsample from mindyolo.utils.utils import set_seed, get_broadcast_datetime, Synchronizer def get_parser_test(parents=None): parser = argparse.ArgumentParser(description="Test", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect", "segment"]) parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU") parser.add_argument("--ms_mode", type=int, default=0, help="train mode, graph/pynative") parser.add_argument("--ms_amp_level", type=str, default="O0", help="amp level, O0/O1/O2") @@ -114,13 +116,21 @@ def set_default_test(args): args.weight = args.ckpt_dir if args.ckpt_dir else "" -def test( +def test(task, **kwargs): + if task == "detect": + return test_detect(**kwargs) + elif task == "segment": + return test_segment(**kwargs) + + +def test_detect( network: nn.Cell, dataloader: ms.dataset.Dataset, anno_json_path: str, conf_thres: float = 0.001, iou_thres: float = 0.65, conf_free: bool = False, + num_class: int = 80, nms_time_limit: float = -1.0, is_coco_dataset: bool = True, imgIds: list = [], @@ -147,9 +157,8 @@ def test( result_dicts = [] for i, data in enumerate(loader): - imgs, _, paths, ori_shape, pad, hw_scale = ( + imgs, paths, ori_shape, pad, hw_scale = ( data["image"], - data["labels"], data["img_files"], data["hw_ori"], data["pad"], @@ -253,6 +262,177 @@ def test( return map, map50 +def test_segment( + network: nn.Cell, + dataloader: ms.dataset.Dataset, + anno_json_path: str, + conf_thres: float = 0.001, + iou_thres: float = 0.65, + conf_free: bool = False, + num_class: int = 80, + nms_time_limit: float = -1.0, + is_coco_dataset: bool = True, + imgIds: list = [], + per_batch_size: int = -1, + rank: int = 0, + rank_size: int = 1, + save_dir: str = '', + synchronizer: Synchronizer = None, + cur_epoch: Union[str, int] = 0, # to distinguish saving directory from different epoch in eval while run mode +): + try: + from mindyolo.csrc import COCOeval_fast as COCOeval + except ImportError: + logger.warning(f'unable to load fast_coco_eval api, use normal one instead') + from pycocotools.cocoeval import COCOeval + + steps_per_epoch = dataloader.get_dataset_size() + loader = dataloader.create_dict_iterator(output_numpy=True, num_epochs=1) + coco91class = COCO80_TO_COCO91_CLASS + + sample_num = 0 + infer_times = 0.0 + nms_times = 0.0 + result_dicts = [] + + for i, data in enumerate(loader): + imgs, paths, ori_shape, pad, hw_scale = ( + data["image"], + data["img_files"], + data["hw_ori"], + data["pad"], + data["hw_scale"], + ) + nb, _, height, width = imgs.shape + imgs = Tensor(imgs, ms.float32) + + # Run infer + _t = time.time() + out, (_, _, prototypes) = network(imgs) # inference and training outputs + infer_times += time.time() - _t + + # Run NMS + t = time.time() + _c = num_class + 4 if conf_free else num_class + 5 + out = out.asnumpy() + bboxes, mask_coefficient = out[:, :, :_c], out[:, :, _c:] + out = non_max_suppression( + bboxes, + mask_coefficient, + conf_thres=conf_thres, + iou_thres=iou_thres, + conf_free=conf_free, + multi_label=True, + time_limit=nms_time_limit, + ) + nms_times += time.time() - t + + p = prototypes.asnumpy() + + # Statistics pred + for si, (pred, proto) in enumerate(zip(out, p)): + path = Path(str(paths[si])) + sample_num += 1 + if len(pred) == 0: + continue + + # Predictions + pred_masks = process_mask_upsample(proto, pred[:, 6:], pred[:, :4], shape=imgs[si].shape[1:]) + pred_masks = pred_masks.astype('float32') + pred_masks = scale_image(pred_masks.transpose(1, 2, 0), ori_shape[si], pad=pad[si]) + predn = np.copy(pred) + scale_coords( + imgs[si].shape[1:], predn[:, :4], ori_shape[si], ratio=hw_scale[si], pad=pad[si] + ) # native-space pred + + def single_encode(x): + """Encode predicted masks as RLE and append results to jdict.""" + rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] + rle['counts'] = rle['counts'].decode('utf-8') + return rle + + image_id = int(path.stem) if path.stem.isnumeric() else path.stem + box = xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + rles = [] + for _i in range(pred_masks.shape[0]): + rles.append(single_encode(pred_masks[_i])) + for j, (p, b) in enumerate(zip(pred.tolist(), box.tolist())): + result_dicts.append( + { + "image_id": image_id, + "category_id": coco91class[int(p[5])] if is_coco_dataset else int(p[5]), + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + "segmentation": rles[j] + } + ) + logger.info(f"Sample {steps_per_epoch}/{i + 1}, time cost: {(time.time() - _t) * 1000:.2f} ms.") + + # save and load result file for distributed case + if rank_size > 1: + # save result to file + # each epoch has a unique directory in eval while run mode + infer_dir = os.path.join(save_dir, 'infer', str(cur_epoch)) + os.makedirs(infer_dir, exist_ok=True) + infer_path = os.path.join(infer_dir, f'det_result_rank{rank}_{rank_size}.json') + with open(infer_path, 'w') as f: + json.dump(result_dicts, f) + # synchronize + assert synchronizer is not None + synchronizer() + + # load file to result_dicts + f_names = os.listdir(infer_dir) + f_paths = [os.path.join(infer_dir, f) for f in f_names] + logger.info(f"Loading {len(f_names)} eval file from directory {infer_dir}: {sorted(f_names)}.") + assert len(f_names) == rank_size, f'number of eval file({len(f_names)}) should be equal to rank size({rank_size})' + result_dicts = [] + for path in f_paths: + with open(path, 'r') as fp: + result_dicts += json.load(fp) + + # Compute mAP + if not result_dicts: + logger.warning(f'Got 0 bbox after NMS, skip computing map') + map_bbox, map50_bbox, map_mask, map50_mask = 0.0, 0.0, 0.0, 0.0 + else: + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + print("Object detection:") + with contextlib.redirect_stdout(get_logger()): # redirect stdout to logger + anno = COCO(anno_json_path) # init annotations api + pred = anno.loadRes(result_dicts) # init predictions api + eval = COCOeval(anno, pred, "bbox") + if is_coco_dataset: + eval.params.imgIds = imgIds + eval.evaluate() + eval.accumulate() + eval.summarize() + map_bbox, map50_bbox = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) + print('\n') + print("Instance segmentation:") + with contextlib.redirect_stdout(get_logger()): # redirect stdout to logger + anno = COCO(anno_json_path) # init annotations api + pred = anno.loadRes(result_dicts) # init predictions api + eval = COCOeval(anno, pred, "segm") + if is_coco_dataset: + eval.params.imgIds = imgIds + eval.evaluate() + eval.accumulate() + eval.summarize() + map_mask, map50_mask = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) + except Exception as e: + logger.error(f"pycocotools unable to run: {e}") + raise e + + t = tuple(x / sample_num * 1E3 for x in (infer_times, nms_times, infer_times + nms_times)) + \ + (height, width, per_batch_size) # tuple + logger.info(f'Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g;' % t) + + return map_bbox, map50_bbox, map_mask, map50_mask + + def main(args): # Init s_time = time.time() @@ -288,7 +468,8 @@ def main(args): dataloader = create_loader( dataset=dataset, batch_collate_fn=dataset.test_collate_fn, - dataset_column_names=dataset.dataset_column_names, + column_names_getitem=dataset.column_names_getitem, + column_names_collate=dataset.column_names_collate, batch_size=args.per_batch_size, epoch_size=1, rank=args.rank, @@ -301,6 +482,7 @@ def main(args): # Run test test( + task=args.task, network=network, dataloader=dataloader, anno_json_path=os.path.join( @@ -309,6 +491,7 @@ def main(args): conf_thres=args.conf_thres, iou_thres=args.iou_thres, conf_free=args.conf_free, + num_class=args.data.nc, nms_time_limit=args.nms_time_limit, is_coco_dataset=is_coco_dataset, imgIds=None if not is_coco_dataset else dataset.imgIds, diff --git a/tests/dataset_plots.py b/tests/dataset_plots.py index 278a0529..facb1c32 100644 --- a/tests/dataset_plots.py +++ b/tests/dataset_plots.py @@ -2,7 +2,7 @@ from mindyolo.data.dataset import COCODataset from mindyolo.data.loader import create_loader -from mindyolo.data.poly import show_img_with_bbox +from mindyolo.utils.poly import show_img_with_bbox from mindyolo.utils.config import parse_args if __name__ == "__main__": @@ -20,7 +20,8 @@ dataloader = create_loader( dataset=dataset, batch_collate_fn=dataset.test_collate_fn, - dataset_column_names=dataset.dataset_column_names, + column_names_getitem=dataset.column_names_getitem, + column_names_collate=dataset.column_names_collate, batch_size=cfg.per_batch_size * 2, epoch_size=1, rank=0, diff --git a/tests/modules/test_create_loader.py b/tests/modules/test_create_loader.py index c42b46f9..140a6b9b 100644 --- a/tests/modules/test_create_loader.py +++ b/tests/modules/test_create_loader.py @@ -39,8 +39,12 @@ def test_create_loader(mode, drop_remainder, shuffle, batch_size): dataset_path = './coco128' ms.set_context(mode=mode) transforms_dict = [ - {'func_name': 'mosaic', 'prob': 1.0, 'mosaic9_prob': 0.0, 'translate': 0.1, 'scale': 0.9}, - {'func_name': 'mixup', 'prob': 0.1, 'alpha': 8.0, 'beta': 8.0, 'needed_mosaic': True}, + {'func_name': 'mosaic', 'prob': 1.0}, + {'func_name': 'random_perspective', 'prob': 1.0, 'translate': 0.1, 'scale': 0.9}, + {'func_name': 'mixup', 'prob': 0.1, 'alpha': 8.0, 'beta': 8.0, 'pre_transform': [ + { 'func_name': 'mosaic', 'prob': 1.0 }, + { 'func_name': 'random_perspective', 'prob': 1.0, 'translate': 0.1, 'scale': 0.9}, + ]}, {'func_name': 'hsv_augment', 'prob': 1.0, 'hgain': 0.015, 'sgain': 0.7, 'vgain': 0.4}, {'func_name': 'label_norm', 'xyxy2xywh_': True}, {'func_name': 'albumentations'}, @@ -62,7 +66,8 @@ def test_create_loader(mode, drop_remainder, shuffle, batch_size): dataloader = create_loader( dataset=dataset, batch_collate_fn=dataset.train_collate_fn, - dataset_column_names=dataset.dataset_column_names, + column_names_getitem=dataset.column_names_getitem, + column_names_collate=dataset.column_names_collate, batch_size=batch_size, epoch_size=1, shuffle=shuffle, diff --git a/tests/modules/test_create_trainer.py b/tests/modules/test_create_trainer.py index a603a818..0f933450 100644 --- a/tests/modules/test_create_trainer.py +++ b/tests/modules/test_create_trainer.py @@ -73,6 +73,7 @@ def test_create_trainer(yaml_name, mode): # Create train_step_fn scaler = StaticLossScaler(1.0) train_step_fn = create_train_step_fn( + task="detect", network=network, loss_fn=loss_fn, optimizer=optimizer, diff --git a/train.py b/train.py index 710e1727..1dbb05d6 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ def get_parser_train(parents=None): parser = argparse.ArgumentParser(description="Train", parents=[parents] if parents else []) + parser.add_argument("--task", type=str, default="detect", choices=["detect", "segment"]) parser.add_argument("--device_target", type=str, default="Ascend", help="device target, Ascend/GPU/CPU") parser.add_argument("--save_dir", type=str, default="./runs", help="save dir") parser.add_argument("--device_per_servers", type=int, default=8, help="device number on a server") @@ -32,12 +33,13 @@ def get_parser_train(parents=None): help="Whether to maintain loss using fp32/O0-level calculation") parser.add_argument("--ms_loss_scaler", type=str, default="static", help="train loss scaler, static/dynamic/none") parser.add_argument("--ms_loss_scaler_value", type=float, default=1024.0, help="static loss scale value") - parser.add_argument("--ms_grad_sens", type=float, default=1024.0, help="gard sens") parser.add_argument("--ms_jit", type=ast.literal_eval, default=True, help="use jit or not") parser.add_argument("--ms_enable_graph_kernel", type=ast.literal_eval, default=False, help="use enable_graph_kernel or not") parser.add_argument("--ms_datasink", type=ast.literal_eval, default=False, help="Train with datasink.") parser.add_argument("--overflow_still_update", type=ast.literal_eval, default=True, help="overflow still update") + parser.add_argument("--clip_grad", type=ast.literal_eval, default=False) + parser.add_argument("--clip_grad_value", type=float, default=10.0) parser.add_argument("--ema", type=ast.literal_eval, default=True, help="ema") parser.add_argument("--weight", type=str, default="", help="initial weight path") parser.add_argument("--ema_weight", type=str, default="", help="initial ema weight path") @@ -137,11 +139,13 @@ def train(args): single_cls=args.single_cls, batch_size=args.total_batch_size, stride=max(args.network.stride), + return_segments=(args.task == "segment") ) _dataloader = create_loader( dataset=_dataset, batch_collate_fn=_dataset.train_collate_fn, - dataset_column_names=_dataset.dataset_column_names, + column_names_getitem=_dataset.column_names_getitem, + column_names_collate=_dataset.column_names_collate, batch_size=args.per_batch_size, epoch_size=stage_epochs[stage], rank=args.rank, @@ -171,7 +175,8 @@ def train(args): eval_dataloader = create_loader( dataset=eval_dataset, batch_collate_fn=eval_dataset.test_collate_fn, - dataset_column_names=eval_dataset.dataset_column_names, + column_names_getitem=eval_dataset.column_names_getitem, + column_names_collate=eval_dataset.column_names_collate, batch_size=args.per_batch_size, epoch_size=1, rank=args.rank, @@ -201,6 +206,7 @@ def train(args): reducer = get_gradreducer(args.is_parallel, optimizer.parameters) scaler = get_loss_scaler(args.ms_loss_scaler, scale_value=args.ms_loss_scaler_value) train_step_fn = create_train_step_fn( + task=args.task, network=network, loss_fn=loss_fn, optimizer=optimizer, @@ -210,6 +216,8 @@ def train(args): ema=ema, overflow_still_update=args.overflow_still_update, ms_jit=args.ms_jit, + clip_grad=args.clip_grad, + clip_grad_value=args.clip_grad_value ) # Create callbacks @@ -224,6 +232,7 @@ def train(args): is_coco_dataset = "coco" in args.data.dataset_name test_fn = partial( test, + task=args.task, dataloader=eval_dataloader, anno_json_path=os.path.join( args.data.val_set[: -len(args.data.val_set.split("/")[-1])], "annotations/instances_val2017.json" @@ -231,6 +240,7 @@ def train(args): conf_thres=args.conf_thres, iou_thres=args.iou_thres, conf_free=args.conf_free, + num_class=args.data.nc, nms_time_limit=args.nms_time_limit, is_coco_dataset=is_coco_dataset, imgIds=None if not is_coco_dataset else eval_dataset.imgIds,