Skip to content

Commit

Permalink
Feature: support yolov10
Browse files Browse the repository at this point in the history
  • Loading branch information
WongGawa committed Oct 18, 2024
1 parent eeebe04 commit 584c73d
Show file tree
Hide file tree
Showing 13 changed files with 512 additions and 14 deletions.
60 changes: 60 additions & 0 deletions configs/yolov10/hyp.scratch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
optimizer:
optimizer: momentum
lr_init: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
momentum: 0.937 # SGD momentum/Adam beta1
nesterov: True # update gradients with NAG(Nesterov Accelerated Gradient) algorithm
loss_scale: 1.0 # loss scale for optimizer
warmup_epochs: 3 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
min_warmup_step: 1000 # minimum warmup step
group_param: yolov8 # group param strategy
gp_weight_decay: 0.0005 # group param weight decay 5e-4
start_factor: 1.0
end_factor: 0.01

loss:
name: YOLOv10Loss
box: 7.5 # box loss gain
cls: 0.5 # cls loss gain
dfl: 1.5 # dfl loss gain
reg_max: 16

data:
num_parallel_workers: 4

# multi-stage data augment
train_transforms: {
stage_epochs: [ 490, 10 ],
trans_list: [
[
{func_name: mosaic, prob: 1.0},
{func_name: resample_segments},
{func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0},
{func_name: albumentations},
{func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4},
{func_name: fliplr, prob: 0.5},
{func_name: label_norm, xyxy2xywh_: True},
{func_name: label_pad, padding_size: 160, padding_value: -1},
{func_name: image_norm, scale: 255.},
{func_name: image_transpose, bgr2rgb: True, hwc2chw: True}
],
[
{func_name: letterbox, scaleup: True},
{func_name: resample_segments},
{func_name: random_perspective, prob: 1.0, degrees: 0.0, translate: 0.1, scale: 0.5, shear: 0.0},
{func_name: albumentations},
{func_name: hsv_augment, prob: 1.0, hgain: 0.015, sgain: 0.7, vgain: 0.4},
{func_name: fliplr, prob: 0.5},
{func_name: label_norm, xyxy2xywh_: True},
{func_name: label_pad, padding_size: 160, padding_value: -1},
{func_name: image_norm, scale: 255.},
{func_name: image_transpose, bgr2rgb: True, hwc2chw: True}
]]
}

test_transforms: [
{func_name: letterbox, scaleup: False, only_image: True},
{func_name: image_norm, scale: 255.},
{func_name: image_transpose, bgr2rgb: True, hwc2chw: True}
]
57 changes: 57 additions & 0 deletions configs/yolov10/yolov10n.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
__BASE__: [
'../coco.yaml',
'./hyp.scratch.yaml',
]

epochs: 500 # total train epochs
per_batch_size: 16 # 16 * 8 = 128
img_size: 640
iou_thres: 0.7
overflow_still_update: False
opencv_threads_num: 0 # opencv: disable threading optimizations

network:
model_name: yolov10
nc: 80 # number of classes
reg_max: 16

depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
max_channels: 1024
stride: [8, 16, 32]

# YOLOv10.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, ConvNormAct, [64, 3, 2]] # 0-P1/2
- [-1, 1, ConvNormAct, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, ConvNormAct, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024, 5]] # 10

# YOLOv10.0n head
head:
- [-1, 1, Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13

- [-1, 1, Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1] ] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)

- [-1, 1, ConvNormAct, [256, 3, 2]]
- [[ -1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)

- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)

- [[16, 19, 22], 1, YOLOv10Head, [nc, reg_max, stride]] # Detect(P3, P4, P5)

4 changes: 3 additions & 1 deletion mindyolo/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import (heads, initializer, layers, losses, model_factory, yolov3,
yolov4, yolov5, yolov7, yolov8)
yolov4, yolov5, yolov7, yolov8, yolov10)

__all__ = []
__all__.extend(heads.__all__)
__all__.extend(layers.__all__)
__all__.extend(losses.__all__)
__all__.extend(yolov10.__all__)
__all__.extend(yolov8.__all__)
__all__.extend(yolov7.__all__)
__all__.extend(yolov5.__all__)
Expand All @@ -25,4 +26,5 @@
from .yolov5 import *
from .yolov7 import *
from .yolov8 import *
from .yolov10 import *
from .yolox import *
5 changes: 3 additions & 2 deletions mindyolo/models/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from .yolov7_head import *
from .yolov8_head import *
from .yolox_head import *

from .yolov10_head import *

__all__ = [
"YOLOv3Head",
"YOLOv4Head",
"YOLOv5Head",
"YOLOv7Head", "YOLOv7AuxHead",
"YOLOv8Head", "YOLOv8SegHead",
"YOLOXHead"
"YOLOXHead",
"YOLOv10Head"
]
194 changes: 194 additions & 0 deletions mindyolo/models/heads/yolov10_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import math
import numpy as np
from copy import deepcopy

import mindspore as ms
import mindspore.numpy as mnp
from mindspore import Parameter, Tensor, nn, ops

from ..layers import DFL, ConvNormAct, Identity
from ..layers.utils import meshgrid

class YOLOv10Head(nn.Cell):
# YOLOv10 Detect head for detection models
def __init__(self, nc=80, reg_max=16, stride=(), ch=(), sync_bn=False): # detection layer
super().__init__()
# self.dynamic = False # force grid reconstruction

assert isinstance(stride, (tuple, list)) and len(stride) > 0
assert isinstance(ch, (tuple, list)) and len(ch) > 0

self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = Parameter(Tensor(stride, ms.int32), requires_grad=False)
self.max_det = 300 # max_det
self.end2end = True
self.export = False

c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.CellList(
[
nn.SequentialCell(
[
ConvNormAct(x, c2, 3, sync_bn=sync_bn),
ConvNormAct(c2, c2, 3, sync_bn=sync_bn),
nn.Conv2d(c2, 4 * self.reg_max, 1, has_bias=True),
]
)
for x in ch
]
)
self.cv3 = nn.CellList(
[
nn.SequentialCell(
[
nn.SequentialCell(
[
ConvNormAct(x, x, 3, g=x),
ConvNormAct(x, c3, 1)
]
),
nn.SequentialCell([
ConvNormAct(c3, c3, 3, g=c3),
ConvNormAct(c3, c3, 1)
]
),
nn.Conv2d(c3, self.nc, 1, has_bias=True)
]
)
for i, x in enumerate(ch)
]
)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else Identity()

self.one2one_cv2 = deepcopy(self.cv2)
self.one2one_cv3 = deepcopy(self.cv3)

def construct(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
if self.end2end:
return self.construct_end2end(x)

x = ()
for i in range(self.nl):
x += (ops.concat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1),)
if self.training: # Training path
return x
y= self._inference(x)
return y if self.export else (y, x)

def construct_end2end(self, x):
"""
Performs forward pass of the YOLOv10Head module.
Args:
x (tensor): Input tensor.
Returns:
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
"""
x_detach = [ops.stop_gradient(xi) for xi in x]
one2one = ()
for i in range(self.nl):
one2one += (ops.concat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1),)
one2many = ()
for i in range(self.nl):
one2many += (ops.concat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1), )
if self.training: # Training path
return (one2many, one2one)
y = self._inference(one2one)
y = self.postprocess(y, self.max_det, self.nc)
return y if self.export else (y, {"one2many": one2many, "one2one": one2one})

def _inference(self, x):
# Inference path
shape = x[0].shape # BCHW
_anchors, _strides = self.make_anchors(x, self.stride, 0.5)
_anchors, _strides = _anchors.swapaxes(0, 1), _strides.swapaxes(0, 1)

_x = ()
for i in range(len(x)):
_x += (x[i].view(shape[0], self.no, -1),)
_x = ops.concat(_x, 2)
box, cls = _x[:, : self.reg_max * 4, :], _x[:, self.reg_max * 4 : self.reg_max * 4 + self.nc, :]
# box, cls = ops.concat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
dbox = self.dist2bbox(self.dfl(box), ops.expand_dims(_anchors, 0), xywh=not self.end2end, axis=1) * _strides

y = None
y = ops.concat((dbox, ops.Sigmoid()(cls)), 1)
y = ops.transpose(y, (0, 2, 1)) # (bs, no-84, nbox) -> (bs, nbox, no-84)
return (y, x)

@staticmethod
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points, stride_tensor = (), ()
dtype = feats[0].dtype
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = mnp.arange(w, dtype=dtype) + grid_cell_offset # shift x
sy = mnp.arange(h, dtype=dtype) + grid_cell_offset # shift y
# FIXME: Not supported on a specific model of machine
sy, sx = meshgrid((sy, sx), indexing="ij")
anchor_points += (ops.stack((sx, sy), -1).view(-1, 2),)
stride_tensor += (ops.ones((h * w, 1), dtype) * stride,)
return ops.concat(anchor_points), ops.concat(stride_tensor)

@staticmethod
def dist2bbox(distance, anchor_points, xywh=True, axis=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = ops.split(distance, split_size_or_sections=2, axis=axis)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return ops.concat((c_xy, wh), axis) # xywh bbox
return ops.concat((x1y1, x2y2), axis) # xyxy bbox

@staticmethod
def postprocess(preds, max_det, nc=80):
"""
Post-processes YOLO model predictions.
Args:
preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
format [x, y, w, h, class_probs].
max_det (int): Maximum detections per image.
nc (int, optional): Number of classes. Default: 80.
Returns:
(torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
dimension format [x, y, w, h, max_class_prob, class_index].
"""
batch_size, _, _ = preds.shape # i.e. shape(16,8400,84)
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = ops.amax(scores, dim=-1)
max_scores, index = ops.topk(max_scores, max_det, dim=-1)
index = ops.expand_dims(index, -1)
boxes = ops.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = ops.gather(scores, dim=1, index=index.repeat(1, 1, nc))

scores, index = ops.topk(ops.flatten(scores, start_dim=1), max_det, dim=-1)
i = ops.arange(batch_size)[..., None] # batch indices
return ops.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)

def initialize_biases(self):
# Initialize Detect() biases, WARNING: requires stride availability
m = self
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
s = s.asnumpy()
a[-1].bias = ops.assign(a[-1].bias, Tensor(np.ones(a[-1].bias.shape), ms.float32))
b_np = b[-1].bias.data.asnumpy()
b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2)
b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32))
if self.end2end:
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
s = s.asnumpy()
a[-1].bias = ops.assign(a[-1].bias, Tensor(np.ones(a[-1].bias.shape), ms.float32))
b_np = b[-1].bias.data.asnumpy()
b_np[: m.nc] = math.log(5 / m.nc / (640 / int(s)) ** 2)
b[-1].bias = ops.assign(b[-1].bias, Tensor(b_np, ms.float32))
3 changes: 3 additions & 0 deletions mindyolo/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@
"SPPF",
"Upsample",
"Residualblock",
"SCDown",
"PSA",
"C2fCIB",
]
Loading

0 comments on commit 584c73d

Please sign in to comment.