Skip to content

Commit

Permalink
Merge pull request #197 from zhanghuiyao/master
Browse files Browse the repository at this point in the history
fix yolov8x-seg bug
  • Loading branch information
SamitHuang authored Sep 6, 2023
2 parents 323a042 + 9992ccd commit 0313ab2
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 25 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

MindYOLO is [MindSpore Lab](https://github.com/mindspore-lab)'s software toolbox that implements state-of-the-art YOLO series algorithms, [support list and benchmark](MODEL_ZOO.md). It is written in Python and powered by the [MindSpore](https://mindspore.cn/) AI framework.

The master branch supporting **MindSpore 2.0**.
The master branch supporting **MindSpore 2.0/2.1**.

<img src="https://raw.githubusercontent.com/mindspore-lab/mindyolo/master/.github/000000137950.jpg" />


## What is New

- 2023/06/15
- 2023/09/05

1. Support YOLOv3/v4/v5/X/v7/v8 6 models and release 23 corresponding weights, see [MODEL ZOO](MODEL_ZOO.md) for details.
2. Support MindSpore 2.0.
3. Support deployment on MindSpore lite 2.0.
4. New online documents are available!
1. Add YOLOv8-X segment model.
2. Dataset pipeline reconstruction(current supports seg/detect tasks).
3. Add IoU custom operators example on GPU.
4. Add distribute eval function.
5. Add fast coco eval api.
6. Tutorials and [Docs](https://mindspore-lab.github.io/mindyolo/) update(e.q. [Write a new model](https://mindspore-lab.github.io/mindyolo/zh/how_to_guides/write_a_new_model/), [Train Process Tutorial](https://mindspore-lab.github.io/mindyolo/zh/tutorials/quick_start/), ...).

## Benchmark and Model Zoo

Expand Down
6 changes: 4 additions & 2 deletions configs/yolov8/seg/yolov8-seg-base.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
task: segment
epochs: 500 # total train epochs
per_batch_size: 16 # 16 * 8 = 128
img_size: 640
conf_thres: 0.001
iou_thres: 0.7
conf_free: True
sync_bn: True
clip_grad: True
clip_grad_value: 10.0
opencv_threads_num: 0 # opencv: disable threading optimizations

network:
Expand Down Expand Up @@ -46,4 +48,4 @@ network:
- [[-1, 9], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)

- [[15, 18, 21], 1, YOLOv8Head, [nc, reg_max, stride]] # Detect(P3, P4, P5)
- [[15, 18, 21], 1, YOLOv8SegHead, [nc, reg_max, 32, 256, stride]] # Seg(P3, P4, P5)
4 changes: 2 additions & 2 deletions configs/yolov8/seg/yolov8x-seg.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__BASE__: [
'../../coco.yaml',
'./hyp.scratch.high.seg.yaml',
'./yolov8-seg-base.yaml'
'./yolov8-seg-base.yaml',
'./hyp.scratch.high.seg.yaml'
]

recompute: True
Expand Down
2 changes: 1 addition & 1 deletion deploy/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_detect(args):
result_dicts = []
for i, data in enumerate(loader):
imgs, paths, ori_shape, pad, hw_scale = (
data["image"],
data["images"],
data["img_files"],
data["hw_ori"],
data["pad"],
Expand Down
4 changes: 2 additions & 2 deletions mindyolo/data/albumentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Albumentations:
# Implement Albumentations augmentation https://github.com/ultralytics/yolov5
# YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self, size=640, random_resized_crop=True):
def __init__(self, size=640, random_resized_crop=True, **kwargs):
self.transform = None
prefix = _colorstr("albumentations: ")
try:
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self, size=640, random_resized_crop=True):
print(f"{prefix}{e}", flush=True)
print("[WARNING] albumentations load failed", flush=True)

def __call__(self, sample, p=1.0):
def __call__(self, sample, p=1.0, **kwargs):
if self.transform and random.random() < p:
im, bboxes, cls, bbox_format = sample['img'], sample['bboxes'], sample['cls'], sample['bbox_format']
assert bbox_format in ("ltrb", "xywhn")
Expand Down
9 changes: 4 additions & 5 deletions mindyolo/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self.img_files = sorted([x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in self.img_formats])
assert self.img_files, f"No images found"
except Exception as e:
raise Exception(f"Error loading data from {self.path}: {e}\nSee {self.help_url}")
raise Exception(f"Error loading data from {self.path}: {e}\n")

# Check cache
self.label_files = self._img2label_paths(self.img_files) # labels
Expand Down Expand Up @@ -305,7 +305,7 @@ def __getitem__(self, index):
sample = self.copy_paste(sample, prob)
elif random.random() < prob:
if func_name == "albumentations" and getattr(self, "albumentations", None) is None:
self.albumentations = Albumentations(size=self.img_size)
self.albumentations = Albumentations(size=self.img_size, **_trans)
if func_name == "letterbox":
new_shape = self.img_size if not self.rect else self.batch_shapes[self.batch[index]]
sample = self.letterbox(sample, new_shape, **_trans)
Expand Down Expand Up @@ -697,7 +697,7 @@ def mixup(self, sample, alpha: 32.0, beta: 32.0, pre_transform=None):
sample2 = self.copy_paste(sample2, prob)
elif random.random() < prob:
if func_name == "albumentations" and getattr(self, "albumentations", None) is None:
self.albumentations = Albumentations(size=self.img_size)
self.albumentations = Albumentations(size=self.img_size, **_trans)
sample2 = getattr(self, func_name)(sample2, **_trans)

assert isinstance(sample['segments'], np.ndarray), \
Expand Down Expand Up @@ -1098,10 +1098,9 @@ def segment_poly2mask(self, sample, mask_overlap, mask_ratio):
h, w = sample['img'].shape[:2]
if mask_overlap:
masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=mask_ratio)
masks = masks[None] # (1, h/mask_ratio, w/mask_ratio)
sample['cls'] = sample['cls'][sorted_idx]
sample['bboxes'] = sample['bboxes'][sorted_idx]
sample['segments'] = masks
sample['segments'] = masks # (h/mask_ratio, w/mask_ratio)
sample['segment_format'] = 'overlap'
else:
masks = polygons2masks((h, w), segments, color=1, downsample_ratio=mask_ratio)
Expand Down
12 changes: 9 additions & 3 deletions mindyolo/utils/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def train(
dtype = self.optimizer.momentum.dtype
self.optimizer.momentum = Tensor(warmup_momentum[i], dtype)

imgs, labels = data["image"], data["labels"]
segments = None if 'segment' not in data else data["segment"]
imgs, labels = data["images"], data["labels"]
segments = None if 'masks' not in data else data["masks"]
self._on_train_step_begin(run_context)
run_context.loss, run_context.lr = self.train_step(imgs, labels, segments,
cur_step=cur_step,cur_epoch=cur_epoch)
Expand Down Expand Up @@ -212,6 +212,7 @@ def train(

def train_with_datasink(
self,
task: str,
epochs: int,
main_device: bool,
warmup_epoch: int = 0,
Expand All @@ -230,7 +231,12 @@ def train_with_datasink(
profiler_step_num: int = 1
):
# Modify dataset columns name for data sink mode, because dataloader could not send string data to device.
loader = self.dataloader.project(["image", "labels"])
if task == "detect":
loader = self.dataloader.project(["images", "labels"])
elif task == "segment":
loader = self.dataloader.project(["images", "labels", "masks"])
else:
raise NotImplementedError

# to be compatible with old interface
has_eval_mask = list(isinstance(c, EvalWhileTrain) for c in self.callback)
Expand Down
2 changes: 1 addition & 1 deletion mindyolo/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_broadcast_datetime(rank_size=1, root_rank=0):
x = x[0].asnumpy().tolist()
return x

@ms.ms_function
@ms.jit
def broadcast(x, root_rank):
return ops.Broadcast(root_rank=root_rank)(x)

Expand Down
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_detect(

for i, data in enumerate(loader):
imgs, paths, ori_shape, pad, hw_scale = (
data["image"],
data["images"],
data["img_files"],
data["hw_ori"],
data["pad"],
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_segment(

for i, data in enumerate(loader):
imgs, paths, ori_shape, pad, hw_scale = (
data["image"],
data["images"],
data["img_files"],
data["hw_ori"],
data["pad"],
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/test_create_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_create_trainer(yaml_name, mode):
l[:, 0] = i # add target image index for build_targets()

data = (x, y)
dataset = de.NumpySlicesDataset(data=data, column_names=["image", "labels"])
dataset = de.NumpySlicesDataset(data=data, column_names=["images", "labels"])
dataset = dataset.batch(batch_size=bs)
dataloader = dataset.repeat(10)

Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def train(args):
logger.warning("Train with data sink mode.")
assert args.accumulate == 1, "datasink mode not support grad accumulate."
trainer.train_with_datasink(
task=args.task,
epochs=args.epochs,
main_device=main_device,
warmup_epoch=max(args.optimizer.warmup_epochs, args.optimizer.min_warmup_step // steps_per_epoch),
Expand Down

0 comments on commit 0313ab2

Please sign in to comment.