Skip to content

Commit

Permalink
avoid the problem that ops.Custom call error on Ascend310
Browse files Browse the repository at this point in the history
  • Loading branch information
panshaowu committed Aug 31, 2023
1 parent 0f6de64 commit 3855f0a
Show file tree
Hide file tree
Showing 23 changed files with 101 additions and 114 deletions.
3 changes: 2 additions & 1 deletion docs/en/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ cd mindyolo/csrc
sh build.sh
```

We also provide fused GPU operators which are built upon MindSpore [ops.Custom](https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html) API. The fused GPU operators are able to improve train speed. The source code is provided in C++ and CUDA and is in the folder `mindyolo/models/losses/fused_op`. Before using it, you shall try compiling the source code to dynamic link libraries with the following commands **(This operation is optional)** :
We also provide fused GPU operators which are built upon MindSpore [ops.Custom](https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html) API. The fused GPU operators are able to improve train speed. The source code is provided in C++ and CUDA and is in the folder `examples/custom_gpu_op/`. To enable this feature in the GPU training process, you shall modify the method `bbox_iou` in `mindyolo/models/losses/iou_loss.py` by referring to the demo script `examples/custom_gpu_op/iou_loss_fused.py`. Before runing `iou_loss_fused.py`, you shall compile the C++ and CUDA source code to dynamic link libraries with the following commands **(This operation is optional)** :

```shell
cp -rf examples/custom_gpu_op/* mindyolo/models/losses/
bash mindyolo/models/losses/fused_op/build.sh
```
4 changes: 2 additions & 2 deletions docs/zh/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ cd mindyolo/csrc
sh build.sh
```

我们还提供了基于MindSpore [Custom自定义算子](https://www.mindspore.cn/tutorials/experts/zh-CN/master/operation/op_custom.html) 的GPU融合算子,用于提升训练过程的速度。代码采用C++和CUDA开发,位于`mindyolo/models/losses/fused_op`路径下。使用该特性前,需要使用以下的命令,编译生成GPU融合算子运行所依赖的动态库 **(此操作是可选的)** :
我们还提供了基于MindSpore [Custom自定义算子](https://www.mindspore.cn/tutorials/experts/zh-CN/master/operation/op_custom.html) 的GPU融合算子,用于提升训练过程的速度。代码采用C++和CUDA开发,位于`examples/custom_gpu_op/`路径下。您可参考示例脚本`examples/custom_gpu_op/iou_loss_fused.py`,修改`mindyolo/models/losses/iou_loss.py``bbox_iou`方法,在GPU训练过程中使用该特性。运行`iou_loss_fused.py`,需要使用以下的命令,编译生成GPU融合算子运行所依赖的动态库 **(此操作是可选的)** :

```shell
bash mindyolo/models/losses/fused_op/build.sh
bash examples/custom_gpu_op/fused_op/build.sh
```
File renamed without changes.
File renamed without changes.
57 changes: 57 additions & 0 deletions examples/custom_gpu_op/iou_loss_fused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np

from mindspore import ops, Tensor

from .fused_op import fused_get_ciou, fused_get_center_dist, fused_get_iou, \
fused_get_convex_diagonal_squared, fused_get_ciou_diagonal_angle, \
fused_get_boundding_boxes_coord, fused_get_intersection_area


def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Return intersection-over-union (IoU) of boxes.
Arguments:
box1 (Tensor[N, 4]) or (Tensor[bs, N, 4])
box2 (Tensor[N, 4]) or (Tensor[bs, N, 4])
xywh (bool): Whether the box format is (x_center, y_center, w, h) or (x1, y1, x2, y2). Default: True.
GIoU (bool): Whether to use GIoU. Default: False.
DIoU (bool): Whether to use DIoU. Default: False.
CIoU (bool): Whether to use CIoU. Default: False.
Returns:
iou (Tensor[N,]): the IoU values for every element in boxes1 and boxes2
"""
# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
x1, y1, w1, h1 = ops.split(box1, split_size_or_sections=1, axis=-1)
x2, y2, w2, h2 = ops.split(box2, split_size_or_sections=1, axis=-1)
b1_x1, b1_x2, b1_y1, b1_y2,b2_x1, b2_x2, b2_y1, b2_y2=fused_get_boundding_boxes_coord(x1, y1, w1, h1,x2, y2, w2, h2)
else: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = ops.split(box1, split_size_or_sections=1, axis=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = ops.split(box2, split_size_or_sections=1, axis=-1)

# Intersection area
inter = fused_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2)

w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
iou, union = fused_get_iou(w1, h1, w2, h2, inter)

if CIoU or DIoU or GIoU:
cw = ops.maximum(b1_x2, b2_x2) - ops.minimum(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = ops.maximum(b1_y2, b2_y2) - ops.minimum(b1_y1, b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = fused_get_convex_diagonal_squared(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2)
rho2 = fused_get_center_dist(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2)
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = fused_get_ciou_diagonal_angle(w1, h1, w2, h2)
_, res = fused_get_ciou(v, iou, rho2, c2)
return res
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
return iou # IoU

if __name__ =="__main__":
box1 = Tensor(np.random.rand(32, 4).astype(np.float32))
box2 = Tensor(np.random.rand(32, 4).astype(np.float32))
iou = bbox_iou(box1, box2, xywh=True, CIoU=True)
61 changes: 19 additions & 42 deletions mindyolo/models/losses/iou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

from mindyolo.models.layers.utils import box_cxcywh_to_xyxy

from .fused_op import fused_get_ciou, fused_get_center_dist, fused_get_iou, \
fused_get_convex_diagonal_squared, fused_get_ciou_diagonal_angle, \
fused_get_boundding_boxes_coord, fused_get_intersection_area

PI = Tensor(math.pi, ms.float32)
EPS = 1e-7

Expand Down Expand Up @@ -101,7 +97,7 @@ def batch_box_iou(batch_box1, batch_box2, xywh=False):
) # iou = inter / (area1 + area2 - inter)


def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7, use_fused_op=False):
def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
"""
Return intersection-over-union (IoU) of boxes.
Arguments:
Expand All @@ -111,64 +107,45 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7
GIoU (bool): Whether to use GIoU. Default: False.
DIoU (bool): Whether to use DIoU. Default: False.
CIoU (bool): Whether to use CIoU. Default: False.
use_fused_op(bool): Whether to use fused operator built upon aot customized operator. Default: False.
Returns:
iou (Tensor[N,]): the IoU values for every element in boxes1 and boxes2
"""

# Get the coordinates of bounding boxes
if xywh: # transform from xywh to xyxy
x1, y1, w1, h1 = ops.split(box1, split_size_or_sections=1, axis=-1)
x2, y2, w2, h2 = ops.split(box2, split_size_or_sections=1, axis=-1)
if use_fused_op:
b1_x1, b1_x2, b1_y1, b1_y2,b2_x1, b2_x2, b2_y1, b2_y2=fused_get_boundding_boxes_coord(x1, y1, w1, h1,x2, y2, w2, h2)
else:
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
else: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = ops.split(box1, split_size_or_sections=1, axis=-1)
b2_x1, b2_y1, b2_x2, b2_y2 = ops.split(box2, split_size_or_sections=1, axis=-1)

# Intersection area
if use_fused_op:
inter = fused_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2)
else:
inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0., None) * \
(ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1)).clip(0., None)
inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0., None) * \
(ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1)).clip(0., None)

# Union Area
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
if use_fused_op:
iou, union = fused_get_iou(w1, h1, w2, h2, inter)
else:
union = w1 * h1 + w2 * h2 - inter + eps # Union Area
iou = inter / union # IoU
union = w1 * h1 + w2 * h2 - inter + eps

# IoU
iou = inter / union

if CIoU or DIoU or GIoU:
cw = ops.maximum(b1_x2, b2_x2) - ops.minimum(b1_x1, b2_x1) # convex (smallest enclosing box) width
ch = ops.maximum(b1_y2, b2_y2) - ops.minimum(b1_y1, b2_y1) # convex height
if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
if use_fused_op:
c2 = fused_get_convex_diagonal_squared(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2)
else:
c2 = cw**2 + ch**2 + eps # convex diagonal squared
if use_fused_op:
rho2 = fused_get_center_dist(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2)
else:
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
if use_fused_op:
v = fused_get_ciou_diagonal_angle(w1, h1, w2, h2)
else:
# v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2)
v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2)
if use_fused_op:
_, res = fused_get_ciou(v, iou, rho2, c2)
else:
alpha = v / (v - iou + (1 + eps))
alpha = ops.stop_gradient(alpha)
res = iou - (rho2 / c2 + v * alpha) # CIoU
return res
# v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2)
v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2)
alpha = v / (v - iou + (1 + eps))
alpha = ops.stop_gradient(alpha)
return iou - (rho2 / c2 + v * alpha) # CIoU
return iou - rho2 / c2 # DIoU
c_area = cw * ch + eps # convex area
return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
Expand Down
5 changes: 2 additions & 3 deletions mindyolo/models/losses/yolov3_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@register_model
class YOLOv3Loss(nn.Cell):
def __init__(
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs
):
super(YOLOv3Loss, self).__init__()
self.hyp_box = box
Expand Down Expand Up @@ -63,7 +63,6 @@ def __init__(
)

self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print
self.use_fused_op = use_fused_op

def construct(self, p, targets, imgs):
lcls, lbox, lobj = 0.0, 0.0, 0.0
Expand Down Expand Up @@ -94,7 +93,7 @@ def construct(self, p, targets, imgs):
pxy = ops.Sigmoid()(pxy) * 2 - 0.5
pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index]
pbox = ops.concat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target)
iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target)
# iou = iou * tmask
# lbox += ((1.0 - iou) * tmask).mean() # iou loss
lbox += (((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None)).astype(iou.dtype)
Expand Down
7 changes: 3 additions & 4 deletions mindyolo/models/losses/yolov4_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def construct(self, object_mask, predict_confidence, ignore_mask):

@register_model
class YOLOv4Loss(nn.Cell):
def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, use_fused_op=False, **kwargs):
def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, **kwargs):
super(YOLOv4Loss, self).__init__()
self.ignore_threshold = ignore_threshold
self.iou = Iou()
Expand All @@ -57,7 +57,6 @@ def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_thresho

self.concat = ops.Concat(axis=-1)
self.reduce_max = ops.ReduceMax(keep_dims=False)
self.use_fused_op = use_fused_op

def construct(self, p, targets, imgs):
image_shape = imgs.shape
Expand Down Expand Up @@ -95,7 +94,7 @@ def construct(self, p, targets, imgs):

# Regression
pbox = ops.concat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox, tbox, GIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target)
iou = bbox_iou(pbox, tbox, GIoU=True).squeeze() # iou(prediction, target)
# iou = iou * tmask
# lbox += ((1.0 - iou) * tmask).mean() # iou loss
box_loss_scale = 2 - tbox[:, 2] * tbox[:, 3] / gain[0] / gain[1]
Expand Down Expand Up @@ -161,7 +160,7 @@ def build_targets(self, p, targets, imgs):

anchor_shapes = ops.zeros((na, 1, 4), ms.float32)
anchor_shapes[..., 2:] = ops.ExpandDims()(self.anchors, 1)
anch_ious = bbox_iou(gt_box, anchor_shapes, use_fused_op=self.use_fused_op).squeeze()
anch_ious = bbox_iou(gt_box, anchor_shapes).squeeze()

j = anch_ious == anch_ious.max(axis=0)
l = anch_ious > self.iou_threshold
Expand Down
5 changes: 2 additions & 3 deletions mindyolo/models/losses/yolov5_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class YOLOv5Loss(nn.Cell):
# Compute losses
def __init__(
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs
):
super(YOLOv5Loss, self).__init__()

Expand Down Expand Up @@ -64,7 +64,6 @@ def __init__(
)

self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print
self.use_fused_op = use_fused_op

def scatter_index_tensor(self, x, index):
x_tmp = ops.transpose(x.reshape((-1, x.shape[-1])), (1, 0))
Expand Down Expand Up @@ -102,7 +101,7 @@ def construct(self, p, targets, imgs): # predictions, targets
pxy = ops.Sigmoid()(pxy) * 2 - 0.5
pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index]
pbox = ops.concat((pxy, pwh), 1) # predicted box
iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target)
iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target)
lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum() # iou loss

# Objectness
Expand Down
12 changes: 5 additions & 7 deletions mindyolo/models/losses/yolov7_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@register_model
class YOLOv7Loss(nn.Cell):
def __init__(
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs
):
super(YOLOv7Loss, self).__init__()
self.hyp_box = box
Expand Down Expand Up @@ -63,7 +63,6 @@ def __init__(
)

self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print
self.use_fused_op = use_fused_op

def construct(self, p, targets, imgs):
lcls, lbox, lobj = 0.0, 0.0, 0.0
Expand Down Expand Up @@ -99,7 +98,7 @@ def construct(self, p, targets, imgs):
pbox = ops.concat((pxy, pwh), 1) # predicted box
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
selected_tbox[:, :2] -= grid
iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1)
iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1)
lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss

# Objectness
Expand Down Expand Up @@ -365,7 +364,7 @@ def find_3_positive(self, p, targets):
@register_model
class YOLOv7AuxLoss(nn.Cell):
def __init__(
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op, **kwargs
self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs
):
super(YOLOv7AuxLoss, self).__init__()
self.hyp_box = box
Expand Down Expand Up @@ -417,7 +416,6 @@ def __init__(
)

self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print
self.use_fused_op = use_fused_op

def construct(self, p, targets, imgs):
lcls, lbox, lobj = 0.0, 0.0, 0.0
Expand Down Expand Up @@ -473,7 +471,7 @@ def construct(self, p, targets, imgs):
pbox = ops.concat((pxy, pwh), 1) # predicted box
selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
selected_tbox[:, :2] -= grid
iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1)
iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1)
lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss
# 1.2. Objectness
tobj[b, a, gj, gi] = ((1.0 - self.gr) + self.gr * ops.stop_gradient(iou).clip(0, None)) * tmask # iou ratio
Expand All @@ -496,7 +494,7 @@ def construct(self, p, targets, imgs):
pbox_aux = ops.concat((pxy_aux, pwh_aux), 1) # predicted box
selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains[i]
selected_tbox_aux[:, :2] -= grid_aux
iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1)
iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True).view(-1)
lbox += (
0.25 * ((1.0 - iou_aux) * tmask_aux).sum() / tmask_aux.astype(iou_aux.dtype).sum().clip(1, None)
) # iou loss
Expand Down
Loading

0 comments on commit 3855f0a

Please sign in to comment.