diff --git a/docs/en/installation.md b/docs/en/installation.md index 3086abcf..5d68687b 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -64,8 +64,9 @@ cd mindyolo/csrc sh build.sh ``` -We also provide fused GPU operators which are built upon MindSpore [ops.Custom](https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html) API. The fused GPU operators are able to improve train speed. The source code is provided in C++ and CUDA and is in the folder `mindyolo/models/losses/fused_op`. Before using it, you shall try compiling the source code to dynamic link libraries with the following commands **(This operation is optional)** : +We also provide fused GPU operators which are built upon MindSpore [ops.Custom](https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html) API. The fused GPU operators are able to improve train speed. The source code is provided in C++ and CUDA and is in the folder `examples/custom_gpu_op/`. To enable this feature in the GPU training process, you shall modify the method `bbox_iou` in `mindyolo/models/losses/iou_loss.py` by referring to the demo script `examples/custom_gpu_op/iou_loss_fused.py`. Before runing `iou_loss_fused.py`, you shall compile the C++ and CUDA source code to dynamic link libraries with the following commands **(This operation is optional)** : ```shell +cp -rf examples/custom_gpu_op/* mindyolo/models/losses/ bash mindyolo/models/losses/fused_op/build.sh ``` diff --git a/docs/zh/installation.md b/docs/zh/installation.md index 744defae..5759dc8b 100644 --- a/docs/zh/installation.md +++ b/docs/zh/installation.md @@ -63,8 +63,8 @@ cd mindyolo/csrc sh build.sh ``` -我们还提供了基于MindSpore [Custom自定义算子](https://www.mindspore.cn/tutorials/experts/zh-CN/master/operation/op_custom.html) 的GPU融合算子,用于提升训练过程的速度。代码采用C++和CUDA开发,位于`mindyolo/models/losses/fused_op`路径下。使用该特性前,需要使用以下的命令,编译生成GPU融合算子运行所依赖的动态库 **(此操作是可选的)** : +我们还提供了基于MindSpore [Custom自定义算子](https://www.mindspore.cn/tutorials/experts/zh-CN/master/operation/op_custom.html) 的GPU融合算子,用于提升训练过程的速度。代码采用C++和CUDA开发,位于`examples/custom_gpu_op/`路径下。您可参考示例脚本`examples/custom_gpu_op/iou_loss_fused.py`,修改`mindyolo/models/losses/iou_loss.py`的`bbox_iou`方法,在GPU训练过程中使用该特性。运行`iou_loss_fused.py`前,需要使用以下的命令,编译生成GPU融合算子运行所依赖的动态库 **(此操作是可选的)** : ```shell -bash mindyolo/models/losses/fused_op/build.sh +bash examples/custom_gpu_op/fused_op/build.sh ``` diff --git a/mindyolo/models/losses/fused_op/__init__.py b/examples/custom_gpu_op/fused_op/__init__.py similarity index 100% rename from mindyolo/models/losses/fused_op/__init__.py rename to examples/custom_gpu_op/fused_op/__init__.py diff --git a/mindyolo/models/losses/fused_op/build.sh b/examples/custom_gpu_op/fused_op/build.sh similarity index 100% rename from mindyolo/models/losses/fused_op/build.sh rename to examples/custom_gpu_op/fused_op/build.sh diff --git a/mindyolo/models/losses/fused_op/elementswise_op_impl.cu b/examples/custom_gpu_op/fused_op/elementswise_op_impl.cu similarity index 100% rename from mindyolo/models/losses/fused_op/elementswise_op_impl.cu rename to examples/custom_gpu_op/fused_op/elementswise_op_impl.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_boundding_boxes_coord_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_boundding_boxes_coord_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_center_dist_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_center_dist_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_ciou_diagonal_angle_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_ciou_diagonal_angle_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_ciou_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_ciou_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_convex_diagonal_squared_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_convex_diagonal_squared_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_intersection_area_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_intersection_area_kernel.cu diff --git a/mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu b/examples/custom_gpu_op/fused_op/fused_get_iou_kernel.cu similarity index 100% rename from mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu rename to examples/custom_gpu_op/fused_op/fused_get_iou_kernel.cu diff --git a/examples/custom_gpu_op/iou_loss_fused.py b/examples/custom_gpu_op/iou_loss_fused.py new file mode 100644 index 00000000..673a863c --- /dev/null +++ b/examples/custom_gpu_op/iou_loss_fused.py @@ -0,0 +1,57 @@ +import numpy as np + +from mindspore import ops, Tensor + +from .fused_op import fused_get_ciou, fused_get_center_dist, fused_get_iou, \ + fused_get_convex_diagonal_squared, fused_get_ciou_diagonal_angle, \ + fused_get_boundding_boxes_coord, fused_get_intersection_area + + +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + """ + Return intersection-over-union (IoU) of boxes. + Arguments: + box1 (Tensor[N, 4]) or (Tensor[bs, N, 4]) + box2 (Tensor[N, 4]) or (Tensor[bs, N, 4]) + xywh (bool): Whether the box format is (x_center, y_center, w, h) or (x1, y1, x2, y2). Default: True. + GIoU (bool): Whether to use GIoU. Default: False. + DIoU (bool): Whether to use DIoU. Default: False. + CIoU (bool): Whether to use CIoU. Default: False. + Returns: + iou (Tensor[N,]): the IoU values for every element in boxes1 and boxes2 + """ + # Get the coordinates of bounding boxes + if xywh: # transform from xywh to xyxy + x1, y1, w1, h1 = ops.split(box1, split_size_or_sections=1, axis=-1) + x2, y2, w2, h2 = ops.split(box2, split_size_or_sections=1, axis=-1) + b1_x1, b1_x2, b1_y1, b1_y2,b2_x1, b2_x2, b2_y1, b2_y2=fused_get_boundding_boxes_coord(x1, y1, w1, h1,x2, y2, w2, h2) + else: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = ops.split(box1, split_size_or_sections=1, axis=-1) + b2_x1, b2_y1, b2_x2, b2_y2 = ops.split(box2, split_size_or_sections=1, axis=-1) + + # Intersection area + inter = fused_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) + + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + iou, union = fused_get_iou(w1, h1, w2, h2, inter) + + if CIoU or DIoU or GIoU: + cw = ops.maximum(b1_x2, b2_x2) - ops.minimum(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = ops.maximum(b1_y2, b2_y2) - ops.minimum(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = fused_get_convex_diagonal_squared(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) + rho2 = fused_get_center_dist(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2) + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = fused_get_ciou_diagonal_angle(w1, h1, w2, h2) + _, res = fused_get_ciou(v, iou, rho2, c2) + return res + return iou - rho2 / c2 # DIoU + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou # IoU + +if __name__ =="__main__": + box1 = Tensor(np.random.rand(32, 4).astype(np.float32)) + box2 = Tensor(np.random.rand(32, 4).astype(np.float32)) + iou = bbox_iou(box1, box2, xywh=True, CIoU=True) diff --git a/mindyolo/models/losses/iou_loss.py b/mindyolo/models/losses/iou_loss.py index 1cb466a0..f3227778 100644 --- a/mindyolo/models/losses/iou_loss.py +++ b/mindyolo/models/losses/iou_loss.py @@ -5,10 +5,6 @@ from mindyolo.models.layers.utils import box_cxcywh_to_xyxy -from .fused_op import fused_get_ciou, fused_get_center_dist, fused_get_iou, \ - fused_get_convex_diagonal_squared, fused_get_ciou_diagonal_angle, \ - fused_get_boundding_boxes_coord, fused_get_intersection_area - PI = Tensor(math.pi, ms.float32) EPS = 1e-7 @@ -101,7 +97,7 @@ def batch_box_iou(batch_box1, batch_box2, xywh=False): ) # iou = inter / (area1 + area2 - inter) -def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7, use_fused_op=False): +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): """ Return intersection-over-union (IoU) of boxes. Arguments: @@ -111,64 +107,45 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 GIoU (bool): Whether to use GIoU. Default: False. DIoU (bool): Whether to use DIoU. Default: False. CIoU (bool): Whether to use CIoU. Default: False. - use_fused_op(bool): Whether to use fused operator built upon aot customized operator. Default: False. Returns: iou (Tensor[N,]): the IoU values for every element in boxes1 and boxes2 """ + # Get the coordinates of bounding boxes if xywh: # transform from xywh to xyxy x1, y1, w1, h1 = ops.split(box1, split_size_or_sections=1, axis=-1) x2, y2, w2, h2 = ops.split(box2, split_size_or_sections=1, axis=-1) - if use_fused_op: - b1_x1, b1_x2, b1_y1, b1_y2,b2_x1, b2_x2, b2_y1, b2_y2=fused_get_boundding_boxes_coord(x1, y1, w1, h1,x2, y2, w2, h2) - else: - w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 - b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ - b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ else: # x1, y1, x2, y2 = box1 b1_x1, b1_y1, b1_x2, b1_y2 = ops.split(box1, split_size_or_sections=1, axis=-1) b2_x1, b2_y1, b2_x2, b2_y2 = ops.split(box2, split_size_or_sections=1, axis=-1) # Intersection area - if use_fused_op: - inter = fused_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) - else: - inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0., None) * \ - (ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1)).clip(0., None) + inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0., None) * \ + (ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1)).clip(0., None) + # Union Area w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps - if use_fused_op: - iou, union = fused_get_iou(w1, h1, w2, h2, inter) - else: - union = w1 * h1 + w2 * h2 - inter + eps # Union Area - iou = inter / union # IoU + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union if CIoU or DIoU or GIoU: cw = ops.maximum(b1_x2, b2_x2) - ops.minimum(b1_x1, b2_x1) # convex (smallest enclosing box) width ch = ops.maximum(b1_y2, b2_y2) - ops.minimum(b1_y1, b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - if use_fused_op: - c2 = fused_get_convex_diagonal_squared(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) - else: - c2 = cw**2 + ch**2 + eps # convex diagonal squared - if use_fused_op: - rho2 = fused_get_center_dist(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2) - else: - rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + c2 = cw**2 + ch**2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - if use_fused_op: - v = fused_get_ciou_diagonal_angle(w1, h1, w2, h2) - else: - # v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) - v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) - if use_fused_op: - _, res = fused_get_ciou(v, iou, rho2, c2) - else: - alpha = v / (v - iou + (1 + eps)) - alpha = ops.stop_gradient(alpha) - res = iou - (rho2 / c2 + v * alpha) # CIoU - return res + # v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) + v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) + alpha = v / (v - iou + (1 + eps)) + alpha = ops.stop_gradient(alpha) + return iou - (rho2 / c2 + v * alpha) # CIoU return iou - rho2 / c2 # DIoU c_area = cw * ch + eps # convex area return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf diff --git a/mindyolo/models/losses/yolov3_loss.py b/mindyolo/models/losses/yolov3_loss.py index d846eeb9..516f06bc 100644 --- a/mindyolo/models/losses/yolov3_loss.py +++ b/mindyolo/models/losses/yolov3_loss.py @@ -17,7 +17,7 @@ @register_model class YOLOv3Loss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs ): super(YOLOv3Loss, self).__init__() self.hyp_box = box @@ -63,7 +63,6 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print - self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -94,7 +93,7 @@ def construct(self, p, targets, imgs): pxy = ops.Sigmoid()(pxy) * 2 - 0.5 pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index] pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target) # iou = iou * tmask # lbox += ((1.0 - iou) * tmask).mean() # iou loss lbox += (((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None)).astype(iou.dtype) diff --git a/mindyolo/models/losses/yolov4_loss.py b/mindyolo/models/losses/yolov4_loss.py index 7fa5d180..51abf2cb 100644 --- a/mindyolo/models/losses/yolov4_loss.py +++ b/mindyolo/models/losses/yolov4_loss.py @@ -31,7 +31,7 @@ def construct(self, object_mask, predict_confidence, ignore_mask): @register_model class YOLOv4Loss(nn.Cell): - def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, use_fused_op=False, **kwargs): + def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, **kwargs): super(YOLOv4Loss, self).__init__() self.ignore_threshold = ignore_threshold self.iou = Iou() @@ -57,7 +57,6 @@ def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_thresho self.concat = ops.Concat(axis=-1) self.reduce_max = ops.ReduceMax(keep_dims=False) - self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): image_shape = imgs.shape @@ -95,7 +94,7 @@ def construct(self, p, targets, imgs): # Regression pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox, GIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox, GIoU=True).squeeze() # iou(prediction, target) # iou = iou * tmask # lbox += ((1.0 - iou) * tmask).mean() # iou loss box_loss_scale = 2 - tbox[:, 2] * tbox[:, 3] / gain[0] / gain[1] @@ -161,7 +160,7 @@ def build_targets(self, p, targets, imgs): anchor_shapes = ops.zeros((na, 1, 4), ms.float32) anchor_shapes[..., 2:] = ops.ExpandDims()(self.anchors, 1) - anch_ious = bbox_iou(gt_box, anchor_shapes, use_fused_op=self.use_fused_op).squeeze() + anch_ious = bbox_iou(gt_box, anchor_shapes).squeeze() j = anch_ious == anch_ious.max(axis=0) l = anch_ious > self.iou_threshold diff --git a/mindyolo/models/losses/yolov5_loss.py b/mindyolo/models/losses/yolov5_loss.py index a1321d06..890cefef 100644 --- a/mindyolo/models/losses/yolov5_loss.py +++ b/mindyolo/models/losses/yolov5_loss.py @@ -15,7 +15,7 @@ class YOLOv5Loss(nn.Cell): # Compute losses def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs ): super(YOLOv5Loss, self).__init__() @@ -64,7 +64,6 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print - self.use_fused_op = use_fused_op def scatter_index_tensor(self, x, index): x_tmp = ops.transpose(x.reshape((-1, x.shape[-1])), (1, 0)) @@ -102,7 +101,7 @@ def construct(self, p, targets, imgs): # predictions, targets pxy = ops.Sigmoid()(pxy) * 2 - 0.5 pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index] pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum() # iou loss # Objectness diff --git a/mindyolo/models/losses/yolov7_loss.py b/mindyolo/models/losses/yolov7_loss.py index 204d1dca..46258369 100644 --- a/mindyolo/models/losses/yolov7_loss.py +++ b/mindyolo/models/losses/yolov7_loss.py @@ -17,7 +17,7 @@ @register_model class YOLOv7Loss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs ): super(YOLOv7Loss, self).__init__() self.hyp_box = box @@ -63,7 +63,6 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print - self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -99,7 +98,7 @@ def construct(self, p, targets, imgs): pbox = ops.concat((pxy, pwh), 1) # predicted box selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] selected_tbox[:, :2] -= grid - iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) + iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss # Objectness @@ -365,7 +364,7 @@ def find_3_positive(self, p, targets): @register_model class YOLOv7AuxLoss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs ): super(YOLOv7AuxLoss, self).__init__() self.hyp_box = box @@ -417,7 +416,6 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print - self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -473,7 +471,7 @@ def construct(self, p, targets, imgs): pbox = ops.concat((pxy, pwh), 1) # predicted box selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] selected_tbox[:, :2] -= grid - iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) + iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss # 1.2. Objectness tobj[b, a, gj, gi] = ((1.0 - self.gr) + self.gr * ops.stop_gradient(iou).clip(0, None)) * tmask # iou ratio @@ -496,7 +494,7 @@ def construct(self, p, targets, imgs): pbox_aux = ops.concat((pxy_aux, pwh_aux), 1) # predicted box selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains[i] selected_tbox_aux[:, :2] -= grid_aux - iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) + iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True).view(-1) lbox += ( 0.25 * ((1.0 - iou_aux) * tmask_aux).sum() / tmask_aux.astype(iou_aux.dtype).sum().clip(1, None) ) # iou loss diff --git a/mindyolo/models/losses/yolov8_loss.py b/mindyolo/models/losses/yolov8_loss.py index 2ab47863..2e668d33 100644 --- a/mindyolo/models/losses/yolov8_loss.py +++ b/mindyolo/models/losses/yolov8_loss.py @@ -14,7 +14,7 @@ @register_model class YOLOv8Loss(nn.Cell): - def __init__(self, box, cls, dfl, stride, nc, reg_max=16, use_fused_op=False, **kwargs): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, **kwargs): super(YOLOv8Loss, self).__init__() self.bce = nn.BCEWithLogitsLoss(reduction="none") @@ -27,8 +27,8 @@ def __init__(self, box, cls, dfl, stride, nc, reg_max=16, use_fused_op=False, ** self.reg_max = reg_max self.use_dfl = reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0, use_fused_op=use_fused_op) - self.bbox_loss = BboxLoss(reg_max, use_dfl=self.use_dfl, use_fused_op=use_fused_op) + self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(reg_max, use_dfl=self.use_dfl) self.proj = mnp.arange(reg_max) # ops @@ -154,11 +154,10 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): class BboxLoss(nn.Cell): - def __init__(self, reg_max, use_dfl=False, use_fused_op=False): + def __init__(self, reg_max, use_dfl=False): super().__init__() self.reg_max = reg_max self.use_dfl = use_dfl - self.use_fused_op = use_fused_op def construct( self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask @@ -175,7 +174,7 @@ def construct( """ # IoU loss weight = target_scores.sum(-1).expand_dims(-1) # (bs, N, num_classes) -> (bs, N) -> (bs, N, 1) - iou = bbox_iou(pred_bboxes, target_bboxes, xywh=False, CIoU=True, use_fused_op=self.use_fused_op) + iou = bbox_iou(pred_bboxes, target_bboxes, xywh=False, CIoU=True) loss_iou = ((1.0 - iou) * weight * fg_mask.expand_dims(2)).sum() / target_scores_sum # DFL loss @@ -220,7 +219,7 @@ def _df_loss(pred_dist, target): class TaskAlignedAssigner(nn.Cell): - def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, use_fused_op=False): + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): super().__init__() self.topk = topk self.num_classes = num_classes @@ -228,7 +227,6 @@ def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, use_f self.alpha = alpha self.beta = beta self.eps = eps - self.use_fused_op=use_fused_op def construct(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): """This code referenced to @@ -312,7 +310,7 @@ def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): # (b, n_gt, 1, 4), (b, 1, N, 4) -> (b, n_gt, N) overlaps = ( - bbox_iou(gt_bboxes.expand_dims(2), pd_bboxes.expand_dims(1), xywh=False, CIoU=True, use_fused_op=self.use_fused_op).squeeze(3).clip(0, None) + bbox_iou(gt_bboxes.expand_dims(2), pd_bboxes.expand_dims(1), xywh=False, CIoU=True).squeeze(3).clip(0, None) ) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) return align_metric, overlaps diff --git a/mindyolo/models/losses/yolox_loss.py b/mindyolo/models/losses/yolox_loss.py index 55c3ae56..1ebc77d1 100644 --- a/mindyolo/models/losses/yolox_loss.py +++ b/mindyolo/models/losses/yolox_loss.py @@ -23,7 +23,6 @@ def __init__( strides=(8, 16, 32), use_l1=False, use_summary=False, - use_fused_op=False, **kwargs ): super(YOLOXLoss, self).__init__() @@ -53,7 +52,6 @@ def __init__( self.assign = ops.Assign() self.loss_item_name = ["loss", "lbox", "lobj", "lcls", "lboxl1"] # branch name returned by lossitem for print - self.use_fused_op = use_fused_op def _get_anchor_center_and_stride(self, norm=False): """ @@ -257,7 +255,7 @@ def construct(self, preds, targets, imgs=None): loss_l1 = ops.reduce_sum(self.l1_loss(l1_preds, l1_target), -1) * obj_target loss_l1 = ops.reduce_sum(loss_l1) # calculate target -----------END------------------------------------------------------------------------------- - iou = bbox_iou(bbox_preds.reshape(-1, 4), reg_target.reshape(-1, 4), xywh=True, use_fused_op=self.use_fused_op).reshape(batch_size, -1) + iou = bbox_iou(bbox_preds.reshape(-1, 4), reg_target.reshape(-1, 4), xywh=True).reshape(batch_size, -1) loss_iou = (1 - iou * iou) * obj_target # (bs, num_total_anchor) loss_iou = ops.reduce_sum(loss_iou) diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index b78086b9..53b74b7a 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -87,19 +87,6 @@ def set_default(args): args.data.test_set = os.path.join(args.data_dir, args.data.test_set) args.weight = args.ckpt_dir if args.ckpt_dir else "" args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else "" - - # Check Custom operator settings. - if args.use_fused_op: - if args.device_target != "GPU": - logger.warning( - "mindyolo only support aot custom operator on GPU currently, please check configurations" - ) - args.use_fused_op = False - else: - logger.warning( - "aot Custom operator enabled, please confirm that the compilation script \ - (mindyolo\\models\\losses\\fused_op\\build.sh) has been executed properly." - ) def load_pretrain(network, weight, ema=None, ema_weight=None): diff --git a/setup.py b/setup.py index 13d2ea0f..5a210fc1 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,7 @@ #!/usr/bin/env python import os.path -import subprocess import pathlib import sys -import glob from setuptools import find_packages, setup @@ -30,26 +28,6 @@ def parse_requirements(path=here / "requirements.txt"): return pkgs -def compile_fused_op(path=here / "mindyolo/models/losses/fused_op"): - try: - check_txt = subprocess.run(["nvcc", "--version"], timeout=3, capture_output=True, check=True).stdout - if "command not found" in str(check_txt): - print("nvcc not found, skipped compiling fused operator.") - return - for fused_op_src in glob.glob(str(path / "*_kernel.cu")): - fused_op_so = f"{fused_op_src[:-3]}.so" - so_path = str(path / fused_op_so) - nvcc_cmd = "nvcc --shared -Xcompiler -fPIC -o " + so_path + " " + fused_op_src - print("nvcc compiler cmd: {}".format(nvcc_cmd)) - os.system(nvcc_cmd) - except FileNotFoundError: - print("nvcc not found, skipped compiling fused operator.") - return - except subprocess.CalledProcessError as e: - print("nvcc execute failed, skipped compiling fused operator: ", e) - return - - # add c++ extension ext_modules = [] try: @@ -65,7 +43,6 @@ def compile_fused_op(path=here / "mindyolo/models/losses/fused_op"): ] except ImportError: pass -compile_fused_op() setup( name="mindyolo", author="MindSpore Ecosystem", diff --git a/train.py b/train.py index 12247d22..393ba490 100644 --- a/train.py +++ b/train.py @@ -81,8 +81,6 @@ def get_parser_train(parents=None): help="ModelArts: local device path to dataset folder") parser.add_argument("--ckpt_dir", type=str, default="/cache/pretrain_ckpt/", help="ModelArts: local device path to checkpoint folder") - parser.add_argument("--use_fused_op", type=ast.literal_eval, default=False, - help="Whether to use aot custom operator to accelerate GPU computation") return parser @@ -187,8 +185,7 @@ def train(args): # Create Loss loss_fn = create_loss( - **args.loss, anchors=args.network.get("anchors", 1), stride=args.network.stride, nc=args.data.nc, - use_fused_op=args.use_fused_op + **args.loss, anchors=args.network.get("anchors", 1), stride=args.network.stride, nc=args.data.nc ) ms.amp.auto_mixed_precision(loss_fn, amp_level="O0" if args.keep_loss_fp32 else args.ms_amp_level)