From 1fb2453afc75d81497a7ff19ac2a8b7fb76f2cc1 Mon Sep 17 00:00:00 2001 From: DexterJZ Date: Thu, 10 Aug 2023 15:31:57 +0800 Subject: [PATCH] feat: add ssd detector (#704) --- examples/det/ssd/README.md | 143 ++++ examples/det/ssd/callbacks.py | 130 ++++ examples/det/ssd/create_data.py | 211 ++++++ examples/det/ssd/data.py | 196 ++++++ examples/det/ssd/eval.py | 82 +++ examples/det/ssd/model.py | 722 +++++++++++++++++++++ examples/det/ssd/ssd_mobilenetv2.yaml | 80 +++ examples/det/ssd/ssd_mobilenetv2_gpu.yaml | 80 +++ examples/det/ssd/ssd_mobilenetv3.yaml | 80 +++ examples/det/ssd/ssd_mobilenetv3_gpu.yaml | 80 +++ examples/det/ssd/ssd_resnet50_fpn.yaml | 80 +++ examples/det/ssd/ssd_resnet50_fpn_gpu.yaml | 80 +++ examples/det/ssd/train.py | 128 ++++ examples/det/ssd/utils.py | 423 ++++++++++++ mindcv/models/mobilenetv2.py | 198 ++---- 15 files changed, 2580 insertions(+), 133 deletions(-) create mode 100644 examples/det/ssd/README.md create mode 100644 examples/det/ssd/callbacks.py create mode 100644 examples/det/ssd/create_data.py create mode 100644 examples/det/ssd/data.py create mode 100644 examples/det/ssd/eval.py create mode 100644 examples/det/ssd/model.py create mode 100644 examples/det/ssd/ssd_mobilenetv2.yaml create mode 100644 examples/det/ssd/ssd_mobilenetv2_gpu.yaml create mode 100644 examples/det/ssd/ssd_mobilenetv3.yaml create mode 100644 examples/det/ssd/ssd_mobilenetv3_gpu.yaml create mode 100644 examples/det/ssd/ssd_resnet50_fpn.yaml create mode 100644 examples/det/ssd/ssd_resnet50_fpn_gpu.yaml create mode 100644 examples/det/ssd/train.py create mode 100644 examples/det/ssd/utils.py diff --git a/examples/det/ssd/README.md b/examples/det/ssd/README.md new file mode 100644 index 00000000..c4599f45 --- /dev/null +++ b/examples/det/ssd/README.md @@ -0,0 +1,143 @@ +# SSD Based on MindCV Backbones + +> [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325) + +## Introduction + +SSD is an single-staged object detector. It discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location, and combines predictions from multi-scale feature maps to detect objects with various sizes. At prediction time, SSD generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape. + +

+ +

+

+ Figure 1. Architecture of SSD [1] +

+ +In this example, by leveraging [the multi-scale feature extraction of MindCV](https://github.com/mindspore-lab/mindcv/blob/main/docs/en/how_to_guides/feature_extraction.md), we demonstrate that using backbones from MindCV much simplifies the implementation of SSD. + +## Configurations + +Here, we provide three configurations of SSD. +* Using [MobileNetV2](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv2) as the backbone and the original detector described in the paper. +* Using [ResNet50](https://github.com/mindspore-lab/mindcv/tree/main/configs/resnet) as the backbone with a FPN and a shared-weight-based detector. +* Using [MobileNetV3](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv3) as the backbone and the original detector described in the paper. + +## Dataset + +We train and test SSD using [COCO 2017 Dataset](https://cocodataset.org/#download). The dataset contains +* 118000 images about 18 GB for training, and +* 5000 images about 1 GB for testing. + +## Quick Start + +### Preparation + +1. Clone MindCV repository by running +``` +git clone https://github.com/mindspore-lab/mindcv.git +``` + +2. Install dependencies as shown [here](https://mindspore-lab.github.io/mindcv/installation/). + +3. Download [COCO 2017 Dataset](https://cocodataset.org/#download), prepare the dataset as follows. +``` +. +└─cocodataset + ├─annotations + ├─instance_train2017.json + └─instance_val2017.json + ├─val2017 + └─train2017 +``` +Run the following commands to preprocess the dataset and convert it to [MindRecord format](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.mindrecord.html) for reducing preprocessing time during training and testing. +``` +cd mindcv # change directory to the root of MindCV repository +python examples/det/ssd/create_data.py coco --data_path [root of COCO 2017 Dataset] --out_path [directory for storing MindRecord files] +``` +Specify the path of the preprocessed dataset at keyword `data_dir` in the config file. + +4. Download the pretrained backbone weights from the table below, and specify the path to the backbone weights at keyword `backbone_ckpt_path` in the config file. +
+ +| MobileNetV2 | ResNet50 | MobileNetV3 | +|:----------------:|:----------------:|:----------------:| +| [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv2/mobilenet_v2_100-d5532038.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3/mobilenet_v3_large_100-1279ad5f.ckpt) | + +
+ +### Train + +It is highly recommended to use **distributed training** for this SSD implementation. + +For distributed training using **OpenMPI's `mpirun`**, simply run +``` +cd mindcv # change directory to the root of MindCV repository +mpirun -n [# of devices] python examples/det/ssd/train.py --config [the path to the config file] +``` +For example, if train SSD distributively with the `MobileNetV2` configuration on 8 devices, run +``` +cd mindcv # change directory to the root of MindCV repository +mpirun -n 8 python examples/det/ssd/train.py --config examples/det/ssd/ssd_mobilenetv2.yaml +``` + +For distributed training with [Ascend rank table](https://github.com/mindspore-lab/mindocr/blob/main/docs/en/tutorials/distribute_train.md#12-configure-rank_table_file-for-training), configure `ascend8p.sh` as follows +``` +#!/bin/bash +export DEVICE_NUM=8 +export RANK_SIZE=8 +export RANK_TABLE_FILE="./hccl_8p_01234567_127.0.0.1.json" + +for ((i = 0; i < ${DEVICE_NUM}; i++)); do + export DEVICE_ID=$i + export RANK_ID=$i + echo "Launching rank: ${RANK_ID}, device: ${DEVICE_ID}" + if [ $i -eq 0 ]; then + echo 'i am 0' + python examples/det/ssd/train.py --config [the path to the config file] &> ./train.log & + else + echo 'not 0' + python -u examples/det/ssd/train.py --config [the path to the config file] &> /dev/null & + fi +done +``` +and start training by running +``` +cd mindcv # change directory to the root of MindCV repository +bash ascend8p.sh +``` + +For single-device training, please run +``` +cd mindcv # change directory to the root of MindCV repository +python examples/det/ssd/train.py --config [the path to the config file] +``` + +### Test + +For testing the trained model, first specify the path to the model checkpoint at keyword `ckpt_path` in the config file, then run +``` +cd mindcv # change directory to the root of MindCV repository +python examples/det/ssd/eval.py --config [the path to the config file] +``` +For example, for testing SSD with the `MobileNetV2` configuration, run +``` +cd mindcv # change directory to the root of MindCV repository +python examples/det/ssd/eval.py --config examples/det/ssd/ssd_mobilenetv2.yaml +``` + +## Performance + +Here are the performance resutls and the pretrained model weights for each configuration. +
+ +| Configuration | Mixed Precision | mAP | Config | Download | +|:-----------------:|:---------------:|:----:|:------:|:--------:| +| MobileNetV2 | O2 | 23.2 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv2.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv2-5bbd7411.ckpt) | +| ResNet50 with FPN | O3 | 38.3 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_resnet50_fpn.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_resnet50_fpn-ac87ddac.ckpt) | +| MobileNetV3 | O2 | 23.8 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv3.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv3-53d9f6e9.ckpt) | + +
+ +## References + +[1] Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016). SSD: Single Shot Multibox Detector. In Computer Vision–ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11–14, 2016, Proceedings, Part I 14 (pp. 21-37). Springer International Publishing. diff --git a/examples/det/ssd/callbacks.py b/examples/det/ssd/callbacks.py new file mode 100644 index 00000000..7c8076e5 --- /dev/null +++ b/examples/det/ssd/callbacks.py @@ -0,0 +1,130 @@ +import os +import stat + +from utils import apply_eval + +from mindspore import log as logger +from mindspore import save_checkpoint +from mindspore.train.callback import Callback, CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor + + +class EvalCallBack(Callback): + """ + Evaluation callback when training. + + Args: + eval_function (function): evaluation function. + eval_param_dict (dict): evaluation parameters' configure dict. + interval (int): run evaluation interval, default is 1. + eval_start_epoch (int): evaluation start epoch, default is 1. + save_best_ckpt (bool): Whether to save best checkpoint, default is True. + best_ckpt_name (str): best checkpoint name, default is `best.ckpt`. + metrics_name (str): evaluation metrics name, default is `acc`. + + Returns: + None + + Examples: + >>> EvalCallBack(eval_function, eval_param_dict) + """ + + def __init__( + self, + eval_function, + eval_param_dict, + interval=1, + eval_start_epoch=1, + save_best_ckpt=True, + ckpt_directory="./", + best_ckpt_name="best.ckpt", + metrics_name="acc", + ): + super(EvalCallBack, self).__init__() + self.eval_function = eval_function + self.eval_param_dict = eval_param_dict + self.eval_start_epoch = eval_start_epoch + + if interval < 1: + raise ValueError("interval should >= 1.") + + self.interval = interval + self.save_best_ckpt = save_best_ckpt + self.best_res = 0 + self.best_epoch = 0 + + if not os.path.isdir(ckpt_directory): + os.makedirs(ckpt_directory) + + self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name) + self.metrics_name = metrics_name + + def remove_ckpoint_file(self, file_name): + """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + try: + os.chmod(file_name, stat.S_IWRITE) + os.remove(file_name) + except OSError: + logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) + except ValueError: + logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) + + def on_train_epoch_end(self, run_context): + """Callback when epoch end.""" + cb_params = run_context.original_args() + cur_epoch = cb_params.cur_epoch_num + + if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: + res = self.eval_function(self.eval_param_dict) + print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True) + + if res >= self.best_res: + self.best_res = res + self.best_epoch = cur_epoch + print("update best result: {}".format(res), flush=True) + + if self.save_best_ckpt: + if os.path.exists(self.best_ckpt_path): + self.remove_ckpoint_file(self.best_ckpt_path) + + save_checkpoint(cb_params.train_network, self.best_ckpt_path) + print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True) + + def on_train_end(self, run_context): + print( + "End training, the best {0} is: {1}, the best {0} epoch is {2}".format( + self.metrics_name, self.best_res, self.best_epoch + ), + flush=True, + ) + + +def get_ssd_callbacks(args, steps_per_epoch, rank_id): + ckpt_config = CheckpointConfig(keep_checkpoint_max=args.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="ssd", directory=args.ckpt_save_dir, config=ckpt_config) + + if rank_id == 0: + return [TimeMonitor(data_size=steps_per_epoch), LossMonitor(), ckpt_cb] + + return [TimeMonitor(data_size=steps_per_epoch), LossMonitor()] + + +def get_ssd_eval_callback(eval_net, eval_dataset, args): + if args.dataset == "coco": + anno_json = os.path.join(args.data_dir, "annotations/instances_val2017.json") + else: + raise NotImplementedError + + eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json, "args": args} + + eval_cb = EvalCallBack( + apply_eval, + eval_param_dict, + interval=args.eval_interval, + eval_start_epoch=args.eval_start_epoch, + save_best_ckpt=True, + ckpt_directory=args.ckpt_save_dir, + best_ckpt_name="best.ckpt", + metrics_name="mAP", + ) + + return eval_cb diff --git a/examples/det/ssd/create_data.py b/examples/det/ssd/create_data.py new file mode 100644 index 00000000..4530c056 --- /dev/null +++ b/examples/det/ssd/create_data.py @@ -0,0 +1,211 @@ +import argparse +import os + +import numpy as np + +from mindspore.mindrecord import FileWriter + +coco_classes = [ + "background", + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] + + +def create_coco_label(data_path, is_training): + """Get image path and annotation from COCO.""" + from pycocotools.coco import COCO + + coco_root = data_path + + if is_training: + data_type = "train2017" + else: + data_type = "val2017" + + # Classes need to train or test. + train_cls = coco_classes + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, f"annotations/instances_{data_type}.json") + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + images = [] + image_path_dict = {} + image_anno_dict = {} + for img_id in image_ids: + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + image_path = os.path.join(coco_root, data_type, file_name) + annos = [] + iscrowd = False + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + iscrowd = iscrowd or label["iscrowd"] + if class_name in train_cls: + x_min, x_max = bbox[0], bbox[0] + bbox[2] + y_min, y_max = bbox[1], bbox[1] + bbox[3] + annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]]) + + if not is_training and iscrowd: + continue + if len(annos) >= 1: + images.append(img_id) + image_path_dict[img_id] = image_path + image_anno_dict[img_id] = np.array(annos) + + return images, image_path_dict, image_anno_dict + + +def data_to_mindrecord_byte_image(dataset="coco", data_path="", out_path="", is_training=True, file_num=8): + """Create MindRecord file.""" + if is_training: + os.mkdir(os.path.join(out_path, "train")) + mindrecord_path = os.path.join(out_path, "train", dataset) + else: + os.mkdir(os.path.join(out_path, "val")) + mindrecord_path = os.path.join(out_path, "val", dataset) + + writer = FileWriter(mindrecord_path, file_num) + + if dataset == "coco": + images, image_path_dict, image_anno_dict = create_coco_label(data_path, is_training) + else: + raise NotImplementedError + + ssd_json = { + "img_id": {"type": "int32", "shape": [1]}, + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 5]}, + } + writer.add_schema(ssd_json, "ssd_json") + + for img_id in images: + image_path = image_path_dict[img_id] + + with open(image_path, "rb") as f: + img = f.read() + + annos = np.array(image_anno_dict[img_id], dtype=np.int32) + img_id = np.array([img_id], dtype=np.int32) + row = {"img_id": img_id, "image": img, "annotation": annos} + writer.write_raw_data([row]) + + writer.commit() + + +def convert_dataset(dataset="coco", data_path="", out_path=""): + if dataset == "coco": + if os.path.isdir(data_path): + print("Start converting training dataset...") + data_to_mindrecord_byte_image(dataset=dataset, data_path=data_path, out_path=out_path, is_training=True) + print("Training dataset conversion done.") + print("Start converting evaluation dataset...") + data_to_mindrecord_byte_image(dataset=dataset, data_path=data_path, out_path=out_path, is_training=False) + print("Evaluation dataset conversion done.") + else: + print("data path not exits.") + else: + raise NotImplementedError + + +parser = argparse.ArgumentParser(description="Data converter arg parser") +parser.add_argument("dataset", metavar="coco", help="name of the dataset") +parser.add_argument("--data_path", type=str, default="./data/coco/", help="specify the root path of dataset") +parser.add_argument( + "--out_path", type=str, default="./data/coco/", required=False, help="specify the path of the coverted dataset" +) +args = parser.parse_args() + + +if __name__ == "__main__": + convert_dataset(dataset=args.dataset, data_path=args.data_path, out_path=args.out_path) diff --git a/examples/det/ssd/data.py b/examples/det/ssd/data.py new file mode 100644 index 00000000..31d82ca0 --- /dev/null +++ b/examples/det/ssd/data.py @@ -0,0 +1,196 @@ +import os + +import cv2 +import numpy as np +from utils import jaccard_numpy, ssd_bboxes_encode + +import mindspore.dataset as de + + +def _rand(a=0.0, b=1.0): + """Generate random.""" + return np.random.rand() * (b - a) + a + + +def random_sample_crop(image, boxes): + """Random Crop the image and boxes""" + height, width, _ = image.shape + min_iou = np.random.choice([None, 0.1, 0.3, 0.5, 0.7, 0.9]) + + if min_iou is None: + return image, boxes + + # max trails (50) + for _ in range(50): + image_t = image + + w = _rand(0.3, 1.0) * width + h = _rand(0.3, 1.0) * height + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = _rand() * (width - w) + top = _rand() * (height - h) + + rect = np.array([int(top), int(left), int(top + h), int(left + w)]) + overlap = jaccard_numpy(boxes, rect) + + # dropout some boxes + drop_mask = overlap > 0 + if not drop_mask.any(): + continue + + if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2): + continue + + image_t = image_t[rect[0] : rect[2], rect[1] : rect[3], :] + + centers = (boxes[:, :2] + boxes[:, 2:4]) / 2.0 + + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 * drop_mask + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + boxes_t = boxes[mask, :].copy() + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], rect[:2]) + boxes_t[:, :2] -= rect[:2] + boxes_t[:, 2:4] = np.minimum(boxes_t[:, 2:4], rect[2:4]) + boxes_t[:, 2:4] -= rect[:2] + + return image_t, boxes_t + return image, boxes + + +def preprocess_fn(img_id, image, box, is_training, args): + """Preprocess function for dataset.""" + cv2.setNumThreads(2) + + def _infer_data(image, input_shape): + img_h, img_w, _ = image.shape + input_h, input_w = input_shape + + image = cv2.resize(image, (input_w, input_h)) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + return img_id, image, np.array((img_h, img_w), np.float32) + + def _data_aug(image, box, is_training, args): + """Data augmentation function.""" + ih, iw, _ = image.shape + h, w = args.image_size + + if not is_training: + return _infer_data(image, args.image_size) + + # Random crop + box = box.astype(np.float32) + image, box = random_sample_crop(image, box) + ih, iw, _ = image.shape + + # Resize image + image = cv2.resize(image, (w, h)) + + # Flip image or not + flip = _rand() < 0.5 + if flip: + image = cv2.flip(image, 1, dst=None) + + # When the channels of image is 1 + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = np.concatenate([image, image, image], axis=-1) + + box[:, [0, 2]] = box[:, [0, 2]] / ih + box[:, [1, 3]] = box[:, [1, 3]] / iw + + if flip: + box[:, [1, 3]] = 1 - box[:, [3, 1]] + + box, label, num_match = ssd_bboxes_encode(box, args) + return image, box, label, num_match + + return _data_aug(image, box, is_training, args) + + +def create_ssd_dataset( + name, + root, + shuffle, + batch_size, + python_multiprocessing, + num_parallel_workers, + drop_remainder, + args, + num_shards=1, + shard_id=0, + is_training=True, +): + """Create SSD dataset with MindDataset.""" + if name == "coco": + if is_training: + mindrecord_file = os.path.join(root, "train", "coco0") + else: + mindrecord_file = os.path.join(root, "val", "coco0") + + ds = de.MindDataset( + mindrecord_file, + columns_list=["img_id", "image", "annotation"], + num_shards=num_shards, + shard_id=shard_id, + num_parallel_workers=num_parallel_workers, + shuffle=shuffle, + ) + + decode = de.vision.Decode() + ds = ds.map(operations=decode, input_columns=["image"]) + change_swap_op = de.vision.HWC2CHW() + + # Computed from random subset of ImageNet training images + normalize_op = de.vision.Normalize( + mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255] + ) + color_adjust_op = de.vision.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + + def compose_map_func(img_id, image, annotation): + return preprocess_fn(img_id, image, annotation, is_training, args) + + if is_training: + output_columns = ["image", "box", "label", "num_match"] + trans = [color_adjust_op, normalize_op, change_swap_op] + else: + output_columns = ["img_id", "image", "image_shape"] + trans = [normalize_op, change_swap_op] + + ds = ds.map( + operations=compose_map_func, + input_columns=["img_id", "image", "annotation"], + output_columns=output_columns, + column_order=output_columns, + python_multiprocessing=python_multiprocessing, + num_parallel_workers=num_parallel_workers, + ) + ds = ds.map( + operations=trans, + input_columns=["image"], + python_multiprocessing=python_multiprocessing, + num_parallel_workers=num_parallel_workers, + ) + ds = ds.batch(batch_size, drop_remainder=drop_remainder) + + return ds + else: + raise NotImplementedError diff --git a/examples/det/ssd/eval.py b/examples/det/ssd/eval.py new file mode 100644 index 00000000..2b79cadb --- /dev/null +++ b/examples/det/ssd/eval.py @@ -0,0 +1,82 @@ +import argparse +import os +import sys + +import yaml +from addict import Dict +from data import create_ssd_dataset +from model import SSD, SSDInferWithDecoder +from utils import apply_eval + +from mindspore import load_checkpoint, load_param_into_net + +sys.path.append(".") + +from mindcv.models import create_model + + +def eval(args): + eval_dataset = create_ssd_dataset( + name=args.dataset, + root=args.data_dir, + shuffle=False, + batch_size=args.batch_size, + python_multiprocessing=True, + num_parallel_workers=args.num_parallel_workers, + drop_remainder=False, + args=args, + is_training=False, + ) + + backbone = create_model( + args.backbone, + features_only=args.backbone_features_only, + out_indices=args.backbone_out_indices, + ) + + ssd = SSD(backbone, args, is_training=False) + eval_model = SSDInferWithDecoder(ssd, args) + eval_model.init_parameters_data() + + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(eval_model, param_dict) + + eval_model.set_train(False) + + print("\n========================================\n") + print("Processing, please wait a moment.") + + if args.dataset == "coco": + anno_json = os.path.join(args.data_dir, "annotations/instances_val2017.json") + else: + raise NotImplementedError + + eval_param_dict = {"net": eval_model, "dataset": eval_dataset, "anno_json": anno_json, "args": args} + mAP = apply_eval(eval_param_dict) + + print("\n========================================\n") + print(f"mAP: {mAP}") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training Config", add_help=False) + parser.add_argument( + "-c", "--config", type=str, default="", help="YAML config file specifying default arguments (default=" ")" + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + yaml_fp = args.config + + with open(yaml_fp) as fp: + args = yaml.safe_load(fp) + + args = Dict(args) + + # core evaluation + eval(args) diff --git a/examples/det/ssd/model.py b/examples/det/ssd/model.py new file mode 100644 index 00000000..137ba716 --- /dev/null +++ b/examples/det/ssd/model.py @@ -0,0 +1,722 @@ +from utils import GeneratDefaultBoxes, GridAnchorGenerator + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore.common.initializer import TruncatedNormal, initializer +from mindspore.communication.management import get_group_size +from mindspore.context import ParallelMode +from mindspore.parallel._auto_parallel_context import auto_parallel_context + + +def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod="same"): + return nn.Conv2d( + in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=0, pad_mode=pad_mod, has_bias=True + ) + + +def _bn(channel): + return nn.BatchNorm2d( + channel, eps=1e-3, momentum=0.97, gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1 + ) + + +def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod="same", pad=0): + in_channels = in_channel + out_channels = in_channel + depthwise_conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, pad_mode=pad_mod, padding=pad, group=in_channels + ) + conv = _conv2d(in_channel, out_channel, kernel_size=1) + return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv]) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + shared_conv(Cell): Use the weight shared conv, default: None. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, shared_conv=None): + super(ConvBNReLU, self).__init__() + padding = 0 + in_channels = in_planes + out_channels = out_planes + if shared_conv is None: + if groups == 1: + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode="same", padding=padding) + else: + out_channels = in_planes + conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, pad_mode="same", padding=padding, group=in_channels + ) + layers = [conv, _bn(out_planes), nn.ReLU6()] + else: + layers = [shared_conv, _bn(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class InvertedResidual(nn.Cell): + """ + Residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + + def __init__(self, inp, oup, stride, expand_ratio, last_relu=False): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend( + [ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), + _bn(oup), + ] + ) + self.conv = nn.SequentialCell(layers) + self.cast = ops.Cast() + self.last_relu = last_relu + self.relu = nn.ReLU6() + + def construct(self, x): + identity = x + x = self.conv(x) + + if self.use_res_connect: + x = identity + x + + if self.last_relu: + x = self.relu(x) + + return x + + +class MobileNetV2Wrapper(nn.Cell): + def __init__(self, backbone, args): + super(MobileNetV2Wrapper, self).__init__() + self.backbone = backbone + feature1_output_channels = backbone.out_channels[0] + self.feature1_expand_layer = ConvBNReLU( + feature1_output_channels, int(round(feature1_output_channels * 6)), kernel_size=1 + ) + + in_channels = args.extras_in_channels + out_channels = args.extras_out_channels + ratios = args.extras_ratio + strides = args.extras_strides + residual_list = [] + + for i in range(2, len(in_channels)): + residual = InvertedResidual( + in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i], last_relu=True + ) + residual_list.append(residual) + + self.multi_residual = nn.CellList(residual_list) + + self._initialize_weights() + + def _initialize_weights(self) -> None: + params = self.feature1_expand_layer.trainable_params() + params.extend(self.multi_residual.trainable_params()) + + for p in params: + if "beta" not in p.name and "gamma" not in p.name and "bias" not in p.name: + p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) + + def construct(self, x): + feature1, feature2 = self.backbone(x) + layer_out = self.feature1_expand_layer(feature1) + multi_feature = (layer_out, feature2) + feature = feature2 + + for residual in self.multi_residual: + feature = residual(feature) + multi_feature += (feature,) + + return multi_feature + + +class FPNTopDown(nn.Cell): + """ + Fpn to extract features + """ + + def __init__(self, in_channel_list, out_channels): + super(FPNTopDown, self).__init__() + self.lateral_convs_list_ = [] + self.fpn_convs_ = [] + + for channel in in_channel_list: + l_conv = nn.Conv2d( + channel, out_channels, kernel_size=1, stride=1, has_bias=True, padding=0, pad_mode="same" + ) + fpn_conv = ConvBNReLU(out_channels, out_channels, kernel_size=3, stride=1) + self.lateral_convs_list_.append(l_conv) + self.fpn_convs_.append(fpn_conv) + + self.lateral_convs_list = nn.layer.CellList(self.lateral_convs_list_) + self.fpn_convs_list = nn.layer.CellList(self.fpn_convs_) + self.num_layers = len(in_channel_list) + + def construct(self, inputs): + image_features = () + + for i, feature in enumerate(inputs): + image_features = image_features + (self.lateral_convs_list[i](feature),) + + features = (image_features[-1],) + + for i in range(len(inputs) - 1): + top = len(inputs) - i - 1 + down = top - 1 + size = ops.shape(inputs[down]) + top_down = ops.ResizeBilinear((size[2], size[3]))(features[-1]) + top_down = top_down + image_features[down] + features = features + (top_down,) + + extract_features = () + num_features = len(features) + + for i in range(num_features): + extract_features = extract_features + (self.fpn_convs_list[i](features[num_features - i - 1]),) + + return extract_features + + +class BottomUp(nn.Cell): + """ + Bottom Up feature extractor + """ + + def __init__(self, levels, channels, kernel_size, stride): + super(BottomUp, self).__init__() + self.levels = levels + bottom_up_cells = [ConvBNReLU(channels, channels, kernel_size, stride) for x in range(self.levels)] + self.blocks = nn.CellList(bottom_up_cells) + + def construct(self, features): + for block in self.blocks: + features = features + (block(features[-1]),) + + return features + + +class ResNet50FPNWrapper(nn.Cell): + def __init__(self, backbone, args): + super(ResNet50FPNWrapper, self).__init__() + self.backbone = backbone + self.fpn = FPNTopDown([512, 1024, 2048], 256) + self.bottom_up = BottomUp(2, 256, 3, 2) + + self._initialize_weights() + + def _initialize_weights(self) -> None: + params = self.fpn.trainable_params() + params.extend(self.bottom_up.trainable_params()) + + for p in params: + if "beta" not in p.name and "gamma" not in p.name and "bias" not in p.name: + p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) + + def construct(self, x): + feature1, feature2, feature3 = self.backbone(x) + features = self.fpn((feature1, feature2, feature3)) + features = self.bottom_up(features) + + return features + + +class MobileNetV3Wrapper(nn.Cell): + def __init__(self, backbone, args): + super(MobileNetV3Wrapper, self).__init__() + self.backbone = backbone + + feature1_output_channels = backbone.out_channels[0] + + self.feature1_expand_layer = nn.SequentialCell( + [ + nn.Conv2d(feature1_output_channels, 672, 1, 1, pad_mode="pad", padding=0, has_bias=False), + nn.BatchNorm2d(672), + nn.HSwish(), + ] + ) + + in_channels = args.extras_in_channels + out_channels = args.extras_out_channels + ratios = args.extras_ratio + strides = args.extras_strides + residual_list = [] + + for i in range(2, len(in_channels)): + residual = InvertedResidual( + in_channels[i], out_channels[i], stride=strides[i], expand_ratio=ratios[i], last_relu=True + ) + residual_list.append(residual) + + self.multi_residual = nn.CellList(residual_list) + + self._initialize_weights() + + def _initialize_weights(self) -> None: + params = self.feature1_expand_layer.trainable_params() + params.extend(self.multi_residual.trainable_params()) + + for p in params: + if "beta" not in p.name and "gamma" not in p.name and "bias" not in p.name: + p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) + + def construct(self, x): + feature1, feature2 = self.backbone(x) + layer_out = self.feature1_expand_layer(feature1) + multi_feature = (layer_out, feature2) + feature = feature2 + + for residual in self.multi_residual: + feature = residual(feature) + multi_feature += (feature,) + + return multi_feature + + +backbone_wrapper = { + "mobilenet_v2_100": MobileNetV2Wrapper, + "resnet50": ResNet50FPNWrapper, + "mobilenet_v3_large_100": MobileNetV3Wrapper, +} + + +class FlattenConcat(nn.Cell): + """ + Concatenate predictions into a single tensor. + + Args: + config (dict): The default config of SSD. + + Returns: + Tensor, flatten predictions. + """ + + def __init__(self, args): + super(FlattenConcat, self).__init__() + self.num_ssd_boxes = args.num_ssd_boxes + self.concat = ops.Concat(axis=1) + self.transpose = ops.Transpose() + + def construct(self, inputs): + output = () + batch_size = ops.shape(inputs[0])[0] + + for x in inputs: + x = self.transpose(x, (0, 2, 3, 1)) + output += (ops.reshape(x, (batch_size, -1)),) + + res = self.concat(output) + + return ops.reshape(res, (batch_size, self.num_ssd_boxes, -1)) + + +class MultiBox(nn.Cell): + """ + Multibox conv layers. Each multibox layer contains class conf scores and localization predictions. + + Args: + config (dict): The default config of SSD. + + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + """ + + def __init__(self, args): + super(MultiBox, self).__init__() + num_classes = args.num_classes + out_channels = args.extras_out_channels + num_default = args.num_default + + loc_layers = [] + cls_layers = [] + + for k, out_channel in enumerate(out_channels): + loc_layers += [ + _last_conv2d(out_channel, 4 * num_default[k], kernel_size=3, stride=1, pad_mod="same", pad=0) + ] + cls_layers += [ + _last_conv2d(out_channel, num_classes * num_default[k], kernel_size=3, stride=1, pad_mod="same", pad=0) + ] + + self.multi_loc_layers = nn.CellList(loc_layers) + self.multi_cls_layers = nn.CellList(cls_layers) + self.flatten_concat = FlattenConcat(args) + + def construct(self, inputs): + loc_outputs = () + cls_outputs = () + + for i in range(len(self.multi_loc_layers)): + loc_outputs += (self.multi_loc_layers[i](inputs[i]),) + cls_outputs += (self.multi_cls_layers[i](inputs[i]),) + + return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) + + +class WeightSharedMultiBox(nn.Cell): + """ + Weight shared Multi-box conv layers. Each multi-box layer contains class conf scores and localization predictions. + All box predictors shares the same conv weight in different features. + + Args: + config (dict): The default config of SSD. + loc_cls_shared_addition(bool): Whether the location predictor and classifier prediction share the + same addition layer. + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + """ + + def __init__(self, args, loc_cls_shared_addition=False): + super(WeightSharedMultiBox, self).__init__() + num_classes = args.num_classes + out_channels = args.extras_out_channels[0] + num_default = args.num_default[0] + num_features = len(args.feature_size) + num_addition_layers = args.num_addition_layers + self.loc_cls_shared_addition = loc_cls_shared_addition + + if not loc_cls_shared_addition: + loc_convs = [_conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers)] + cls_convs = [_conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers)] + addition_loc_layer_list = [] + addition_cls_layer_list = [] + + for _ in range(num_features): + addition_loc_layer = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, loc_convs[x]) for x in range(num_addition_layers) + ] + addition_cls_layer = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, cls_convs[x]) for x in range(num_addition_layers) + ] + addition_loc_layer_list.append(nn.SequentialCell(addition_loc_layer)) + addition_cls_layer_list.append(nn.SequentialCell(addition_cls_layer)) + + self.addition_layer_loc = nn.CellList(addition_loc_layer_list) + self.addition_layer_cls = nn.CellList(addition_cls_layer_list) + else: + convs = [_conv2d(out_channels, out_channels, 3, 1) for x in range(num_addition_layers)] + addition_layer_list = [] + + for _ in range(num_features): + addition_layers = [ + ConvBNReLU(out_channels, out_channels, 3, 1, 1, convs[x]) for x in range(num_addition_layers) + ] + addition_layer_list.append(nn.SequentialCell(addition_layers)) + + self.addition_layer = nn.SequentialCell(addition_layer_list) + + loc_layers = [_conv2d(out_channels, 4 * num_default, kernel_size=3, stride=1, pad_mod="same")] + cls_layers = [_conv2d(out_channels, num_classes * num_default, kernel_size=3, stride=1, pad_mod="same")] + + self.loc_layers = nn.SequentialCell(loc_layers) + self.cls_layers = nn.SequentialCell(cls_layers) + self.flatten_concat = FlattenConcat(args) + + def construct(self, inputs): + loc_outputs = () + cls_outputs = () + num_heads = len(inputs) + + for i in range(num_heads): + if self.loc_cls_shared_addition: + features = self.addition_layer[i](inputs[i]) + loc_outputs += (self.loc_layers(features),) + cls_outputs += (self.cls_layers(features),) + else: + features = self.addition_layer_loc[i](inputs[i]) + loc_outputs += (self.loc_layers(features),) + features = self.addition_layer_cls[i](inputs[i]) + cls_outputs += (self.cls_layers(features),) + + return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs) + + +class SSD(nn.Cell): + """ + SSD300 Network. Default backbone is resnet34. + + Args: + backbone (Cell): Backbone Network. + config (dict): The default config of SSD. + + Returns: + Tensor, localization predictions. + Tensor, class conf scores. + + Examples:backbone + SSD300(backbone=resnet34(num_classes=None), + config=config). + """ + + def __init__(self, backbone, args, is_training=True): + super(SSD, self).__init__() + self.backbone_wrapper = backbone_wrapper[args.backbone](backbone, args) + + if args.get("use_fpn", False): + self.multi_box = WeightSharedMultiBox(args) + else: + self.multi_box = MultiBox(args) + + self.is_training = is_training + + if not is_training: + self.activation = ops.Sigmoid() + + self._initialize_weights() + + def _initialize_weights(self) -> None: + params = self.multi_box.trainable_params() + + for p in params: + if "beta" not in p.name and "gamma" not in p.name and "bias" not in p.name: + p.set_data(initializer(TruncatedNormal(0.02), p.data.shape, p.data.dtype)) + + def construct(self, x): + multi_feature = self.backbone_wrapper(x) + + pred_loc, pred_label = self.multi_box(multi_feature) + + if not self.is_training: + pred_label = self.activation(pred_label) + + pred_loc = ops.cast(pred_loc, ms.float32) + pred_label = ops.cast(pred_label, ms.float32) + + return pred_loc, pred_label + + +class SigmoidFocalClassificationLoss(nn.Cell): + """ " + Sigmoid focal-loss for classification. + + Args: + gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0 + alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25 + + Returns: + Tensor, the focal loss. + """ + + def __init__(self, gamma=2.0, alpha=0.25): + super(SigmoidFocalClassificationLoss, self).__init__() + self.sigmiod_cross_entropy = ops.SigmoidCrossEntropyWithLogits() + self.sigmoid = ops.Sigmoid() + self.pow = ops.Pow() + self.onehot = ops.OneHot() + self.on_value = Tensor(1.0, ms.float32) + self.off_value = Tensor(0.0, ms.float32) + self.gamma = gamma + self.alpha = alpha + + def construct(self, logits, label): + label = self.onehot(label, ops.shape(logits)[-1], self.on_value, self.off_value) + sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label) + sigmoid = self.sigmoid(logits) + label = ops.cast(label, ms.float32) + p_t = label * sigmoid + (1 - label) * (1 - sigmoid) + modulating_factor = self.pow(1 - p_t, self.gamma) + alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha) + focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy + return focal_loss + + +class SSDWithLossCell(nn.Cell): + """ " + Provide SSD training loss through network. + + Args: + network (Cell): The training network. + config (dict): SSD config. + + Returns: + Tensor, the loss of the network. + """ + + def __init__(self, network, args): + super(SSDWithLossCell, self).__init__(auto_prefix=False) + self.network = network + self.less = ops.Less() + self.tile = ops.Tile() + self.reduce_sum = ops.ReduceSum() + self.expand_dims = ops.ExpandDims() + self.class_loss = SigmoidFocalClassificationLoss(args.gamma, args.alpha) + self.loc_loss = nn.SmoothL1Loss() + + def construct(self, x, gt_loc, gt_label, num_matched_boxes): + pred_loc, pred_label = self.network(x) + mask = ops.cast(self.less(0, gt_label), ms.float32) + num_matched_boxes = self.reduce_sum(ops.cast(num_matched_boxes, ms.float32)) + + # Localization Loss + mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4)) + smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc + loss_loc = self.reduce_sum(self.reduce_sum(smooth_l1, -1), -1) + + # Classification Loss + loss_cls = self.class_loss(pred_label, gt_label) + loss_cls = self.reduce_sum(loss_cls, (1, 2)) + + return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes) + + +grad_scale = ops.MultitypeFuncGraph("grad_scale") + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * ops.Reciprocal()(scale) + + +class TrainingWrapper(nn.Cell): + """ + Encapsulation class of SSD network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + sens (Number): The adjust parameter. Default: 1.0. + use_global_nrom(bool): Whether apply global norm before optimizer. Default: False + """ + + def __init__(self, network, optimizer, sens=1.0, use_global_norm=False): + super(TrainingWrapper, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.weights = ms.ParameterTuple(network.trainable_params()) + self.optimizer = optimizer + self.grad = ops.GradOperation(get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.use_global_norm = use_global_norm + self.parallel_mode = ms.get_auto_parallel_context("parallel_mode") + + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + + if self.reducer_flag: + mean = ms.get_auto_parallel_context("gradients_mean") + + if auto_parallel_context().get_device_num_is_set(): + degree = ms.get_auto_parallel_context("device_num") + else: + degree = get_group_size() + + self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) + + self.hyper_map = ops.HyperMap() + + def construct(self, *args): + weights = self.weights + loss = self.network(*args) + sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(*args, sens) + + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + + if self.use_global_norm: + grads = self.hyper_map(ops.partial(grad_scale, ops.scalar_to_tensor(self.sens)), grads) + grads = ops.clip_by_global_norm(grads) + + self.optimizer(grads) + + return loss + + +class SSDInferWithDecoder(nn.Cell): + """ + SSD Infer wrapper to decode the bbox locations. + + Args: + network (Cell): the origin ssd infer network without bbox decoder. + default_boxes (Tensor): the default_boxes from anchor generator + config (dict): ssd config + Returns: + Tensor, the locations for bbox after decoder representing (y0,x0,y1,x1) + Tensor, the prediction labels. + + """ + + def __init__(self, network, args): + super(SSDInferWithDecoder, self).__init__(auto_prefix=False) + self.network = network + + if hasattr(args, "use_anchor_generator") and args.use_anchor_generator: + self.default_boxes, _ = GridAnchorGenerator(args.image_size, 4, 2, [1.0, 2.0, 0.5]).generate_multi_levels( + args.steps + ) + self.default_boxes = Tensor(self.default_boxes) + else: + self.default_boxes = Tensor(GeneratDefaultBoxes(args).default_boxes) + + self.prior_scaling_xy = args.prior_scaling[0] + self.prior_scaling_wh = args.prior_scaling[1] + + def construct(self, x): + pred_loc, pred_label = self.network(x) + + default_bbox_xy = self.default_boxes[..., :2] + default_bbox_wh = self.default_boxes[..., 2:] + pred_xy = pred_loc[..., :2] * self.prior_scaling_xy * default_bbox_wh + default_bbox_xy + pred_wh = ops.Exp()(pred_loc[..., 2:] * self.prior_scaling_wh) * default_bbox_wh + + pred_xy_0 = pred_xy - pred_wh / 2.0 + pred_xy_1 = pred_xy + pred_wh / 2.0 + pred_xy = ops.Concat(-1)((pred_xy_0, pred_xy_1)) + pred_xy = ops.Maximum()(pred_xy, 0) + pred_xy = ops.Minimum()(pred_xy, 1) + return pred_xy, pred_label + + +def get_ssd_trainer(model, optimizer, args): + return ms.Model(TrainingWrapper(model, optimizer, args.loss_scale)) diff --git a/examples/det/ssd/ssd_mobilenetv2.yaml b/examples/det/ssd/ssd_mobilenetv2.yaml new file mode 100644 index 00000000..899be2ab --- /dev/null +++ b/examples/det/ssd/ssd_mobilenetv2.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/root/zjd/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [300, 300] +num_ssd_boxes: 1917 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [29, 58, 89] + +# model +backbone: "mobilenet_v2_100" +backbone_ckpt_path: "./checkpoints/mobilenetv2/mobilenet_v2_100-d5532038.ckpt" +backbone_ckpt_auto_mapping: True +backbone_features_only: True +backbone_out_indices: [13, 18] + +ckpt_path: "./ckpt/best.ckpt" + +num_default: [3, 6, 6, 6, 6, 6] +extras_in_channels: [256, 576, 1280, 512, 256, 256] +extras_out_channels: [576, 1280, 512, 256, 256, 128] +extras_strides: [1, 1, 2, 2, 2, 2] +extras_ratio: [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] +feature_size: [19, 10, 5, 3, 2, 1] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [16, 32, 64, 100, 150, 300] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.75 +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O2" + +# scheduler +lr: 0.05 +lr_init: 0.001 +lr_end_rate: 0.001 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.00015 +momentum: 0.9 diff --git a/examples/det/ssd/ssd_mobilenetv2_gpu.yaml b/examples/det/ssd/ssd_mobilenetv2_gpu.yaml new file mode 100644 index 00000000..c9cf43d2 --- /dev/null +++ b/examples/det/ssd/ssd_mobilenetv2_gpu.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/data1/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [300, 300] +num_ssd_boxes: 1917 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [29, 58, 89] + +# model +backbone: "mobilenet_v2_100" +backbone_ckpt_path: "./checkpoints/mobilenetv2/mobilenet_v2_100-d5532038.ckpt" +backbone_ckpt_auto_mapping: True +backbone_features_only: True +backbone_out_indices: [13, 18] + +ckpt_path: "./ckpt/best.ckpt" + +num_default: [3, 6, 6, 6, 6, 6] +extras_in_channels: [256, 576, 1280, 512, 256, 256] +extras_out_channels: [576, 1280, 512, 256, 256, 128] +extras_strides: [1, 1, 2, 2, 2, 2] +extras_ratio: [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] +feature_size: [19, 10, 5, 3, 2, 1] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [16, 32, 64, 100, 150, 300] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.75 +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O0" + +# scheduler +lr: 0.05 +lr_init: 0.001 +lr_end_rate: 0.001 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.00015 +momentum: 0.9 diff --git a/examples/det/ssd/ssd_mobilenetv3.yaml b/examples/det/ssd/ssd_mobilenetv3.yaml new file mode 100644 index 00000000..89807c71 --- /dev/null +++ b/examples/det/ssd/ssd_mobilenetv3.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/root/zjd/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [300, 300] +num_ssd_boxes: 1917 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [29, 58, 89] + +# model +backbone: "mobilenet_v3_large_100" +backbone_ckpt_path: "./checkpoints/mobilenetv3/mobilenet_v3_large_100-1279ad5f.ckpt" +backbone_ckpt_auto_mapping: True +backbone_features_only: True +backbone_out_indices: [12, 16] + +ckpt_path: "./ckpt/best.ckpt" + +num_default: [3, 6, 6, 6, 6, 6] +extras_in_channels: [256, 672, 960, 512, 256, 256] +extras_out_channels: [672, 960, 512, 256, 256, 128] +extras_strides: [1, 1, 2, 2, 2, 2] +extras_ratio: [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] +feature_size: [19, 10, 5, 3, 2, 1] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [16, 32, 64, 100, 150, 300] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.75 +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O2" + +# scheduler +lr: 0.05 +lr_init: 0.001 +lr_end_rate: 0.001 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.00015 +momentum: 0.9 diff --git a/examples/det/ssd/ssd_mobilenetv3_gpu.yaml b/examples/det/ssd/ssd_mobilenetv3_gpu.yaml new file mode 100644 index 00000000..1f79a908 --- /dev/null +++ b/examples/det/ssd/ssd_mobilenetv3_gpu.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/data1/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [300, 300] +num_ssd_boxes: 1917 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [29, 58, 89] + +# model +backbone: "mobilenet_v3_large_100" +backbone_ckpt_path: "./checkpoints/mobilenetv3/mobilenet_v3_large_100-1279ad5f.ckpt" +backbone_ckpt_auto_mapping: True +backbone_features_only: True +backbone_out_indices: [12, 16] + +ckpt_path: "./ckpt/best.ckpt" + +num_default: [3, 6, 6, 6, 6, 6] +extras_in_channels: [256, 672, 960, 512, 256, 256] +extras_out_channels: [672, 960, 512, 256, 256, 128] +extras_strides: [1, 1, 2, 2, 2, 2] +extras_ratio: [0.2, 0.2, 0.2, 0.25, 0.5, 0.25] +feature_size: [19, 10, 5, 3, 2, 1] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [16, 32, 64, 100, 150, 300] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.75 +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O0" + +# scheduler +lr: 0.05 +lr_init: 0.001 +lr_end_rate: 0.001 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.00015 +momentum: 0.9 diff --git a/examples/det/ssd/ssd_resnet50_fpn.yaml b/examples/det/ssd/ssd_resnet50_fpn.yaml new file mode 100644 index 00000000..cced6b97 --- /dev/null +++ b/examples/det/ssd/ssd_resnet50_fpn.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/root/zjd/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [640, 640] +num_ssd_boxes: 51150 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [90, 183, 279] + +# model +backbone: "resnet50" +backbone_ckpt_path: "./checkpoints/resnet/resnet50-e0733ab8.ckpt" +backbone_ckpt_auto_mapping: False +backbone_features_only: True +backbone_out_indices: [2, 3, 4] + +ckpt_path: "./ckpt/best.ckpt" + +use_fpn: True +num_default: [6, 6, 6, 6, 6] +extras_out_channels: [256, 256, 256, 256, 256] +feature_size: [80, 40, 20, 10, 5] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [8, 16, 32, 64, 128] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.25 +num_addition_layers: 4 +use_anchor_generator: True +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O3" + +# scheduler +lr: 0.05 +lr_init: 0.01333 +lr_end_rate: 0.0 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.0004 +momentum: 0.9 diff --git a/examples/det/ssd/ssd_resnet50_fpn_gpu.yaml b/examples/det/ssd/ssd_resnet50_fpn_gpu.yaml new file mode 100644 index 00000000..01e5973a --- /dev/null +++ b/examples/det/ssd/ssd_resnet50_fpn_gpu.yaml @@ -0,0 +1,80 @@ +# system +mode: 0 +distribute: True +num_parallel_workers: 8 +enable_modelarts: False +eval_while_train: False + +# dataset +dataset: "coco" +data_dir: "/data1/coco_ori" +shuffle: True +batch_size: 32 +drop_remainder: True +num_classes: 81 +classes: ['background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', + 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', + 'refrigerator', 'book', 'clock', 'vase', 'scissors', + 'teddy bear', 'hair drier', 'toothbrush'] + +# Training options +image_size: [640, 640] +num_ssd_boxes: 51150 +match_threshold: 0.5 +nms_threshold: 0.6 +min_score: 0.1 +max_boxes: 100 +all_reduce_fusion_config: [90, 183, 279] + +# model +backbone: "resnet50" +backbone_ckpt_path: "./checkpoints/resnet/resnet50-e0733ab8.ckpt" +backbone_ckpt_auto_mapping: False +backbone_features_only: True +backbone_out_indices: [2, 3, 4] + +ckpt_path: "./ckpt/best.ckpt" + +use_fpn: True +num_default: [6, 6, 6, 6, 6] +extras_out_channels: [256, 256, 256, 256, 256] +feature_size: [80, 40, 20, 10, 5] +min_scale: 0.2 +max_scale: 0.95 +aspect_ratios: [[2, 3], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]] +steps: [8, 16, 32, 64, 128] +prior_scaling: [0.1, 0.2] +gamma: 2.0 +alpha: 0.25 +num_addition_layers: 4 +use_anchor_generator: True +epoch_size: 500 + +ckpt_save_dir: "./ckpt" +keep_checkpoint_max: 20 +eval_interval: 1 +eval_start_epoch: 350 +dataset_sink_mode: True +amp_level: "O0" + +# scheduler +lr: 0.05 +lr_init: 0.01333 +lr_end_rate: 0.0 +warmup_epochs: 2 + +# optimizer +loss_scale: 1024.0 +weight_decay: 0.0004 +momentum: 0.9 diff --git a/examples/det/ssd/train.py b/examples/det/ssd/train.py new file mode 100644 index 00000000..a58a6820 --- /dev/null +++ b/examples/det/ssd/train.py @@ -0,0 +1,128 @@ +import argparse +import os +import sys + +import yaml +from addict import Dict +from callbacks import get_ssd_callbacks, get_ssd_eval_callback +from data import create_ssd_dataset +from model import SSD, SSDInferWithDecoder, SSDWithLossCell, get_ssd_trainer +from utils import get_ssd_lr_scheduler, get_ssd_optimizer + +import mindspore as ms +from mindspore.communication import get_group_size, get_rank, init + +sys.path.append(".") + +from mindcv.models import create_model +from mindcv.utils import set_seed + + +def train(args): + """main train function""" + + ms.set_context(mode=args.mode) + + if args.distribute: + init() + device_num = get_group_size() + rank_id = get_rank() + ms.set_auto_parallel_context( + device_num=device_num, + parallel_mode="data_parallel", + gradients_mean=True, + all_reduce_fusion_config=args.all_reduce_fusion_config, + ) + else: + device_num = None + rank_id = None + + set_seed(args.seed) + + dataset = create_ssd_dataset( + name=args.dataset, + root=args.data_dir, + shuffle=args.shuffle, + batch_size=args.batch_size, + python_multiprocessing=True, + num_parallel_workers=args.num_parallel_workers, + drop_remainder=args.drop_remainder, + args=args, + num_shards=device_num, + shard_id=rank_id, + is_training=True, + ) + + steps_per_epoch = dataset.get_dataset_size() + + # use mindcv models as backbone for SSD + backbone = create_model( + args.backbone, + checkpoint_path=args.backbone_ckpt_path, + auto_mapping=args.get("backbone_ckpt_auto_mapping", False), + features_only=args.backbone_features_only, + out_indices=args.backbone_out_indices, + ) + + ssd = SSD(backbone, args) + ms.amp.auto_mixed_precision(ssd, amp_level=args.amp_level) + model = SSDWithLossCell(ssd, args) + + lr_scheduler = get_ssd_lr_scheduler(args, steps_per_epoch) + optimizer = get_ssd_optimizer(model, lr_scheduler, args) + + trainer = get_ssd_trainer(model, optimizer, args) + + callbacks = get_ssd_callbacks(args, steps_per_epoch, rank_id) + + if args.eval_while_train and rank_id == 0: + eval_model = SSDInferWithDecoder(ssd, args) + eval_dataset = create_ssd_dataset( + name=args.dataset, + root=args.data_dir, + shuffle=False, + batch_size=args.batch_size, + python_multiprocessing=True, + num_parallel_workers=args.num_parallel_workers, + drop_remainder=False, + args=args, + is_training=False, + ) + eval_callback = get_ssd_eval_callback(eval_model, eval_dataset, args) + callbacks.append(eval_callback) + + trainer.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=args.dataset_sink_mode) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training Config", add_help=False) + parser.add_argument( + "-c", "--config", type=str, default="", help="YAML config file specifying default arguments (default=" ")" + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + yaml_fp = args.config + + with open(yaml_fp) as fp: + args = yaml.safe_load(fp) + + args = Dict(args) + + # data sync for cloud platform if enabled + if args.enable_modelarts: + import moxing as mox + + args.data_dir = f"/cache/{args.data_url}" + mox.file.copy_parallel(src_url=os.path.join(args.data_url, args.dataset), dst_url=args.data_dir) + + # core training + train(args) + + if args.enable_modelarts: + mox.file.copy_parallel(src_url=args.ckpt_save_dir, dst_url=args.train_url) diff --git a/examples/det/ssd/utils.py b/examples/det/ssd/utils.py new file mode 100644 index 00000000..ddaad037 --- /dev/null +++ b/examples/det/ssd/utils.py @@ -0,0 +1,423 @@ +import itertools as it +import json +import math + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +import mindspore.nn as nn +from mindspore import Tensor + + +class GridAnchorGenerator: + """ + Anchor Generator + """ + + def __init__(self, image_shape, scale, scales_per_octave, aspect_ratios): + super(GridAnchorGenerator, self).__init__() + self.scale = scale + self.scales_per_octave = scales_per_octave + self.aspect_ratios = aspect_ratios + self.image_shape = image_shape + + def generate(self, step): + scales = np.array( + [2 ** (float(scale) / self.scales_per_octave) for scale in range(self.scales_per_octave)] + ).astype(np.float32) + aspects = np.array(list(self.aspect_ratios)).astype(np.float32) + + scales_grid, aspect_ratios_grid = np.meshgrid(scales, aspects) + scales_grid = scales_grid.reshape([-1]) + aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) + + feature_size = [self.image_shape[0] / step, self.image_shape[1] / step] + grid_height, grid_width = feature_size + + base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) + anchor_offset = step / 2.0 + + ratio_sqrt = np.sqrt(aspect_ratios_grid) + heights = scales_grid / ratio_sqrt * base_size[0] + widths = scales_grid * ratio_sqrt * base_size[1] + + y_centers = np.arange(grid_height).astype(np.float32) + y_centers = y_centers * step + anchor_offset + x_centers = np.arange(grid_width).astype(np.float32) + x_centers = x_centers * step + anchor_offset + x_centers, y_centers = np.meshgrid(x_centers, y_centers) + + x_centers_shape = x_centers.shape + y_centers_shape = y_centers.shape + + widths_grid, x_centers_grid = np.meshgrid(widths, x_centers.reshape([-1])) + heights_grid, y_centers_grid = np.meshgrid(heights, y_centers.reshape([-1])) + + x_centers_grid = x_centers_grid.reshape(*x_centers_shape, -1) + y_centers_grid = y_centers_grid.reshape(*y_centers_shape, -1) + widths_grid = widths_grid.reshape(-1, *x_centers_shape) + heights_grid = heights_grid.reshape(-1, *y_centers_shape) + + bbox_centers = np.stack([y_centers_grid, x_centers_grid], axis=3) + bbox_sizes = np.stack([heights_grid, widths_grid], axis=3) + bbox_centers = bbox_centers.reshape([-1, 2]) + bbox_sizes = bbox_sizes.reshape([-1, 2]) + bbox_corners = np.concatenate([bbox_centers - 0.5 * bbox_sizes, bbox_centers + 0.5 * bbox_sizes], axis=1) + self.bbox_corners = bbox_corners / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) + self.bbox_centers = np.concatenate([bbox_centers, bbox_sizes], axis=1) + self.bbox_centers = self.bbox_centers / np.array([*self.image_shape, *self.image_shape]).astype(np.float32) + + return self.bbox_centers, self.bbox_corners + + def generate_multi_levels(self, steps): + bbox_centers_list = [] + bbox_corners_list = [] + + for step in steps: + bbox_centers, bbox_corners = self.generate(step) + bbox_centers_list.append(bbox_centers) + bbox_corners_list.append(bbox_corners) + + self.bbox_centers = np.concatenate(bbox_centers_list, axis=0) + self.bbox_corners = np.concatenate(bbox_corners_list, axis=0) + + return self.bbox_centers, self.bbox_corners + + +class GeneratDefaultBoxes: + """ + Generate Default boxes for SSD, follows the order of (W, H, archor_sizes). + `self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w]. + `self.default_boxes_tlbr` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2]. + """ + + def __init__(self, args): + fk = args.image_size[0] / np.array(args.steps) + scale_rate = (args.max_scale - args.min_scale) / (len(args.num_default) - 1) + scales = [args.min_scale + scale_rate * i for i in range(len(args.num_default))] + [1.0] + self.default_boxes = [] + + for idex, feature_size in enumerate(args.feature_size): + sk1 = scales[idex] + sk2 = scales[idex + 1] + sk3 = math.sqrt(sk1 * sk2) + + if idex == 0 and not args.aspect_ratios[idex]: + w, h = sk1 * math.sqrt(2), sk1 / math.sqrt(2) + all_sizes = [(0.1, 0.1), (w, h), (h, w)] + else: + all_sizes = [(sk1, sk1)] + + for aspect_ratio in args.aspect_ratios[idex]: + w, h = sk1 * math.sqrt(aspect_ratio), sk1 / math.sqrt(aspect_ratio) + all_sizes.append((w, h)) + all_sizes.append((h, w)) + + all_sizes.append((sk3, sk3)) + + assert len(all_sizes) == args.num_default[idex] + + for i, j in it.product(range(feature_size), repeat=2): + for w, h in all_sizes: + cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex] + self.default_boxes.append([cy, cx, h, w]) + + def to_tlbr(cy, cx, h, w): + return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2 + + # For IoU calculation + self.default_boxes_tlbr = np.array(tuple(to_tlbr(*i) for i in self.default_boxes), dtype="float32") + self.default_boxes = np.array(self.default_boxes, dtype="float32") + + +def ssd_bboxes_encode(boxes, args): + """ + Labels anchors with ground truth inputs. + + Args: + boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls]. + + Returns: + gt_loc: location ground truth with shape [num_anchors, 4]. + gt_label: class ground truth with shape [num_anchors, 1]. + num_matched_boxes: number of positives in an image. + """ + if hasattr(args, "use_anchor_generator") and args.use_anchor_generator: + generator = GridAnchorGenerator(args.image_size, 4, 2, [1.0, 2.0, 0.5]) + default_boxes, default_boxes_tlbr = generator.generate_multi_levels(args.steps) + else: + generator = GeneratDefaultBoxes(args) + default_boxes_tlbr = generator.default_boxes_tlbr + default_boxes = generator.default_boxes + + y1, x1, y2, x2 = np.split(default_boxes_tlbr[:, :4], 4, axis=-1) + vol_anchors = (x2 - x1) * (y2 - y1) + + def jaccard_with_anchors(bbox): + """Compute jaccard score a box and the anchors.""" + # Intersection bbox and volume. + ymin = np.maximum(y1, bbox[0]) + xmin = np.maximum(x1, bbox[1]) + ymax = np.minimum(y2, bbox[2]) + xmax = np.minimum(x2, bbox[3]) + w = np.maximum(xmax - xmin, 0.0) + h = np.maximum(ymax - ymin, 0.0) + + # Volumes. + inter_vol = h * w + union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol + jaccard = inter_vol / union_vol + return np.squeeze(jaccard) + + pre_scores = np.zeros((args.num_ssd_boxes), dtype=np.float32) + t_boxes = np.zeros((args.num_ssd_boxes, 4), dtype=np.float32) + t_label = np.zeros((args.num_ssd_boxes), dtype=np.int64) + + for bbox in boxes: + label = int(bbox[4]) + scores = jaccard_with_anchors(bbox) + idx = np.argmax(scores) + scores[idx] = 2.0 + mask = scores > args.match_threshold + mask = mask & (scores > pre_scores) + pre_scores = np.maximum(pre_scores, scores * mask) + t_label = mask * label + (1 - mask) * t_label + + for i in range(4): + t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i] + + index = np.nonzero(t_label) + + # Transform to tlbr. + bboxes = np.zeros((args.num_ssd_boxes, 4), dtype=np.float32) + bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2 + bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]] + + # Encode features. + bboxes_t = bboxes[index] + default_boxes_t = default_boxes[index] + bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * args.prior_scaling[0]) + tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001) + bboxes_t[:, 2:4] = np.log(tmp) / args.prior_scaling[1] + bboxes[index] = bboxes_t + + num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) + return bboxes, t_label.astype(np.int32), num_match + + +def ssd_bboxes_decode(boxes, args): + """Decode predict boxes to [y, x, h, w]""" + if hasattr(args, "use_anchor_generator") and args.use_anchor_generator: + generator = GridAnchorGenerator(args.image_size, 4, 2, [1.0, 2.0, 0.5]) + default_boxes, _ = generator.generate_multi_levels(args.steps) + else: + default_boxes = GeneratDefaultBoxes(args).default_boxes + + boxes_t = boxes.copy() + # default_boxes_t = default_boxes.copy() + boxes_t[:, :2] = boxes_t[:, :2] * args.prior_scaling[0] * default_boxes[:, 2:] + default_boxes[:, :2] + boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * args.prior_scaling[1]) * default_boxes[:, 2:4] + + bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32) + + bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2 + bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2 + + return np.clip(bboxes, 0, 1) + + +def intersect(box_a, box_b): + """Compute the intersect of two sets of boxes.""" + max_yx = np.minimum(box_a[:, 2:4], box_b[2:4]) + min_yx = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes.""" + inter = intersect(box_a, box_b) + area_a = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) + area_b = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]) + union = area_a + area_b - inter + return inter / union + + +def get_ssd_lr_scheduler(args, steps_per_epoch): + """ + generate learning rate array for training + """ + lr_init = args.lr_init + lr_end = args.lr_end_rate * args.lr + lr_max = args.lr + warmup_epochs = args.warmup_epochs + total_epochs = args.epoch_size + + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + else: + lr = ( + lr_end + + (lr_max - lr_end) + * (1.0 + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) + / 2.0 + ) + + if lr < 0.0: + lr = 0.0 + + lr_each_step.append(lr) + + learning_rate = np.array(lr_each_step).astype(np.float32) + + return learning_rate + + +def get_ssd_optimizer(model, lr, args): + optimizer = nn.Momentum( + filter(lambda x: x.requires_grad, model.get_parameters()), lr, args.momentum, args.weight_decay, args.loss_scale + ) + return optimizer + + +def apply_nms(all_boxes, all_scores, thres, max_boxes): + """Apply NMS to bboxes.""" + y1 = all_boxes[:, 0] + x1 = all_boxes[:, 1] + y2 = all_boxes[:, 2] + x2 = all_boxes[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + + order = all_scores.argsort()[::-1] + keep = [] + + while order.size > 0: + i = order[0] + keep.append(i) + + if len(keep) >= max_boxes: + break + + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thres)[0] + + order = order[inds + 1] + + return keep + + +class COCOMetrics: + """Calculate mAP of predicted bboxes.""" + + def __init__(self, anno_json, classes, num_classes, min_score, nms_threshold, max_boxes): + self.num_classes = num_classes + self.classes = classes + self.min_score = min_score + self.nms_threshold = nms_threshold + self.max_boxes = max_boxes + + self.val_cls_dict = {i: cls for i, cls in enumerate(classes)} + self.coco_gt = COCO(anno_json) + cat_ids = self.coco_gt.loadCats(self.coco_gt.getCatIds()) + self.class_dict = {cat["name"]: cat["id"] for cat in cat_ids} + + self.predictions = [] + self.img_ids = [] + + def update(self, batch): + pred_boxes = batch["boxes"] + box_scores = batch["box_scores"] + img_id = batch["img_id"] + h, w = batch["image_shape"] + + final_boxes = [] + final_label = [] + final_score = [] + self.img_ids.append(img_id) + + for c in range(1, self.num_classes): + class_box_scores = box_scores[:, c] + score_mask = class_box_scores > self.min_score + class_box_scores = class_box_scores[score_mask] + class_boxes = pred_boxes[score_mask] * [h, w, h, w] + + if score_mask.any(): + nms_index = apply_nms(class_boxes, class_box_scores, self.nms_threshold, self.max_boxes) + class_boxes = class_boxes[nms_index] + class_box_scores = class_box_scores[nms_index] + + final_boxes += class_boxes.tolist() + final_score += class_box_scores.tolist() + final_label += [self.class_dict[self.val_cls_dict[c]]] * len(class_box_scores) + + for loc, label, score in zip(final_boxes, final_label, final_score): + res = {} + res["image_id"] = img_id + res["bbox"] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]] + res["score"] = score + res["category_id"] = label + self.predictions.append(res) + + def get_metrics(self): + with open("predictions.json", "w") as f: + json.dump(self.predictions, f) + + coco_dt = self.coco_gt.loadRes("predictions.json") + E = COCOeval(self.coco_gt, coco_dt, iouType="bbox") + E.params.imgIds = self.img_ids + E.evaluate() + E.accumulate() + E.summarize() + return E.stats[0] + + +def apply_eval(eval_param_dict): + net = eval_param_dict["net"] + net.set_train(False) + ds = eval_param_dict["dataset"] + anno_json = eval_param_dict["anno_json"] + args = eval_param_dict["args"] + coco_metrics = COCOMetrics( + anno_json=anno_json, + classes=args.classes, + num_classes=args.num_classes, + max_boxes=args.max_boxes, + nms_threshold=args.nms_threshold, + min_score=args.min_score, + ) + + for data in ds.create_dict_iterator(output_numpy=True, num_epochs=1): + img_id = data["img_id"] + img_np = data["image"] + image_shape = data["image_shape"] + + output = net(Tensor(img_np)) + + for batch_idx in range(img_np.shape[0]): + pred_batch = { + "boxes": output[0].asnumpy()[batch_idx], + "box_scores": output[1].asnumpy()[batch_idx], + "img_id": int(np.squeeze(img_id[batch_idx])), + "image_shape": image_shape[batch_idx], + } + coco_metrics.update(pred_batch) + + eval_metrics = coco_metrics.get_metrics() + + return eval_metrics diff --git a/mindcv/models/mobilenetv2.py b/mindcv/models/mobilenetv2.py index 780841f6..e81817d0 100644 --- a/mindcv/models/mobilenetv2.py +++ b/mindcv/models/mobilenetv2.py @@ -8,7 +8,7 @@ import mindspore.common.initializer as init from mindspore import Tensor, nn -from .helpers import load_pretrained, make_divisible +from .helpers import build_model_with_cfg, make_divisible from .layers.compatibility import Dropout from .layers.pooling import GlobalAvgPooling from .registry import register_model @@ -203,6 +203,13 @@ def __init__( nn.BatchNorm2d(input_channels), nn.ReLU6(), ] + + total_reduction = 2 + self.feature_info = [] + self.flatten_sequential = True + self.feature_info.append(dict(chs=input_channels, reduction=total_reduction, + name=f'features.{len(features) - 1}')) + # Building inverted residual blocks. for t, c, n, s in inverted_residual_setting: output_channel = make_divisible(c * alpha, round_nearest) @@ -210,12 +217,21 @@ def __init__( stride = s if i == 0 else 1 features.append(InvertedResidual(input_channels, output_channel, stride, expand_ratio=t)) input_channels = output_channel + + total_reduction *= stride + self.feature_info.append(dict(chs=output_channel, reduction=total_reduction, + name=f'features.{len(features) - 1}')) + # Building last point-wise layers. features.extend([ nn.Conv2d(input_channels, last_channels, 1, 1, pad_mode="pad", padding=0, has_bias=False), nn.BatchNorm2d(last_channels), nn.ReLU6(), ]) + + self.feature_info.append(dict(chs=last_channels, reduction=total_reduction, + name=f'features.{len(features) - 1}')) + self.features = nn.SequentialCell(features) self.pool = GlobalAvgPooling() @@ -259,18 +275,18 @@ def construct(self, x: Tensor) -> Tensor: return x +def _create_mobilenet_v2(pretrained=False, **kwargs): + return build_model_with_cfg(MobileNetV2, pretrained, **kwargs) + + @register_model def mobilenet_v2_140(pretrained: bool = False, num_classes: int = 1000, in_channels=3, **kwargs) -> MobileNetV2: """Get MobileNetV2 model with width scaled by 1.4 and input image size of 224. Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_140"] - model = MobileNetV2(alpha=1.4, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.4, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -279,12 +295,8 @@ def mobilenet_v2_130_224(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_130_224"] - model = MobileNetV2(alpha=1.3, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.3, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -293,12 +305,8 @@ def mobilenet_v2_100(pretrained: bool = False, num_classes: int = 1000, in_chann Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_100"] - model = MobileNetV2(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -307,12 +315,8 @@ def mobilenet_v2_100_192(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_100_192"] - model = MobileNetV2(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -321,12 +325,8 @@ def mobilenet_v2_100_160(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_100_160"] - model = MobileNetV2(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -335,12 +335,8 @@ def mobilenet_v2_100_128(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_100_128"] - model = MobileNetV2(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -349,12 +345,8 @@ def mobilenet_v2_100_96(pretrained: bool = False, num_classes: int = 1000, in_ch Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_100_96"] - model = MobileNetV2(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=1.0, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -363,12 +355,8 @@ def mobilenet_v2_075(pretrained: bool = False, num_classes: int = 1000, in_chann Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_075"] - model = MobileNetV2(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -377,12 +365,8 @@ def mobilenet_v2_075_192(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_075_192"] - model = MobileNetV2(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -391,12 +375,8 @@ def mobilenet_v2_075_160(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_075_160"] - model = MobileNetV2(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -405,12 +385,8 @@ def mobilenet_v2_075_128(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_075_128"] - model = MobileNetV2(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -419,12 +395,8 @@ def mobilenet_v2_075_96(pretrained: bool = False, num_classes: int = 1000, in_ch Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_075_96"] - model = MobileNetV2(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.75, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -433,12 +405,8 @@ def mobilenet_v2_050_224(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_050_224"] - model = MobileNetV2(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -447,12 +415,8 @@ def mobilenet_v2_050_192(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_050_192"] - model = MobileNetV2(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -461,12 +425,8 @@ def mobilenet_v2_050_160(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_050_160"] - model = MobileNetV2(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -475,12 +435,8 @@ def mobilenet_v2_050_128(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_050_128"] - model = MobileNetV2(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -489,12 +445,8 @@ def mobilenet_v2_050_96(pretrained: bool = False, num_classes: int = 1000, in_ch Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_050_96"] - model = MobileNetV2(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.5, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -503,12 +455,8 @@ def mobilenet_v2_035_224(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_035_224"] - model = MobileNetV2(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -517,12 +465,8 @@ def mobilenet_v2_035_192(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_035_192"] - model = MobileNetV2(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -531,12 +475,8 @@ def mobilenet_v2_035_160(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_035_160"] - model = MobileNetV2(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -545,12 +485,8 @@ def mobilenet_v2_035_128(pretrained: bool = False, num_classes: int = 1000, in_c Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_035_128"] - model = MobileNetV2(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args)) @register_model @@ -559,9 +495,5 @@ def mobilenet_v2_035_96(pretrained: bool = False, num_classes: int = 1000, in_ch Refer to the base class `models.MobileNetV2` for more details. """ default_cfg = default_cfgs["mobilenet_v2_035_96"] - model = MobileNetV2(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) - - if pretrained: - load_pretrained(model, default_cfg, num_classes=num_classes, in_channels=in_channels) - - return model + model_args = dict(alpha=0.35, num_classes=num_classes, in_channels=in_channels, **kwargs) + return _create_mobilenet_v2(pretrained, **dict(default_cfg=default_cfg, **model_args))