Skip to content

Commit

Permalink
feat: add ssd detector (mindspore-lab#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
DexterJZ authored Aug 10, 2023
1 parent 94ef850 commit 1fb2453
Show file tree
Hide file tree
Showing 15 changed files with 2,580 additions and 133 deletions.
143 changes: 143 additions & 0 deletions examples/det/ssd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SSD Based on MindCV Backbones

> [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325)
## Introduction

SSD is an single-staged object detector. It discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location, and combines predictions from multi-scale feature maps to detect objects with various sizes. At prediction time, SSD generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.

<p align="center">
<img src="https://github.com/DexterJZ/mindcv/assets/16130861/50bc9627-c71c-4b1a-9de4-9e6040a43279" width=800 />
</p>
<p align="center">
<em>Figure 1. Architecture of SSD [<a href="#references">1</a>] </em>
</p>

In this example, by leveraging [the multi-scale feature extraction of MindCV](https://github.com/mindspore-lab/mindcv/blob/main/docs/en/how_to_guides/feature_extraction.md), we demonstrate that using backbones from MindCV much simplifies the implementation of SSD.

## Configurations

Here, we provide three configurations of SSD.
* Using [MobileNetV2](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv2) as the backbone and the original detector described in the paper.
* Using [ResNet50](https://github.com/mindspore-lab/mindcv/tree/main/configs/resnet) as the backbone with a FPN and a shared-weight-based detector.
* Using [MobileNetV3](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv3) as the backbone and the original detector described in the paper.

## Dataset

We train and test SSD using [COCO 2017 Dataset](https://cocodataset.org/#download). The dataset contains
* 118000 images about 18 GB for training, and
* 5000 images about 1 GB for testing.

## Quick Start

### Preparation

1. Clone MindCV repository by running
```
git clone https://github.com/mindspore-lab/mindcv.git
```

2. Install dependencies as shown [here](https://mindspore-lab.github.io/mindcv/installation/).

3. Download [COCO 2017 Dataset](https://cocodataset.org/#download), prepare the dataset as follows.
```
.
└─cocodataset
├─annotations
├─instance_train2017.json
└─instance_val2017.json
├─val2017
└─train2017
```
Run the following commands to preprocess the dataset and convert it to [MindRecord format](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.mindrecord.html) for reducing preprocessing time during training and testing.
```
cd mindcv # change directory to the root of MindCV repository
python examples/det/ssd/create_data.py coco --data_path [root of COCO 2017 Dataset] --out_path [directory for storing MindRecord files]
```
Specify the path of the preprocessed dataset at keyword `data_dir` in the config file.

4. Download the pretrained backbone weights from the table below, and specify the path to the backbone weights at keyword `backbone_ckpt_path` in the config file.
<div align="center">

| MobileNetV2 | ResNet50 | MobileNetV3 |
|:----------------:|:----------------:|:----------------:|
| [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv2/mobilenet_v2_100-d5532038.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3/mobilenet_v3_large_100-1279ad5f.ckpt) |

</div>

### Train

It is highly recommended to use **distributed training** for this SSD implementation.

For distributed training using **OpenMPI's `mpirun`**, simply run
```
cd mindcv # change directory to the root of MindCV repository
mpirun -n [# of devices] python examples/det/ssd/train.py --config [the path to the config file]
```
For example, if train SSD distributively with the `MobileNetV2` configuration on 8 devices, run
```
cd mindcv # change directory to the root of MindCV repository
mpirun -n 8 python examples/det/ssd/train.py --config examples/det/ssd/ssd_mobilenetv2.yaml
```

For distributed training with [Ascend rank table](https://github.com/mindspore-lab/mindocr/blob/main/docs/en/tutorials/distribute_train.md#12-configure-rank_table_file-for-training), configure `ascend8p.sh` as follows
```
#!/bin/bash
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE="./hccl_8p_01234567_127.0.0.1.json"
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
echo "Launching rank: ${RANK_ID}, device: ${DEVICE_ID}"
if [ $i -eq 0 ]; then
echo 'i am 0'
python examples/det/ssd/train.py --config [the path to the config file] &> ./train.log &
else
echo 'not 0'
python -u examples/det/ssd/train.py --config [the path to the config file] &> /dev/null &
fi
done
```
and start training by running
```
cd mindcv # change directory to the root of MindCV repository
bash ascend8p.sh
```

For single-device training, please run
```
cd mindcv # change directory to the root of MindCV repository
python examples/det/ssd/train.py --config [the path to the config file]
```

### Test

For testing the trained model, first specify the path to the model checkpoint at keyword `ckpt_path` in the config file, then run
```
cd mindcv # change directory to the root of MindCV repository
python examples/det/ssd/eval.py --config [the path to the config file]
```
For example, for testing SSD with the `MobileNetV2` configuration, run
```
cd mindcv # change directory to the root of MindCV repository
python examples/det/ssd/eval.py --config examples/det/ssd/ssd_mobilenetv2.yaml
```

## Performance

Here are the performance resutls and the pretrained model weights for each configuration.
<div align="center">

| Configuration | Mixed Precision | mAP | Config | Download |
|:-----------------:|:---------------:|:----:|:------:|:--------:|
| MobileNetV2 | O2 | 23.2 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv2.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv2-5bbd7411.ckpt) |
| ResNet50 with FPN | O3 | 38.3 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_resnet50_fpn.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_resnet50_fpn-ac87ddac.ckpt) |
| MobileNetV3 | O2 | 23.8 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv3.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv3-53d9f6e9.ckpt) |

</div>

## References

[1] Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016). SSD: Single Shot Multibox Detector. In Computer Vision–ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11–14, 2016, Proceedings, Part I 14 (pp. 21-37). Springer International Publishing.
130 changes: 130 additions & 0 deletions examples/det/ssd/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import stat

from utils import apply_eval

from mindspore import log as logger
from mindspore import save_checkpoint
from mindspore.train.callback import Callback, CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor


class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
best_ckpt_name (str): best checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""

def __init__(
self,
eval_function,
eval_param_dict,
interval=1,
eval_start_epoch=1,
save_best_ckpt=True,
ckpt_directory="./",
best_ckpt_name="best.ckpt",
metrics_name="acc",
):
super(EvalCallBack, self).__init__()
self.eval_function = eval_function
self.eval_param_dict = eval_param_dict
self.eval_start_epoch = eval_start_epoch

if interval < 1:
raise ValueError("interval should >= 1.")

self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0

if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)

self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
self.metrics_name = metrics_name

def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)

def on_train_epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num

if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
res = self.eval_function(self.eval_param_dict)
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)

if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)

if self.save_best_ckpt:
if os.path.exists(self.best_ckpt_path):
self.remove_ckpoint_file(self.best_ckpt_path)

save_checkpoint(cb_params.train_network, self.best_ckpt_path)
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)

def on_train_end(self, run_context):
print(
"End training, the best {0} is: {1}, the best {0} epoch is {2}".format(
self.metrics_name, self.best_res, self.best_epoch
),
flush=True,
)


def get_ssd_callbacks(args, steps_per_epoch, rank_id):
ckpt_config = CheckpointConfig(keep_checkpoint_max=args.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="ssd", directory=args.ckpt_save_dir, config=ckpt_config)

if rank_id == 0:
return [TimeMonitor(data_size=steps_per_epoch), LossMonitor(), ckpt_cb]

return [TimeMonitor(data_size=steps_per_epoch), LossMonitor()]


def get_ssd_eval_callback(eval_net, eval_dataset, args):
if args.dataset == "coco":
anno_json = os.path.join(args.data_dir, "annotations/instances_val2017.json")
else:
raise NotImplementedError

eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json, "args": args}

eval_cb = EvalCallBack(
apply_eval,
eval_param_dict,
interval=args.eval_interval,
eval_start_epoch=args.eval_start_epoch,
save_best_ckpt=True,
ckpt_directory=args.ckpt_save_dir,
best_ckpt_name="best.ckpt",
metrics_name="mAP",
)

return eval_cb
Loading

0 comments on commit 1fb2453

Please sign in to comment.