From 3abaf62c1e4a0f228a96590cee3f331c121ca54e Mon Sep 17 00:00:00 2001 From: panshaowu Date: Sun, 30 Jul 2023 17:52:18 +0800 Subject: [PATCH] add aot custom op to accelerate computing bbox_iou on GPU --- docs/en/installation.md | 6 + docs/zh/installation.md | 8 +- mindyolo/models/losses/fused_op/__init__.py | 272 +++++++++++++++ mindyolo/models/losses/fused_op/build.sh | 7 + .../losses/fused_op/elementswise_op_impl.cu | 323 ++++++++++++++++++ .../fused_get_boundding_boxes_coord_kernel.cu | 178 ++++++++++ .../fused_op/fused_get_center_dist_kernel.cu | 154 +++++++++ .../fused_get_ciou_diagonal_angle_kernel.cu | 120 +++++++ .../losses/fused_op/fused_get_ciou_kernel.cu | 120 +++++++ ...used_get_convex_diagonal_squared_kernel.cu | 161 +++++++++ .../fused_get_intersection_area_kernel.cu | 181 ++++++++++ .../losses/fused_op/fused_get_iou_kernel.cu | 126 +++++++ mindyolo/models/losses/iou_loss.py | 169 ++++++++- mindyolo/models/losses/yolov3_loss.py | 5 +- mindyolo/models/losses/yolov4_loss.py | 7 +- mindyolo/models/losses/yolov5_loss.py | 5 +- mindyolo/models/losses/yolov7_loss.py | 12 +- mindyolo/models/losses/yolov8_loss.py | 16 +- mindyolo/models/losses/yolox_loss.py | 4 +- mindyolo/utils/utils.py | 5 + setup.py | 28 +- train.py | 5 +- 22 files changed, 1868 insertions(+), 44 deletions(-) create mode 100644 mindyolo/models/losses/fused_op/__init__.py create mode 100644 mindyolo/models/losses/fused_op/build.sh create mode 100644 mindyolo/models/losses/fused_op/elementswise_op_impl.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu create mode 100644 mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu diff --git a/docs/en/installation.md b/docs/en/installation.md index ab0a5b57..debd8500 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -63,3 +63,9 @@ In addition, we provide an optional [fast coco api](https://github.com/facebookr cd mindyolo/csrc sh build.sh ``` + +We also provide fused GPU operators which are built upon MindSpore [ops.Custom](https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html) API. The fused GPU operators are able to improve train speed. The source code is provided in C++ and CUDA and is in the folder `mindyolo/models/losses/fused_op`. You can try compiling the source code to dynamic link libraries with the following commands, **(This operation is optional)** : + +```shell +bash mindyolo/models/losses/fused_op/build.sh +``` diff --git a/docs/zh/installation.md b/docs/zh/installation.md index 80996a82..f30806f5 100644 --- a/docs/zh/installation.md +++ b/docs/zh/installation.md @@ -56,9 +56,15 @@ cd mindyolo pip install -e . ``` -另外, 我们提供了一个可选的 [fast coco api](https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py) 接口用于提升验证过程的速度。代码是以C++形式提供的,可以尝试用以下的命令进行安装 **(此操作是可选的)** : +我们提供了一个可选的 [fast coco api](https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py) 接口用于提升验证过程的速度。代码是以C++形式提供的,可以尝试用以下的命令进行安装 **(此操作是可选的)** : ```shell cd mindyolo/csrc sh build.sh ``` + +我们还提供了基于MindSpore [Custom自定义算子](https://www.mindspore.cn/tutorials/experts/zh-CN/master/operation/op_custom.html) 的GPU融合算子,用于提升训练过程的速度。代码采用C++和CUDA开发,位于`mindyolo/models/losses/fused_op`路径下。可以使用以下的命令,编译生成GPU融合算子运行所依赖的动态库,用于调测 **(此操作是可选的)** : + +```shell +bash mindyolo/models/losses/fused_op/build.sh +``` diff --git a/mindyolo/models/losses/fused_op/__init__.py b/mindyolo/models/losses/fused_op/__init__.py new file mode 100644 index 00000000..b10d830f --- /dev/null +++ b/mindyolo/models/losses/fused_op/__init__.py @@ -0,0 +1,272 @@ +import os + +from mindspore.ops import DataType, CustomRegOp + + +fused_op_list = ['fused_get_ciou_kernel', 'fused_get_center_dist_kernel', 'fused_get_convex_diagonal_squared_kernel', + 'fused_get_ciou_diagonal_angle_kernel','fused_get_boundding_boxes_coord_kernel', + 'fused_get_intersection_area_kernel'] +fused_ops_dir = os.path.dirname(__file__) +for fused_op_item in fused_op_list: + so_path = fused_ops_dir + '/' + fused_op_item + '.so' + if not os.path.exists(so_path): + cu_path = fused_ops_dir + '/' + fused_op_item + '.cu' + nvcc_cmd = 'nvcc --shared -Xcompiler -fPIC -o ' + so_path + ' ' + cu_path + print("nvcc compiler cmd: {}".format(nvcc_cmd)) + os.system(nvcc_cmd) + +fused_get_ciou_op_path = fused_ops_dir + "/fused_get_ciou_kernel.so" + ":FusedGetCiou" +fused_get_ciou_op_bprop_path = fused_ops_dir + "/fused_get_ciou_kernel.so" + ":FusedGetCiouBprop" +fused_get_center_dist_op_path = fused_ops_dir + "/fused_get_center_dist_kernel.so" + ":FusedGetCenterDist" +fused_get_center_dist_op_bprop_path = fused_ops_dir + "/fused_get_center_dist_kernel.so" + ":FusedGetCenterDistBprop" +fused_get_convex_diagonal_squared_path = fused_ops_dir + "/fused_get_convex_diagonal_squared_kernel.so" + ":FusedGetConvexDiagonalSquared" +fused_get_convex_diagonal_squared_grad_path = fused_ops_dir + "/fused_get_convex_diagonal_squared_kernel.so" + ":FusedGetConvexDiagonalSquaredGrad" +fused_get_ciou_diagonal_angle_path = fused_ops_dir + "/fused_get_ciou_diagonal_angle_kernel.so" + ":FusedGetCiouDiagonalAngle" +fused_get_ciou_diagonal_angle_grad_path = fused_ops_dir + "/fused_get_ciou_diagonal_angle_kernel.so" + ":FusedGetCiouDiagonalAngleGrad" +fused_get_boundding_boxes_coord_path = fused_ops_dir + "/fused_get_boundding_boxes_coord_kernel.so" + ":FusedGetBounddingBoxesCoord" +fused_get_boundding_boxes_coord_grad_path = fused_ops_dir+"/fused_get_boundding_boxes_coord_kernel.so" + ":FusedGetBounddingBoxesCoordGrad" +fused_get_intersection_area_path = fused_ops_dir + "/fused_get_intersection_area_kernel.so" + ":FusedGetIntersectionArea" +fused_get_intersection_area_grad_path = fused_ops_dir + "/fused_get_intersection_area_kernel.so" + ":FusedGetIntersectionAreaGrad" + + +fuse_get_ciou_gpu_info = CustomRegOp() \ + .input(0, "v") \ + .input(1, "iou") \ + .input(2, "rho2") \ + .input(3, "c2") \ + .output(0, "alpha") \ + .output(1, "out") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fuse_get_ciou_bprop_gpu_info = CustomRegOp() \ + .input(0, "v") \ + .input(1, "iou") \ + .input(2, "rho2") \ + .input(3, "c2") \ + .input(4, "d_alpha") \ + .input(5, "d_out") \ + .output(0, "d_v") \ + .output(1, "d_iou") \ + .output(2, "d_rho2") \ + .output(3, "d_c2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fuse_get_center_dist_gpu_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b1_y1") \ + .input(3, "b1_y2") \ + .input(4, "b2_x1") \ + .input(5, "b2_x2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .output(0, "out") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fuse_get_center_dist_bprop_gpu_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b1_y1") \ + .input(3, "b1_y2") \ + .input(4, "b2_x1") \ + .input(5, "b2_x2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .input(8, "d_out") \ + .output(0, "d_b1_x1") \ + .output(1, "d_b1_x2") \ + .output(2, "d_b1_y1") \ + .output(3, "d_b1_y2") \ + .output(4, "d_b2_x1") \ + .output(5, "d_b2_x2") \ + .output(6, "d_b2_y1") \ + .output(7, "d_b2_y2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_convex_diagonal_squared_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b2_x1") \ + .input(3, "b2_x2") \ + .input(4, "b1_y1") \ + .input(5, "b1_y2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .output(8, "out") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_convex_diagonal_squared_grad_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b2_x1") \ + .input(3, "b2_x2") \ + .input(4, "b1_y1") \ + .input(5, "b1_y2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .input(8, "dout") \ + .output(9, "d_b1_x1") \ + .output(10, "d_b1_x2") \ + .output(11, "d_b2_x1") \ + .output(12, "d_b2_x2") \ + .output(13, "d_b1_y1") \ + .output(14, "d_b1_y2") \ + .output(15, "d_b2_y1") \ + .output(16, "d_b2_y2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_ciou_diagonal_angle_info = CustomRegOp() \ + .input(0, "w1") \ + .input(1, "h1") \ + .input(2, "w2") \ + .input(3, "h2") \ + .output(4, "out") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_ciou_diagonal_angle_grad_info = CustomRegOp() \ + .input(0, "w1") \ + .input(1, "h1") \ + .input(2, "w2") \ + .input(3, "h2") \ + .input(4, "out") \ + .output(5, "w1_diff") \ + .output(6, "h1_diff") \ + .output(7, "w2_diff") \ + .output(8, "h2_diff") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_boundding_boxes_coord_gpu_info = CustomRegOp() \ + .input(0, "x1") \ + .input(1, "y1") \ + .input(2, "w1") \ + .input(3, "h1") \ + .input(4, "x2") \ + .input(5, "y2") \ + .input(6, "w2") \ + .input(7, "h2") \ + .output(0, "b1_x1") \ + .output(1, "b1_y1") \ + .output(2, "b1_x2") \ + .output(3, "b1_y2") \ + .output(4, "b2_x1") \ + .output(5, "b2_y1") \ + .output(6, "b2_x2") \ + .output(7, "b2_y2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_boundding_boxes_coord_bprop_gpu_info = CustomRegOp() \ + .input(0, "d_b1_x1") \ + .input(1, "d_b1_x2") \ + .input(2, "d_b1_y1") \ + .input(3, "d_b1_y2") \ + .input(4, "d_b2_x1") \ + .input(5, "d_b2_x2") \ + .input(6, "d_b2_y1") \ + .input(7, "d_b2_y2") \ + .output(0, "d_x1") \ + .output(1, "d_y1") \ + .output(2, "d_w1") \ + .output(3, "d_h1") \ + .output(4, "d_x2") \ + .output(5, "d_y2") \ + .output(6, "d_w2") \ + .output(7, "d_h2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_intersection_area_gpu_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b2_x1") \ + .input(3, "b2_x2") \ + .input(4, "b1_y1") \ + .input(5, "b1_y2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .output(8, "inter") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + + +fused_get_intersection_area_gpu_grad_info = CustomRegOp() \ + .input(0, "b1_x1") \ + .input(1, "b1_x2") \ + .input(2, "b2_x1") \ + .input(3, "b2_x2") \ + .input(4, "b1_y1") \ + .input(5, "b1_y2") \ + .input(6, "b2_y1") \ + .input(7, "b2_y2") \ + .input(8, "d_inter") \ + .output(9, "d_b1_x1") \ + .output(10, "d_b1_x2") \ + .output(11, "d_b2_x1") \ + .output(12, "d_b2_x2") \ + .output(13, "d_b1_y1") \ + .output(14, "d_b1_y2") \ + .output(15, "d_b2_y1") \ + .output(16, "d_b2_y2") \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default) \ + .target("GPU") \ + .get_op_info() + \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/build.sh b/mindyolo/models/losses/fused_op/build.sh new file mode 100644 index 00000000..d43c6e55 --- /dev/null +++ b/mindyolo/models/losses/fused_op/build.sh @@ -0,0 +1,7 @@ +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_intersection_area_kernel.so $(dirname $0)/fused_get_intersection_area_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_ciou_kernel.so $(dirname $0)/fused_get_ciou_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_ciou_diagonal_angle_kernel.so $(dirname $0)/fused_get_ciou_diagonal_angle_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_center_dist_kernel.so $(dirname $0)/fused_get_center_dist_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_boundding_boxes_coord_kernel.so $(dirname $0)/fused_get_boundding_boxes_coord_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_iou_kernel.so $(dirname $0)/fused_get_iou_kernel.cu +nvcc --shared -Xcompiler -fPIC -o $(dirname $0)/fused_get_convex_diagonal_squared_kernel.so $(dirname $0)/fused_get_convex_diagonal_squared_kernel.cu \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/elementswise_op_impl.cu b/mindyolo/models/losses/fused_op/elementswise_op_impl.cu new file mode 100644 index 00000000..85889725 --- /dev/null +++ b/mindyolo/models/losses/fused_op/elementswise_op_impl.cu @@ -0,0 +1,323 @@ +#include +#include +#include +namespace cuda +{ + namespace elementwise + { + // An empirical parameter + // In the mainstream GPU architecture, the maximum number of registers per block is 64K, + // the maximum number of registers that can be used by each thread is 255. + // So, kThreadsPerBlock = 64 * 1024 / 255 = 256. + // Refer from https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities + constexpr uint kThreadsPerBlock = 256; + // An empirical parameter + constexpr uint kWaves = 32; + constexpr uint kStride = 2; + + struct CudaConfig + { + int dev_{0}; + int sm_nums_{1}; + int max_threads_{1}; + }; + + // Get some necessary hardware config. + inline cudaError_t GetCurrentConfig(CudaConfig *config) + { + // 1. Get current device. + // 2. Get current sm_nums + // 3. Get the maximum resident threads in per multiprocessor. + int dev; + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) + { + return err; + } + int sm_nums; + err = cudaDeviceGetAttribute(&sm_nums, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) + { + return err; + } + int max_threads; + err = cudaDeviceGetAttribute(&max_threads, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + if (err != cudaSuccess) + { + return err; + } + config->dev_ = dev; + config->sm_nums_ = sm_nums; + config->max_threads_ = max_threads; + return err; + } + + // Get best blocks basing on parallel data size for current hardware, adaptively. + inline uint GetBestBlocks(uint n, const CudaConfig &config) + { + uint best_blocks = + std::max(1, std::min((n + kThreadsPerBlock - 1) / kThreadsPerBlock, + config.sm_nums_ * config.max_threads_ / kThreadsPerBlock * kWaves)); + return best_blocks; + } + + template + struct VectorizedTraitType + { + using type = typename std::aligned_storage::type; + }; + + template + using VectorizedType = typename VectorizedTraitType::type; + + template + union Vec + { + static_assert(sizeof(VectorizedType) == sizeof(T) * VecSize, "data can not be aligned."); + __device__ Vec() {} + VectorizedType storage_; + T elements_[VecSize]; + }; + + template + struct alignas(sizeof(T) * VecSize) AlignVec + { + T elements_[VecSize]; + }; + + constexpr uint kMaxVecBytes = 128 / 8; + constexpr uint kMaxVecSize = 8; + + constexpr uint MsMin(uint a, uint b) { return a < b ? a : b; } + + template + constexpr uint VecSize() + { + return MsMin(kMaxVecBytes / sizeof(T), kMaxVecSize); + } + + template + constexpr uint VecSize() + { + return MsMin(VecSize(), VecSize()); + } + + template + class CheckApply2 + { + typedef char apply_unit; + struct apply_struct + { + char x_[2]; + }; + + template + static apply_unit check(decltype(&IN3::Apply2)); + template + static apply_struct check(...); + + public: + enum + { + value = sizeof(check(0)) == sizeof(char) + }; + }; + + template + bool IsAligned() + { + return true; + } + + template + bool IsAligned(const T *ptr, const Args *...others) + { + return reinterpret_cast(ptr) % sizeof(Vec) == 0 && IsAligned(others...); + } + + template + __device__ typename std::enable_if::value == true && vec_size % kStride == 0, + AlignVec>::type + ApplyVec(const FunctorT &functor, const IN... in[vec_size]) + { + AlignVec ret; + +#pragma unroll + for (uint j = 0; j < vec_size; j += kStride) + { + functor.Apply2(ret.elements_ + j, (in + j)...); + } + return ret; + } + + template + __device__ typename std::enable_if::value == false || vec_size % kStride != 0, + AlignVec>::type + ApplyVec(const FunctorT &functor, const IN... in[vec_size]) + { + AlignVec ret; +#pragma unroll + for (uint j = 0; j < vec_size; ++j) + { + ret.elements_[j] = functor((in[j])...); + } + return ret; + } + + template + __global__ void __launch_bounds__(kThreadsPerBlock) + DoApply(Factory factory, uint vec_nums, AlignVec *vec_out, const AlignVec *...vec_in, + uint tail_nums, OUT *tail_out, const IN *...tail_in) + { + auto functor = factory(); + const uint global_tid = blockIdx.x * kThreadsPerBlock + threadIdx.x; + for (uint i = global_tid; i < vec_nums; i += blockDim.x * gridDim.x) + { + vec_out[i] = ApplyVec(functor, (vec_in[i].elements_)...); + } + if (tail && global_tid < tail_nums) + { + tail_out[global_tid] = functor((tail_in[global_tid])...); + } + } + + template + cudaError_t LaunchKernel(Factory factory, uint nums, OUT *out, const IN *...in, cudaStream_t stream) + { + const uint vec_nums = nums / vec_size; + const uint tail_offset = vec_nums * vec_size; + const uint tail_nums = nums - tail_offset; + CudaConfig config; + cudaError_t err = GetCurrentConfig(&config); + if (err != cudaSuccess) + { + return err; + } + uint num_blocks = GetBestBlocks(vec_nums, config); + dim3 block{kThreadsPerBlock}; + dim3 grid{uint(num_blocks)}; + if (tail_nums > 0) + { + auto func = DoApply; + func<<>>(factory, vec_nums, reinterpret_cast *>(out), + (reinterpret_cast *>(in))..., tail_nums, + out + tail_offset, (in + tail_offset)...); + } + else + { + auto func = DoApply; + func<<>>(factory, vec_nums, reinterpret_cast *>(out), + (reinterpret_cast *>(in))..., tail_nums, + out + tail_offset, (in + tail_offset)...); + } + return cudaPeekAtLastError(); + } + + template + struct DoLaunch + { + static cudaError_t Launch(Factory factory, uint n, OUT *out, const IN *...in, cudaStream_t stream) + { + constexpr uint max_pack_size = VecSize(); + if (IsAligned(out, in...)) + { + return LaunchKernel(factory, n, out, in..., stream); + } + return LaunchKernel<1, Factory, OUT, IN...>(factory, n, out, in..., stream); + } + }; + + template + struct TransitFactory + { + explicit TransitFactory(FunctorT functor) : transit_impl_(functor) {} + __device__ FunctorT operator()() const { return transit_impl_; } + + private: + FunctorT transit_impl_; + }; + + // API elementwise for input: a, output: out. + template + inline cudaError_t UnaryTransit(Factory factory, uint n, OUT *out, const IN *in, cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, stream); + } + + template + inline cudaError_t Unary(FunctorT functor, uint n, OUT *out, const IN *in, cudaStream_t stream) + { + return UnaryTransit(TransitFactory(functor), n, out, in, stream); + } + + template + inline cudaError_t BinaryTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, in2, stream); + } + + // API elementwise for input: [a, b], output: out. + template + inline cudaError_t Binary(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, cudaStream_t stream) + { + return BinaryTransit(TransitFactory(functor), n, out, in, in2, stream); + } + + template + inline cudaError_t TernaryTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, + cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, in2, in3, stream); + } + + // API elementwise for input: [a, b, c], output: out. + template + inline cudaError_t Ternary(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, + cudaStream_t stream) + { + return TernaryTransit(TransitFactory(functor), n, out, in, in2, in3, stream); + } + + template + inline cudaError_t EightInputsTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, const IN5 *in5, const IN6 *in6, const IN7 *in7, const IN8 *in8, + cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, in2, in3, in4, in5, in6, in7, in8, stream); + } + + template + inline cudaError_t EightInputs(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, const IN5 *in5, const IN6 *in6, const IN7 *in7, const IN8 *in8, + cudaStream_t stream) + { + return EightInputsTransit(TransitFactory(functor), n, out, in, in2, in3, in4, in5, in6, in7, in8, stream); + } + + template + inline cudaError_t FourInputsTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, + cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, in2, in3, in4, stream); + } + + template + inline cudaError_t FourInputs(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, + cudaStream_t stream) + { + return FourInputsTransit(TransitFactory(functor), n, out, in, in2, in3, in4, stream); + } + + template + inline cudaError_t FiveInputsTransit(Factory factory, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, const IN5 *in5, + cudaStream_t stream) + { + return DoLaunch::Launch(factory, n, out, in, in2, in3, in4, in5, stream); + } + + template + inline cudaError_t FiveInputs(FunctorT functor, uint n, OUT *out, const IN *in, const IN2 *in2, const IN3 *in3, const IN4 *in4, const IN5 *in5, + cudaStream_t stream) + { + return FiveInputsTransit(TransitFactory(functor), n, out, in, in2, in3, in4, in5, stream); + } + } // namespace elementwise +} // namespace cuda \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu new file mode 100644 index 00000000..e746e005 --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_boundding_boxes_coord_kernel.cu @@ -0,0 +1,178 @@ +#include +#include + +constexpr int thread_per_block = 256; + +__global__ void FusedGetBounddingBoxesCoordKernel( + const size_t size, const float *x1, const float *y1, const float *w1, + const float *h1, const float *x2, const float *y2, const float *w2, + const float *h2, float *b1_x1, float *b1_y1, float *b1_x2, float *b1_y2, + float *b2_x1, float *b2_y1, float *b2_x2, float *b2_y2) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) + { + const float w1_ = w1[i] / 2.0; + const float h1_ = h1[i] / 2.0; + const float w2_ = w2[i] / 2.0; + const float h2_ = h2[i] / 2.0; + const float x1_i = x1[i], y1_i = y1[i], x2_i = x2[i], y2_i = y2[i]; + b1_x1[i] = x1_i - w1_; + b1_x2[i] = x1_i + w1_; + b1_y1[i] = y1_i - h1_; + b1_y2[i] = y1_i + h1_; + b2_x1[i] = x2_i - w2_; + b2_x2[i] = x2_i + w2_; + b2_y1[i] = y2_i - h2_; + b2_y2[i] = y2_i + h2_; + } +} + +__global__ void FusedGetBounddingBoxesCoordGradKernel( + const size_t size, const float *d_b1_x1, const float *d_b1_x2, + const float *d_b1_y1, const float *d_b1_y2, const float *d_b2_x1, + const float *d_b2_x2, const float *d_b2_y1, const float *d_b2_y2, + float *d_x1, float *d_y1, float *d_w1, float *d_h1, float *d_x2, + float *d_y2, float *d_w2, float *d_h2) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) + { + const float d_b1_x1_i = d_b1_x1[i]; + const float d_b1_x2_i = d_b1_x2[i]; + const float d_b1_y1_i = d_b1_y1[i]; + const float d_b1_y2_i = d_b1_y2[i]; + const float d_b2_x1_i = d_b2_x1[i]; + const float d_b2_x2_i = d_b2_x2[i]; + const float d_b2_y1_i = d_b2_y1[i]; + const float d_b2_y2_i = d_b2_y2[i]; + d_x1[i] = d_b1_x1_i + d_b1_x2_i; + d_y1[i] = d_b1_y1_i + d_b1_y2_i; + d_w1[i] = 0.0; + d_h1[i] = 0.0; + d_x2[i] = d_b2_x1_i + d_b2_x2_i; + d_y2[i] = d_b2_y1_i + d_b2_y2_i; + d_w2[i] = 0.0; + d_h2[i] = 0.0; + } +} + +extern "C" int FusedGetBounddingBoxesCoord(int nparam, void **params, + int *ndims, int64_t **shapes, + const char **dtypes, void *stream, + void *extra) +{ + cudaStream_t custream = static_cast(stream); + int input_num = 16; + int output_index = 8; + // check input number + if (nparam != input_num) + { + printf( + "For FusedGetBounddingBoxesCoord, the number of input should be %d, " + "but got %d.", + input_num, nparam); + return 1; + } + // check dytpe + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + printf( + "For FusedGetBounddingBoxesCoord, the dtype of input should be %s, " + "but got %s.", + "float32", dtypes[i]); + return 2; + } + } + // read input and output parameters + const float *x1 = static_cast(params[0]); + const float *y1 = static_cast(params[1]); + const float *w1 = static_cast(params[2]); + const float *h1 = static_cast(params[3]); + const float *x2 = static_cast(params[4]); + const float *y2 = static_cast(params[5]); + const float *w2 = static_cast(params[6]); + const float *h2 = static_cast(params[7]); + float *b1_x1 = static_cast(params[8]); + float *b1_x2 = static_cast(params[9]); + float *b1_y1 = static_cast(params[10]); + float *b1_y2 = static_cast(params[11]); + float *b2_x1 = static_cast(params[12]); + float *b2_x2 = static_cast(params[13]); + float *b2_y1 = static_cast(params[14]); + float *b2_y2 = static_cast(params[15]); + + // calculate the size of output + size_t size = std::accumulate(shapes[output_index], + shapes[output_index] + ndims[output_index], + size_t(1), std::multiplies()); + int block_num = (size + thread_per_block - 1) / thread_per_block; + FusedGetBounddingBoxesCoordKernel<<>>( + size, x1, y1, w1, h1, x2, y2, w2, h2, b1_x1, b1_y1, b1_x2, b1_y2, b2_x1, + b2_y1, b2_x2, b2_y2); + return 0; +} + +extern "C" int FusedGetBounddingBoxesCoordGrad(int nparam, void **params, + int *ndims, int64_t **shapes, + const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + int input_num = 16; + int output_index = 8; + // check input number + if (nparam != input_num) + { + printf( + "For FusedGetBounddingBoxesCoordGrad, the number of input should be " + "%d, " + "but got %d.", + input_num, nparam); + return 1; + } + // check dytpe + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + printf( + "For FusedGetBounddingBoxesCoordGrad, the dtype of input should be " + "%s, " + "but got %s.", + "float32", dtypes[i]); + return 2; + } + } + // read input and output parameters + const float *d_b1_x1 = static_cast(params[0]); + const float *d_b1_x2 = static_cast(params[1]); + const float *d_b1_y1 = static_cast(params[2]); + const float *d_b1_y2 = static_cast(params[3]); + const float *d_b2_x1 = static_cast(params[4]); + const float *d_b2_x2 = static_cast(params[5]); + const float *d_b2_y1 = static_cast(params[6]); + const float *d_b2_y2 = static_cast(params[7]); + float *d_x1 = static_cast(params[8]); + float *d_y1 = static_cast(params[9]); + float *d_w1 = static_cast(params[10]); + float *d_h1 = static_cast(params[11]); + float *d_x2 = static_cast(params[12]); + float *d_y2 = static_cast(params[13]); + float *d_w2 = static_cast(params[14]); + float *d_h2 = static_cast(params[15]); + + // calculate the size of output + size_t size = std::accumulate(shapes[output_index], + shapes[output_index] + ndims[output_index], + size_t(1), std::multiplies()); + int block_num = (size + thread_per_block - 1) / thread_per_block; + FusedGetBounddingBoxesCoordGradKernel<<>>( + size, d_b1_x1, d_b1_x2, d_b1_y1, d_b1_y2, d_b2_x1, d_b2_x2, d_b2_y1, d_b2_y2, d_x1, d_y1, d_w1, + d_h1, d_x2, d_y2, d_w2, d_h2); + return 0; +} diff --git a/mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu new file mode 100644 index 00000000..06de3c76 --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_center_dist_kernel.cu @@ -0,0 +1,154 @@ +#include +#include +#include "cuda_runtime.h" +#include "elementswise_op_impl.cu" + +constexpr int THREADS = 256; + +template +struct FusedGetCenterDistFunctor +{ + FusedGetCenterDistFunctor() {} + __device__ __forceinline__ T operator()(T b1_x1, T b1_x2, T b1_y1, T b1_y2, T b2_x1, T b2_x2, T b2_y1, T b2_y2) const + { + T a = b2_x1 + b2_x2 - b1_x1 - b1_x2; + T b = b2_y1 + b2_y2 - b1_y1 - b1_y2; + return (a * a + b * b) / 4; + } +}; + +template +void FusedGetCenterDistKernel(const T *b1_x1, const T *b1_x2, const T *b1_y1, const T *b1_y2, + const T *b2_x1, const T *b2_x2, const T *b2_y1, const T *b2_y2, + T *output, const size_t count, cudaStream_t cuda_stream) +{ + FusedGetCenterDistFunctor functor; + cuda::elementwise::EightInputs(functor, (uint)(count), output, b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, cuda_stream); +} + +extern "C" int FusedGetCenterDist(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 8; + constexpr int TOTAL_PARAM_NUM = 9; + if (nparam != TOTAL_PARAM_NUM) + { + printf("[Error] nparam is %d", nparam); + return 1; + } + // check the data type is float32 + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + // read input & output parameters + void *b1_x1 = params[0]; + void *b1_x2 = params[1]; + void *b1_y1 = params[2]; + void *b1_y2 = params[3]; + void *b2_x1 = params[4]; + void *b2_x2 = params[5]; + void *b2_y1 = params[6]; + void *b2_y2 = params[7]; + void *out = params[8]; + // calculate the size to data to be processed + size_t size = 1; + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + FusedGetCenterDistKernel(static_cast(b1_x1), static_cast(b1_x2), + static_cast(b1_y1), static_cast(b1_y2), + static_cast(b2_x1), static_cast(b2_x2), + static_cast(b2_y1), static_cast(b2_y2), + static_cast(out), size, custream); + return 0; +} + +__global__ void FusedGetCenterDistBpropKernel(const float *b1_x1, const float *b1_x2, const float *b1_y1, const float *b1_y2, + const float *b2_x1, const float *b2_x2, const float *b2_y1, const float *b2_y2, + const float *d_out, float *d_b1_x1, float *d_b1_x2, float *d_b1_y1, float *d_b1_y2, + float *d_b2_x1, float *d_b2_x2, float *d_b2_y1, float *d_b2_y2, const size_t size) +{ + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; idx < size; + idx += step) + { + const float b1_x1_i = b1_x1[idx]; + const float b1_x2_i = b1_x2[idx]; + const float b1_y1_i = b1_y1[idx]; + const float b1_y2_i = b1_y2[idx]; + const float b2_x1_i = b2_x1[idx]; + const float b2_x2_i = b2_x2[idx]; + const float b2_y1_i = b2_y1[idx]; + const float b2_y2_i = b2_y2[idx]; + const float d_out_i = d_out[idx]; + float dx_i = (b2_x1_i + b2_x2_i - b1_x1_i - b1_x2_i) * d_out_i / 2; + d_b1_x1[idx] = -dx_i; + d_b1_x2[idx] = -dx_i; + d_b2_x1[idx] = dx_i; + d_b2_x2[idx] = dx_i; + float dy_i = (b2_y1_i + b2_y2_i - b1_y1_i - b1_y2_i) * d_out_i / 2; + d_b1_y1[idx] = -dy_i; + d_b1_y2[idx] = -dy_i; + d_b2_y1[idx] = dy_i; + d_b2_y2[idx] = dy_i; + } +} + +extern "C" int FusedGetCenterDistBprop(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 9; + constexpr int TOTAL_PARAM_NUM = 17; + if (nparam != TOTAL_PARAM_NUM) + return 1; + // check the data type is float32 + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + // read input & output parameters + void *b1_x1 = params[0]; + void *b1_x2 = params[1]; + void *b1_y1 = params[2]; + void *b1_y2 = params[3]; + void *b2_x1 = params[4]; + void *b2_x2 = params[5]; + void *b2_y1 = params[6]; + void *b2_y2 = params[7]; + void *d_out = params[8]; + void *d_b1_x1 = params[9]; + void *d_b1_x2 = params[10]; + void *d_b1_y1 = params[11]; + void *d_b1_y2 = params[12]; + void *d_b2_x1 = params[13]; + void *d_b2_x2 = params[14]; + void *d_b2_y1 = params[15]; + void *d_b2_y2 = params[16]; + // calculate the size to data to be processed + size_t size = 1; + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + int n = size / THREADS; + + FusedGetCenterDistBpropKernel<<>>(static_cast(b1_x1), static_cast(b1_x2), + static_cast(b1_y1), static_cast(b1_y2), + static_cast(b2_x1), static_cast(b2_x2), + static_cast(b2_y1), static_cast(b2_y2), + static_cast(d_out), + static_cast(d_b1_x1), static_cast(d_b1_x2), + static_cast(d_b1_y1), static_cast(d_b1_y2), + static_cast(d_b2_x1), static_cast(d_b2_x2), + static_cast(d_b2_y1), static_cast(d_b2_y2), size); + return 0; +} diff --git a/mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu new file mode 100644 index 00000000..a3c389f2 --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_ciou_diagonal_angle_kernel.cu @@ -0,0 +1,120 @@ +#include +#include "elementswise_op_impl.cu" +constexpr int THREADS = 256; + +template +struct FusedGetCiouDiagonalAngleFunctor { + FusedGetCiouDiagonalAngleFunctor() {} + __device__ __forceinline__ T operator()(T w1, T h1, T w2, T h2) const { + T eps = static_cast(1e-7); + const T angle = atan(w2 / (h2 + eps)) - atan(w1 / (h1 + eps)); + return static_cast(4.0 / (M_PI * M_PI) * angle * angle); + } +}; + +template +void FusedGetCiouDiagonalAngleKernel(const T *w1, const T *h1, const T *w2, const T *h2, + T *output, const size_t count, cudaStream_t cuda_stream) { + FusedGetCiouDiagonalAngleFunctor functor; + cuda::elementwise::FourInputs(functor, (uint)(count), output, w1, h1, w2, h2, cuda_stream); +} + + +template +__global__ void FusedGetCiouDiagonalAngleGradKernel( + const int output_num, const T* w1, const T* h1, const T* w2, const T* h2, const T* v_diff, + T* w1_diff, T* h1_diff, T* w2_diff, + T* h2_diff) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + T eps = static_cast(1e-7); + for(int i = idx; i < output_num; i += blockDim.x * gridDim.x) { + const T w1_i = w1[i]; + const T h1_i = h1[i]; + const T w2_i = w2[i]; + const T h2_i = h2[i]; + const T v_diff_i = v_diff[i]; + const T angle_delta = static_cast(8.0) * (atan(w2_i / (h2_i + eps)) - atan(w1_i / (h1_i + eps))) / static_cast((M_PI * M_PI)); + const T angle1 = static_cast(1.0) + (w1_i * w1_i / ((h1_i + eps) * (h1_i + eps))); + const T angle2 = static_cast(1.0) + (w2_i * w2_i / ((h2_i + eps) * (h2_i + eps))); + w1_diff[i] = static_cast(-1.0) * angle_delta / ((h1_i + eps) * angle1) * v_diff_i; + w2_diff[i] = angle_delta / ((h2_i + eps) * angle2) * v_diff_i; + h1_diff[i] = w1_i * angle_delta / ((h1_i + eps) * (h1_i + eps) * angle1) * v_diff_i; + h2_diff[i] = static_cast(-1.0) * w2_i * angle_delta / ((h2_i + eps) * (h2_i + eps) * angle2) * v_diff_i; + } +} + +extern "C" int FusedGetCiouDiagonalAngle(int nparam, void **params, int *ndims, int64_t **shapes, + const char **dtypes, void *stream, void *extra) { + cudaStream_t custream = static_cast(stream); + int expect_num_param = 5; + if (nparam != expect_num_param) { + printf("Param num is not equal to %d \n.", expect_num_param); + return -1; + } + // check the data type is float32 or float16 + for (int i = 0; i < nparam; i++) { + if (strcmp(dtypes[i], "float32") != 0 && strcmp(dtypes[i], "float16") != 0) { + printf("dtypes is not equal to float32 and float16."); + return -1; + } + } + int output_index = 4; + void *w1 = params[0]; + void *h1 = params[1]; + void *w2 = params[2]; + void *h2 = params[3]; + void *v = params[4]; + int output_num = 1; + + for (int i = 0; i < ndims[output_index]; i++) { + output_num *= shapes[output_index][i]; + } + FusedGetCiouDiagonalAngleKernel(static_cast(w1), + static_cast(h1), + static_cast(w2), + static_cast(h2), + static_cast(v), output_num, custream); + return 0; +} + + +extern "C" int FusedGetCiouDiagonalAngleGrad(int nparam, void **params, int *ndims, int64_t **shapes, + const char **dtypes, void *stream, void *extra) { + cudaStream_t custream = static_cast(stream); + int expect_num_param = 9; + if (nparam != expect_num_param) { + printf("Param num is not equal to %d \n.", expect_num_param); + return -1; + } + // check the data type is float32 or float16 + for (int i = 0; i < nparam; i++) { + if (strcmp(dtypes[i], "float32") != 0 && strcmp(dtypes[i], "float16") != 0) { + printf("dtypes is not equal to float32 and float16."); + return -1; + } + } + + void *w1 = params[0]; + void *h1 = params[1]; + void *w2 = params[2]; + void *h2 = params[3]; + void *v_diff = params[4]; + void *w1_diff = params[5]; + void *h1_diff = params[6]; + void *w2_diff = params[7]; + void *h2_diff = params[8]; + int output_num = 1; + + int output_index = 4; + for (int i = 0; i < ndims[output_index]; i++) { + output_num *= shapes[output_index][i]; + } + int block_num = output_num / THREADS + 1; + FusedGetCiouDiagonalAngleGradKernel<<>>(output_num, + static_cast(w1), static_cast(h1), + static_cast(w2), static_cast(h2), + static_cast(v_diff), static_cast(w1_diff), + static_cast(h1_diff), static_cast(w2_diff), + static_cast(h2_diff)); + return 0; +} \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu new file mode 100644 index 00000000..16427997 --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_ciou_kernel.cu @@ -0,0 +1,120 @@ +#include + +constexpr int THREADS = 256; +constexpr float EPS = 1e-7; + +__global__ void FusedGetCiouKernel(const float *v, const float *iou, const float *rho2, const float *c2, + float *alpha, float *out, const size_t size) +{ + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; idx < size; + idx += step) + { + const float v_i = v[idx]; + const float iou_i = iou[idx]; + const float alpha_i = v_i / (v_i - iou_i + 1.0 + EPS); + out[idx] = iou_i - (rho2[idx] / c2[idx] + v_i * alpha_i); + alpha[idx] = alpha_i; + } +} + +extern "C" int FusedGetCiou(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 5; + constexpr int TOTAL_PARAM_NUM = 6; + if (nparam != TOTAL_PARAM_NUM) + { + printf("[Error] nparam is %d", nparam); + return 1; + } + // check the data type is float32 + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + // read input & output parameters + void *v = params[0]; + void *iou = params[1]; + void *rho2 = params[2]; + void *c2 = params[3]; + void *alpha = params[4]; + void *out = params[5]; + // calculate the size to data to be processed + size_t size = 1; + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + int n = size / THREADS; + + FusedGetCiouKernel<<>>(static_cast(v), static_cast(iou), + static_cast(rho2), static_cast(c2), + static_cast(alpha), static_cast(out), size); + return 0; +} + +__global__ void FusedGetCiouBpropKernel(const float *v, const float *iou, const float *rho2, const float *c2, + const float *d_alpha, const float *d_out, float *d_v, float *d_iou, + float *d_rho2, float *d_c2, const size_t size) +{ + for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; idx < size; + idx += step) + { + const float v_i = v[idx]; + const float iou_i = iou[idx]; + const float c2_i = c2[idx]; + const float d_out_i = d_out[idx]; + const float alpha_i = v_i / (v_i - iou_i + 1.0 + EPS); + d_v[idx] = -alpha_i * d_out_i; + d_iou[idx] = d_out_i; + d_rho2[idx] = -d_out_i / c2_i; + d_c2[idx] = rho2[idx] / (c2_i * c2_i) * d_out_i; + } +} + +extern "C" int FusedGetCiouBprop(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 6; + constexpr int TOTAL_PARAM_NUM = 10; + if (nparam != TOTAL_PARAM_NUM) + return 1; + // check the data type is float32 + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + // read input & output parameters + void *v = params[0]; + void *iou = params[1]; + void *rho2 = params[2]; + void *c2 = params[3]; + void *d_alpha = params[4]; + void *d_out = params[5]; + void *d_v = params[6]; + void *d_iou = params[7]; + void *d_rho2 = params[8]; + void *d_c2 = params[9]; + // calculate the size to data to be processed + size_t size = 1; + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + int n = size / THREADS; + + FusedGetCiouBpropKernel<<>>(static_cast(v), static_cast(iou), + static_cast(rho2), static_cast(c2), + static_cast(d_alpha), static_cast(d_out), + static_cast(d_v), static_cast(d_iou), + static_cast(d_rho2), static_cast(d_c2), size); + return 0; +} \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu new file mode 100644 index 00000000..b7b3eb0a --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_convex_diagonal_squared_kernel.cu @@ -0,0 +1,161 @@ +#include +#include "elementswise_op_impl.cu" +constexpr int THREADS = 256; + +template +__device__ T Max(T a, T b){ + if (a > b) { + return a; + } else { + return b; + } +} + +template +__device__ T Min(T a, T b){ + if (a < b) { + return a; + } else { + return b; + } +} + +template +struct FusedGetConvexDiagonalSquaredFunctor { + FusedGetConvexDiagonalSquaredFunctor() {} + __device__ __forceinline__ T operator()(T b1_x1, T b1_x2, T b2_x1, T b2_x2, T b1_y1, T b1_y2, T b2_y1, T b2_y2) const { + float eps = 1e-7; + const T cw = Max(b1_x2, b2_x2) - Min(b1_x1, b2_x1); + const T ch = Max(b1_y2, b2_y2) - Min(b1_y1, b2_y1); + return cw * cw + ch * ch + static_cast(eps); + } +}; + +template +void FusedGetConvexDiagonalSquaredKernel(const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, + const T* b1_y2, const T* b2_y1, const T* b2_y2, + T *output, const size_t count, cudaStream_t cuda_stream) { + FusedGetConvexDiagonalSquaredFunctor functor; + cuda::elementwise::EightInputs(functor, (uint)(count), output, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, cuda_stream); +} + + +template +__global__ void FusedGetConvexDiagonalSquaredGradKernel( + const int output_num, const T* b1_x1, const T* b1_x2, const T* b2_x1, const T* b2_x2, const T* b1_y1, + const T* b1_y2, const T* b2_y1, const T* b2_y2, const T* d_c2, T* d_b1_x1, T* d_b1_x2, + T* d_b2_x1, T* d_b2_x2, T* d_b1_y1, T* d_b1_y2, T* d_b2_y1, T* d_b2_y2) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + const T zero = static_cast(0); + const T one = static_cast(1); + const T two = static_cast(2); + for(int i = idx; i < output_num; i += blockDim.x * gridDim.x) { + const T cw = Max(b1_x2[i], b2_x2[i]) - Min(b1_x1[i], b2_x1[i]); + const T ch = Max(b1_y2[i], b2_y2[i]) - Min(b1_y1[i], b2_y1[i]); + const T d_c2_cw = two * cw * d_c2[i]; + const T d_c2_ch = two * ch * d_c2[i]; + d_b1_x2[i] = d_c2_cw * (b1_x2[i] > b2_x2[i] ? one : zero); + d_b2_x2[i] = d_c2_cw * (b1_x2[i] > b2_x2[i] ? zero : one); + d_b1_x1[i] = -d_c2_cw * (b1_x1[i] < b2_x1[i] ? one : zero); + d_b2_x1[i] = -d_c2_cw * (b1_x1[i] < b2_x1[i] ? zero : one); + d_b1_y2[i] = d_c2_ch * (b1_y2[i] > b2_y2[i] ? one : zero); + d_b2_y2[i] = d_c2_ch * (b1_y2[i] > b2_y2[i] ? zero : one); + d_b1_y1[i] = -d_c2_ch * (b1_y1[i] < b2_y1[i] ? one : zero); + d_b2_y1[i] = -d_c2_ch * (b1_y1[i] < b2_y1[i] ? zero : one); + } +} + +extern "C" int FusedGetConvexDiagonalSquared(int nparam, void **params, int *ndims, int64_t **shapes, + const char **dtypes, void *stream, void *extra) { + cudaStream_t custream = static_cast(stream); + int expect_num_param = 9; + if (nparam != expect_num_param) { + printf("Param num is not equal to %d \n.", expect_num_param); + return -1; + } + // check the data type is float32 or float16 + for (int i = 0; i < nparam; i++) { + if (strcmp(dtypes[i], "float32") != 0 && strcmp(dtypes[i], "float16") != 0) { + printf("dtypes is not equal to float32 and float16."); + return -1; + } + } + int output_index = 8; + void *b1_x1 = params[0]; + void *b1_x2 = params[1]; + void *b2_x1 = params[2]; + void *b2_x2 = params[3]; + void *b1_y1 = params[4]; + void *b1_y2 = params[5]; + void *b2_y1 = params[6]; + void *b2_y2 = params[7]; + void *c2 = params[8]; + int output_num = 1; + + for (int i = 0; i < ndims[output_index]; i++) { + output_num *= shapes[output_index][i]; + } + FusedGetConvexDiagonalSquaredKernel(static_cast(b1_x1), + static_cast(b1_x2), + static_cast(b2_x1), + static_cast(b2_x2), + static_cast(b1_y1), + static_cast(b1_y2), + static_cast(b2_y1), + static_cast(b2_y2), + static_cast(c2), output_num, custream); + return 0; +} + + +extern "C" int FusedGetConvexDiagonalSquaredGrad(int nparam, void **params, int *ndims, int64_t **shapes, + const char **dtypes, void *stream, void *extra) { + cudaStream_t custream = static_cast(stream); + int expect_num_param = 17; + if (nparam != expect_num_param) { + printf("Param num is not equal to %d \n.", expect_num_param); + return -1; + } + // check the data type is float32 or float16 + for (int i = 0; i < nparam; i++) { + if (strcmp(dtypes[i], "float32") != 0 && strcmp(dtypes[i], "float16") != 0) { + printf("dtypes is not equal to float32 and float16."); + return -1; + } + } + void *b1_x1 = params[0]; + void *b1_x2 = params[1]; + void *b2_x1 = params[2]; + void *b2_x2 = params[3]; + void *b1_y1 = params[4]; + void *b1_y2 = params[5]; + void *b2_y1 = params[6]; + void *b2_y2 = params[7]; + void *d_c2 = params[8]; + void *d_b1_x1 = params[9]; + void *d_b1_x2 = params[10]; + void *d_b2_x1 = params[11]; + void *d_b2_x2 = params[12]; + void *d_b1_y1 = params[13]; + void *d_b1_y2 = params[14]; + void *d_b2_y1 = params[15]; + void *d_b2_y2 = params[16]; + int output_num = 1; + + int output_index = 9; + for (int i = 0; i < ndims[output_index]; i++) { + output_num *= shapes[output_index][i]; + } + int block_num = output_num / THREADS + 1; + FusedGetConvexDiagonalSquaredGradKernel<<>>(output_num, + static_cast(b1_x1), static_cast(b1_x2), + static_cast(b2_x1), static_cast(b2_x2), + static_cast(b1_y1), static_cast(b1_y2), + static_cast(b2_y1), static_cast(b2_y2), + static_cast(d_c2), static_cast(d_b1_x1), + static_cast(d_b1_x2), static_cast(d_b2_x1), + static_cast(d_b2_x2), static_cast(d_b1_y1), + static_cast(d_b1_y2), static_cast(d_b2_y1), + static_cast(d_b2_y2)); + return 0; +} diff --git a/mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu new file mode 100644 index 00000000..3afeed13 --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_intersection_area_kernel.cu @@ -0,0 +1,181 @@ +#include +#include +#include "elementswise_op_impl.cu" + +constexpr int thread_per_block = 256; + +__host__ __device__ float GetDistance(const float &b1_x2_i, const float &b2_x2_i, + const float &b1_x1_i, const float &b2_x1_i) +{ + return min(b1_x2_i, b2_x2_i) - max(b1_x1_i, b2_x1_i); +} + +template +struct FusedGetIntersectionAreaFunctor +{ + FusedGetIntersectionAreaFunctor() {} + __device__ __forceinline__ T operator()(const T b1_x1, const T b1_x2, + const T b2_x1, const T b2_x2, const T b1_y1, + const T b1_y2, const T b2_y1, const T b2_y2) const + { + const T w = GetDistance(b1_x2, b2_x2, b1_x1, b2_x1); + const T h = GetDistance(b1_y2, b2_y2, b1_y1, b2_y1); + return (w > 0.0 && h > 0.0) ? w * h : 0.0; + } +}; + +template +void FusedGetIntersectionAreaKernel(const T *b1_x1, const T *b1_x2, + const T *b2_x1, const T *b2_x2, const T *b1_y1, + const T *b1_y2, const T *b2_y1, const T *b2_y2, + T *output, const size_t count, cudaStream_t cuda_stream) +{ + FusedGetIntersectionAreaFunctor functor; + cuda::elementwise::EightInputs(functor, (uint)(count), output, b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, + b1_y2, b2_y1, b2_y2, cuda_stream); +} + +__global__ void FusedGetIntersectionAreaGradKernel( + const size_t size, const float *b1_x1, const float *b1_x2, + const float *b2_x1, const float *b2_x2, const float *b1_y1, + const float *b1_y2, const float *b2_y1, const float *b2_y2, + const float *d_inter, float *d_b1_x1, float *d_b1_x2, float *d_b2_x1, + float *d_b2_x2, float *d_b1_y1, float *d_b1_y2, float *d_b2_y1, + float *d_b2_y2) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) + { + float d_inter_i = d_inter[i]; + const float w = GetDistance(b1_x2[i], b2_x2[i], b1_x1[i], b2_x1[i]); + const float h = GetDistance(b1_y2[i], b2_y2[i], b1_y1[i], b2_y1[i]); + d_b1_x1[i] = 0.0; + d_b1_x2[i] = 0.0; + d_b2_x2[i] = 0.0; + d_b2_x2[i] = 0.0; + d_b1_y1[i] = 0.0; + d_b1_y2[i] = 0.0; + d_b2_y1[i] = 0.0; + d_b2_y2[i] = 0.0; + if (w > 0.0 && h > 0.0) + { + d_b1_x1[i] = (b1_x1[i] >= b2_x1[i]) ? -d_inter_i * h : 0.0; + d_b2_x1[i] = (b1_x1[i] <= b2_x1[i]) ? -d_inter_i * h : 0.0; + d_b1_x2[i] = (b1_x2[i] <= b2_x2[i]) ? d_inter_i * h : 0.0; + d_b2_x2[i] = (b1_x2[i] >= b2_x2[i]) ? d_inter_i * h : 0.0; + d_b1_y1[i] = (b1_y1[i] >= b2_y1[i]) ? -d_inter_i * w : 0.0; + d_b2_y1[i] = (b1_y1[i] <= b2_y1[i]) ? -d_inter_i * w : 0.0; + d_b1_y2[i] = (b1_y2[i] <= b2_y2[i]) ? d_inter_i * w : 0.0; + d_b2_y2[i] = (b1_y2[i] >= b2_y2[i]) ? d_inter_i * w : 0.0; + } + } +} + +extern "C" int FusedGetIntersectionArea(int nparam, void **params, int *ndims, + int64_t **shapes, const char **dtypes, + void *stream, void *extra) +{ + cudaStream_t custream = static_cast(stream); + int input_num = 9; + int output_index = 8; + // check input number + if (nparam != input_num) + { + printf( + "For FusedGetIntersectionArea, the number of input should be %d, " + "but got %d.", + input_num, nparam); + return 1; + } + // check dytpe + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + printf( + "For FusedGetIntersectionArea, the dtype of input should be %s, " + "but got %s.", + "float32", dtypes[i]); + return 2; + } + } + // read input and output parameters + const float *b1_x1 = static_cast(params[0]); + const float *b1_x2 = static_cast(params[1]); + const float *b1_y1 = static_cast(params[2]); + const float *b1_y2 = static_cast(params[3]); + const float *b2_x1 = static_cast(params[4]); + const float *b2_x2 = static_cast(params[5]); + const float *b2_y1 = static_cast(params[6]); + const float *b2_y2 = static_cast(params[7]); + float *inter = static_cast(params[8]); + + // calculate the size of output + size_t size = std::accumulate(shapes[output_index], + shapes[output_index] + ndims[output_index], + size_t(1), std::multiplies()); + FusedGetIntersectionAreaKernel(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, inter, size, custream); + return 0; +} + +extern "C" int FusedGetIntersectionAreaGrad(int nparam, void **params, + int *ndims, int64_t **shapes, + const char **dtypes, void *stream, + void *extra) +{ + cudaStream_t custream = static_cast(stream); + int input_num = 17; + int output_index = 9; + // check input number + if (nparam != input_num) + { + printf( + "For FusedGetIntersectionAreaGrad, the number of input should be " + "%d, " + "but got %d.", + input_num, nparam); + return 1; + } + // check dytpe + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + printf( + "For FusedGetIntersectionAreaGrad, the dtype of input should be " + "%s, " + "but got %s.", + "float32", dtypes[i]); + return 2; + } + } + // read input and output parameters + const float *b1_x1 = static_cast(params[0]); + const float *b1_x2 = static_cast(params[1]); + const float *b1_y1 = static_cast(params[2]); + const float *b1_y2 = static_cast(params[3]); + const float *b2_x1 = static_cast(params[4]); + const float *b2_x2 = static_cast(params[5]); + const float *b2_y1 = static_cast(params[6]); + const float *b2_y2 = static_cast(params[7]); + const float *d_inter = static_cast(params[8]); + float *d_b1_x1 = static_cast(params[9]); + float *d_b1_x2 = static_cast(params[10]); + float *d_b1_y1 = static_cast(params[11]); + float *d_b1_y2 = static_cast(params[12]); + float *d_b2_x1 = static_cast(params[13]); + float *d_b2_x2 = static_cast(params[14]); + float *d_b2_y1 = static_cast(params[15]); + float *d_b2_y2 = static_cast(params[16]); + + // calculate the size of output + size_t size = std::accumulate(shapes[output_index], + shapes[output_index] + ndims[output_index], + size_t(1), std::multiplies()); + int block_num = (size + thread_per_block - 1) / thread_per_block; + FusedGetIntersectionAreaGradKernel<<>>( + size, b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, d_inter, + d_b1_x1, d_b1_x2, d_b1_y1, d_b1_y2, d_b2_x1, d_b2_x2, d_b2_y1, d_b2_y2); + return 0; +} \ No newline at end of file diff --git a/mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu b/mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu new file mode 100644 index 00000000..05df35fb --- /dev/null +++ b/mindyolo/models/losses/fused_op/fused_get_iou_kernel.cu @@ -0,0 +1,126 @@ +#include +#include "elementswise_op_impl.cu" + +constexpr int THREADS = 256; +constexpr float EPS = 1e-7; + +template +struct FusedGetIouFunctor +{ + FusedGetIouFunctor() {} + __device__ __forceinline__ T operator()(T w1, T h1, T w2, T h2, T inter) const + { + T val_union = w1 * h1 + w2 * h2 - inter + EPS; + return val_union; + } +}; + +template +void FusedGetIouKernel(const T *w1, const T *h1, const T *w2, const T *h2, const T *inter, + T *output, const size_t count, cudaStream_t cuda_stream) +{ + FusedGetIouFunctor functor; + cuda::elementwise::FiveInputs(functor, (uint)(count), output, w1, h1, w2, h2, inter, cuda_stream); +} + +extern "C" int FusedGetIou(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, + void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 5; + constexpr int TOTAL_PARAM_NUM = 6; + if (TOTAL_PARAM_NUM != 6) + return 1; + void *w1 = params[0]; + void *h1 = params[1]; + void *w2 = params[2]; + void *h2 = params[3]; + void *inter = params[4]; + void *out = params[5]; + size_t size = 1; + + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + + FusedGetIouKernel(static_cast(w1), static_cast(h1), + static_cast(w2), static_cast(h2), + static_cast(inter), static_cast(out), size, custream); + return 0; +} + +template +__global__ void FusedGetIouBpropKernel(const T *w1, const T *h1, const T *w2, const T *h2, const T *inter, + const T *d_out, T *d_w1, T *d_h1, T *d_w2, T *d_h2, + T *d_inter, size_t size) +{ + + for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < size; + i += step) + { + const T w1_i = w1[i], h1_i = h1[i], w2_i = w2[i], h2_i = h2[i], inter_i = inter[i], diou_i = d_out[i]; + const T w_h_eps = w1_i * h1_i + w2_i * h2_i + static_cast(EPS); + const T w_h_eps_inter_diff = w_h_eps - inter_i; + const T w_h_eps_inter_diff_square = w_h_eps_inter_diff * w_h_eps_inter_diff; + const T common_for_dwh = -inter_i * diou_i / w_h_eps_inter_diff_square; + d_inter[i] = w_h_eps * diou_i / w_h_eps_inter_diff_square; + d_w1[i] = h1_i * common_for_dwh; + d_h1[i] = w1_i * common_for_dwh; + + d_w2[i] = h2_i * common_for_dwh; + d_h2[i] = w2_i * common_for_dwh; + } +} + +extern "C" int FusedGetIouBprop(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream, + void *extra) +{ + cudaStream_t custream = static_cast(stream); + constexpr int OUTPUT_INDEX = 6; + constexpr int TOTAL_PARAM_NUM = 11; + if (nparam != TOTAL_PARAM_NUM) + return 1; + // check the data type is float32 + for (int i = 0; i < nparam; i++) + { + if (strcmp(dtypes[i], "float32") != 0) + { + return 2; + } + } + // read input & output parameters + void *w1 = params[0]; + void *h1 = params[1]; + void *w2 = params[2]; + void *h2 = params[3]; + void *inter = params[4]; + void *d_out = params[5]; + void *d_w1 = params[6]; + void *d_h1 = params[7]; + void *d_w2 = params[8]; + void *d_h2 = params[9]; + void *d_inter = params[10]; + // calculate the size to data to be processed + size_t size = 1; + for (int i = 0; i < ndims[OUTPUT_INDEX]; i++) + { + size *= shapes[OUTPUT_INDEX][i]; + } + int n = size / THREADS; + + FusedGetIouBpropKernel<<>>(static_cast(w1), static_cast(h1), + static_cast(w2), static_cast(h2), + static_cast(inter), static_cast(d_out), + static_cast(d_w1), static_cast(d_h1), + static_cast(d_w2), static_cast(d_h2), + static_cast(d_inter), size); + return 0; +} diff --git a/mindyolo/models/losses/iou_loss.py b/mindyolo/models/losses/iou_loss.py index 1fd68abb..cd2a9199 100644 --- a/mindyolo/models/losses/iou_loss.py +++ b/mindyolo/models/losses/iou_loss.py @@ -2,13 +2,132 @@ import mindspore as ms from mindspore import Tensor, ops +from mindspore.common import dtype as mstype from mindyolo.models.layers.utils import box_cxcywh_to_xyxy +from .fused_op import fused_get_ciou_op_path, fused_get_ciou_op_bprop_path, fused_get_center_dist_op_path, \ + fused_get_center_dist_op_bprop_path, fuse_get_ciou_gpu_info, fuse_get_ciou_bprop_gpu_info, \ + fuse_get_center_dist_gpu_info, fuse_get_center_dist_bprop_gpu_info, fused_get_convex_diagonal_squared_info, \ + fused_get_convex_diagonal_squared_path, fused_get_convex_diagonal_squared_grad_path, \ + fused_get_boundding_boxes_coord_path, fused_get_boundding_boxes_coord_grad_path, \ + fused_get_intersection_area_path, fused_get_intersection_area_grad_path, \ + fused_get_convex_diagonal_squared_grad_info, fused_get_ciou_diagonal_angle_info,\ + fused_get_ciou_diagonal_angle_grad_info, fused_get_ciou_diagonal_angle_grad_path, fused_get_ciou_diagonal_angle_path, \ + fused_get_boundding_boxes_coord_gpu_info, fused_get_boundding_boxes_coord_bprop_gpu_info, \ + fused_get_intersection_area_gpu_info, fused_get_intersection_area_gpu_grad_info + PI = Tensor(math.pi, ms.float32) EPS = 1e-7 +def get_ciou_bprop(v, iou, rho2, c2, out, dout): + fuse_get_ciou_bprop = ops.Custom(fused_get_ciou_op_bprop_path, + out_shape=(v.shape, iou.shape, rho2.shape, c2.shape), + out_dtype=(mstype.float32, mstype.float32, mstype.float32, mstype.float32), + func_type="aot", reg_info=fuse_get_ciou_bprop_gpu_info) + res = fuse_get_ciou_bprop(v, iou, rho2, c2, dout[0], dout[1]) + return res + + +fuse_get_ciou = ops.Custom(fused_get_ciou_op_path, + out_shape=lambda v, iou, rho2, c2: (v, v), + out_dtype=(mstype.float32, mstype.float32), + func_type="aot", bprop=get_ciou_bprop, reg_info=fuse_get_ciou_gpu_info) + + +def get_center_dist_bprop(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, out, dout): + fuse_get_center_dist_bprop = ops.Custom(fused_get_center_dist_op_bprop_path, + out_shape=(b1_x1.shape, b1_x2.shape, b1_y1.shape, b1_y2.shape, + b2_x1.shape, b2_x2.shape, b2_y1.shape, b2_y2.shape), + out_dtype=(mstype.float32, mstype.float32, mstype.float32, mstype.float32, + mstype.float32, mstype.float32, mstype.float32, mstype.float32), + func_type="aot", reg_info=fuse_get_center_dist_bprop_gpu_info) + res = fuse_get_center_dist_bprop(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, dout) + return res + + +fuse_get_center_dist = ops.Custom(fused_get_center_dist_op_path, + out_shape=lambda b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2: (b1_x1), + out_dtype=(mstype.float32), + func_type="aot", bprop=get_center_dist_bprop, reg_info=fuse_get_center_dist_gpu_info) + + +def fused_get_convex_diagonal_squared_bprop(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, out, dout): + out_shape = (b1_x1.shape, b1_x2.shape, b2_x1.shape, b2_x2.shape, b1_y1.shape, b1_y2.shape, b2_y1.shape, b2_y2.shape) + out_dtype = (b1_x1.dtype, b1_x2.dtype, b2_x1.dtype, b2_x2.dtype, b1_y1.dtype, b1_y2.dtype, b2_y1.dtype, b2_y2.dtype) + op = ops.Custom(fused_get_convex_diagonal_squared_grad_path, + out_shape=out_shape, + out_dtype=out_dtype, + reg_info=fused_get_convex_diagonal_squared_grad_info, + func_type="aot") + return op(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, dout) + + +fused_get_convex_diagonal_squared = ops.Custom( + fused_get_convex_diagonal_squared_path, + out_shape=lambda b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2: b1_x1, + out_dtype=lambda b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2: b1_x1, + bprop=fused_get_convex_diagonal_squared_bprop, + reg_info=fused_get_convex_diagonal_squared_info, + func_type="aot") + + +def fused_get_ciou_diagonal_angle_bprop(w1, h1, w2, h2, out, dout): + out_shape = (w1.shape, h1.shape, w2.shape, h2.shape) + out_dtype = (w1.dtype, h1.dtype, w2.dtype, h2.dtype) + op = ops.Custom(fused_get_ciou_diagonal_angle_grad_path, + out_shape=out_shape, + out_dtype=out_dtype, + reg_info=fused_get_ciou_diagonal_angle_grad_info, + func_type="aot") + return op(w1, h1, w2, h2, dout) + + +fused_get_ciou_diagonal_angle = ops.Custom( + fused_get_ciou_diagonal_angle_path, + out_shape=lambda w1, h1, w2, h2: w1, + out_dtype=lambda w1, h1, w2, h2: w1, + bprop=fused_get_ciou_diagonal_angle_bprop, + reg_info=fused_get_ciou_diagonal_angle_info, + func_type="aot") + + +def fused_get_boundding_boxes_coord_bprop(x1, y1, w1, h1, x2, y2, w2, h2, out, dout): + out_shape = (x1.shape, y1.shape, w1.shape, h1.shape, + x2.shape, y2.shape, w2.shape, h2.shape) + out_dtype = (mstype.float32, mstype.float32, mstype.float32, mstype.float32, + mstype.float32, mstype.float32, mstype.float32, mstype.float32) + op = ops.Custom(fused_get_boundding_boxes_coord_grad_path, out_shape=out_shape, out_dtype=out_dtype, + func_type='aot', reg_info=fused_get_boundding_boxes_coord_bprop_gpu_info) + return op(dout[0], dout[1], dout[2], dout[3], dout[4], dout[5], dout[6], dout[7]) + + +fused_get_boundding_boxes_coord = ops.Custom(fused_get_boundding_boxes_coord_path,out_shape=lambda x1, y1, w1, h1, x2, y2, w2, h2: ( + x1, y1, w1, h1, x2, y2, w2, h2), + out_dtype=lambda x1, y1, w1, h1, x2, y2, w2, h2: ( + x1, y1, w1, h1, x2, y2, w2, h2), + func_type='aot', bprop=fused_get_boundding_boxes_coord_bprop, reg_info=fused_get_boundding_boxes_coord_gpu_info) + + +def fused_get_intersection_area_bprop(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, out, dout): + out_shape = (b1_x1.shape, b1_x2.shape, b2_x1.shape, b2_x2.shape, + b1_y1.shape, b1_y2.shape, b2_y1.shape, b2_y2.shape) + out_dtype = (mstype.float32, mstype.float32, mstype.float32, mstype.float32, + mstype.float32, mstype.float32, mstype.float32, mstype.float32) + op=ops.Custom(fused_get_intersection_area_grad_path, out_shape=out_shape, out_dtype=out_dtype, + func_type='aot', reg_info=fused_get_intersection_area_gpu_grad_info) + + return op(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2, dout) + + +fused_get_intersection_area = ops.Custom( + fused_get_intersection_area_path, + out_shape=lambda b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2: b1_x1, + out_dtype=lambda b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2: b1_x1, + func_type='aot', bprop=fused_get_intersection_area_bprop, reg_info=fused_get_intersection_area_gpu_info) + + def box_area(box): """ Return area of boxes. @@ -97,7 +216,7 @@ def batch_box_iou(batch_box1, batch_box2, xywh=False): ) # iou = inter / (area1 + area2 - inter) -def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7, use_fused_op=False): """ Return intersection-over-union (IoU) of boxes. Arguments: @@ -107,26 +226,30 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 GIoU (bool): Whether to use GIoU. Default: False. DIoU (bool): Whether to use DIoU. Default: False. CIoU (bool): Whether to use CIoU. Default: False. + use_fused_op(bool): Whether to use fused operator built upon aot customized operator. Default: False. Returns: iou (Tensor[N,]): the IoU values for every element in boxes1 and boxes2 """ - # Get the coordinates of bounding boxes if xywh: # transform from xywh to xyxy x1, y1, w1, h1 = ops.split(box1, split_size_or_sections=1, axis=-1) x2, y2, w2, h2 = ops.split(box2, split_size_or_sections=1, axis=-1) - w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 - b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ - b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + if use_fused_op: + b1_x1, b1_x2, b1_y1, b1_y2,b2_x1, b2_x2, b2_y1, b2_y2=fused_get_boundding_boxes_coord(x1, y1, w1, h1,x2, y2, w2, h2) + else: + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ else: # x1, y1, x2, y2 = box1 b1_x1, b1_y1, b1_x2, b1_y2 = ops.split(box1, split_size_or_sections=1, axis=-1) b2_x1, b2_y1, b2_x2, b2_y2 = ops.split(box2, split_size_or_sections=1, axis=-1) # Intersection area - inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0.0, None) * ( - ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1) - ).clip(0.0, None) - + if use_fused_op: + inter = fused_get_intersection_area(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) + else: + inter = (ops.minimum(b1_x2, b2_x2) - ops.maximum(b1_x1, b2_x1)).clip(0., None) * \ + (ops.minimum(b1_y2, b2_y2) - ops.maximum(b1_y1, b2_y1)).clip(0., None) # Union Area w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps @@ -134,18 +257,32 @@ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7 # IoU iou = inter / union + if CIoU or DIoU or GIoU: cw = ops.maximum(b1_x2, b2_x2) - ops.minimum(b1_x1, b2_x1) # convex (smallest enclosing box) width ch = ops.maximum(b1_y2, b2_y2) - ops.minimum(b1_y1, b2_y1) # convex height if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - c2 = cw**2 + ch**2 + eps # convex diagonal squared - rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + if use_fused_op: + c2 = fused_get_convex_diagonal_squared(b1_x1, b1_x2, b2_x1, b2_x2, b1_y1, b1_y2, b2_y1, b2_y2) + else: + c2 = cw**2 + ch**2 + eps # convex diagonal squared + if use_fused_op: + rho2 = fuse_get_center_dist(b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2) + else: + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - # v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) - v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) - alpha = v / (v - iou + (1 + eps)) - alpha = ops.stop_gradient(alpha) - return iou - (rho2 / c2 + v * alpha) # CIoU + if use_fused_op: + v = fused_get_ciou_diagonal_angle(w1, h1, w2, h2) + else: + # v = (4 / get_pi(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) + v = (4 / PI.astype(iou.dtype) ** 2) * ops.pow(ops.atan(w2 / (h2 + eps)) - ops.atan(w1 / (h1 + eps)), 2) + if use_fused_op: + _, res = fuse_get_ciou(v, iou, rho2, c2) + else: + alpha = v / (v - iou + (1 + eps)) + alpha = ops.stop_gradient(alpha) + res = iou - (rho2 / c2 + v * alpha) # CIoU + return res return iou - rho2 / c2 # DIoU c_area = cw * ch + eps # convex area return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf diff --git a/mindyolo/models/losses/yolov3_loss.py b/mindyolo/models/losses/yolov3_loss.py index 516f06bc..d846eeb9 100644 --- a/mindyolo/models/losses/yolov3_loss.py +++ b/mindyolo/models/losses/yolov3_loss.py @@ -17,7 +17,7 @@ @register_model class YOLOv3Loss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs ): super(YOLOv3Loss, self).__init__() self.hyp_box = box @@ -63,6 +63,7 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print + self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -93,7 +94,7 @@ def construct(self, p, targets, imgs): pxy = ops.Sigmoid()(pxy) * 2 - 0.5 pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index] pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) # iou = iou * tmask # lbox += ((1.0 - iou) * tmask).mean() # iou loss lbox += (((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None)).astype(iou.dtype) diff --git a/mindyolo/models/losses/yolov4_loss.py b/mindyolo/models/losses/yolov4_loss.py index 51abf2cb..7fa5d180 100644 --- a/mindyolo/models/losses/yolov4_loss.py +++ b/mindyolo/models/losses/yolov4_loss.py @@ -31,7 +31,7 @@ def construct(self, object_mask, predict_confidence, ignore_mask): @register_model class YOLOv4Loss(nn.Cell): - def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, **kwargs): + def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_threshold, anchors, nc, use_fused_op=False, **kwargs): super(YOLOv4Loss, self).__init__() self.ignore_threshold = ignore_threshold self.iou = Iou() @@ -57,6 +57,7 @@ def __init__(self, box, obj, cls, label_smoothing, ignore_threshold, iou_thresho self.concat = ops.Concat(axis=-1) self.reduce_max = ops.ReduceMax(keep_dims=False) + self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): image_shape = imgs.shape @@ -94,7 +95,7 @@ def construct(self, p, targets, imgs): # Regression pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox, GIoU=True).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox, GIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) # iou = iou * tmask # lbox += ((1.0 - iou) * tmask).mean() # iou loss box_loss_scale = 2 - tbox[:, 2] * tbox[:, 3] / gain[0] / gain[1] @@ -160,7 +161,7 @@ def build_targets(self, p, targets, imgs): anchor_shapes = ops.zeros((na, 1, 4), ms.float32) anchor_shapes[..., 2:] = ops.ExpandDims()(self.anchors, 1) - anch_ious = bbox_iou(gt_box, anchor_shapes).squeeze() + anch_ious = bbox_iou(gt_box, anchor_shapes, use_fused_op=self.use_fused_op).squeeze() j = anch_ious == anch_ious.max(axis=0) l = anch_ious > self.iou_threshold diff --git a/mindyolo/models/losses/yolov5_loss.py b/mindyolo/models/losses/yolov5_loss.py index 890cefef..a1321d06 100644 --- a/mindyolo/models/losses/yolov5_loss.py +++ b/mindyolo/models/losses/yolov5_loss.py @@ -15,7 +15,7 @@ class YOLOv5Loss(nn.Cell): # Compute losses def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs ): super(YOLOv5Loss, self).__init__() @@ -64,6 +64,7 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print + self.use_fused_op = use_fused_op def scatter_index_tensor(self, x, index): x_tmp = ops.transpose(x.reshape((-1, x.shape[-1])), (1, 0)) @@ -101,7 +102,7 @@ def construct(self, p, targets, imgs): # predictions, targets pxy = ops.Sigmoid()(pxy) * 2 - 0.5 pwh = (ops.Sigmoid()(pwh) * 2) ** 2 * anchors[layer_index] pbox = ops.concat((pxy, pwh), 1) # predicted box - iou = bbox_iou(pbox, tbox[layer_index], CIoU=True).squeeze() # iou(prediction, target) + iou = bbox_iou(pbox, tbox[layer_index], CIoU=True, use_fused_op=self.use_fused_op).squeeze() # iou(prediction, target) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum() # iou loss # Objectness diff --git a/mindyolo/models/losses/yolov7_loss.py b/mindyolo/models/losses/yolov7_loss.py index 46258369..204d1dca 100644 --- a/mindyolo/models/losses/yolov7_loss.py +++ b/mindyolo/models/losses/yolov7_loss.py @@ -17,7 +17,7 @@ @register_model class YOLOv7Loss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op=False, **kwargs ): super(YOLOv7Loss, self).__init__() self.hyp_box = box @@ -63,6 +63,7 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by lossitem for print + self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -98,7 +99,7 @@ def construct(self, p, targets, imgs): pbox = ops.concat((pxy, pwh), 1) # predicted box selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] selected_tbox[:, :2] -= grid - iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1) + iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss # Objectness @@ -364,7 +365,7 @@ def find_3_positive(self, p, targets): @register_model class YOLOv7AuxLoss(nn.Cell): def __init__( - self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, **kwargs + self, box, obj, cls, anchor_t, label_smoothing, fl_gamma, cls_pw, obj_pw, anchors, stride, nc, use_fused_op, **kwargs ): super(YOLOv7AuxLoss, self).__init__() self.hyp_box = box @@ -416,6 +417,7 @@ def __init__( ) self.loss_item_name = ["loss", "lbox", "lobj", "lcls"] # branch name returned by loss for print + self.use_fused_op = use_fused_op def construct(self, p, targets, imgs): lcls, lbox, lobj = 0.0, 0.0, 0.0 @@ -471,7 +473,7 @@ def construct(self, p, targets, imgs): pbox = ops.concat((pxy, pwh), 1) # predicted box selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i] selected_tbox[:, :2] -= grid - iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True).view(-1) + iou = bbox_iou(pbox, selected_tbox, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) lbox += ((1.0 - iou) * tmask).sum() / tmask.astype(iou.dtype).sum().clip(1, None) # iou loss # 1.2. Objectness tobj[b, a, gj, gi] = ((1.0 - self.gr) + self.gr * ops.stop_gradient(iou).clip(0, None)) * tmask # iou ratio @@ -494,7 +496,7 @@ def construct(self, p, targets, imgs): pbox_aux = ops.concat((pxy_aux, pwh_aux), 1) # predicted box selected_tbox_aux = targets_aux[i][:, 2:6] * pre_gen_gains[i] selected_tbox_aux[:, :2] -= grid_aux - iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True).view(-1) + iou_aux = bbox_iou(pbox_aux, selected_tbox_aux, xywh=True, CIoU=True, use_fused_op=self.use_fused_op).view(-1) lbox += ( 0.25 * ((1.0 - iou_aux) * tmask_aux).sum() / tmask_aux.astype(iou_aux.dtype).sum().clip(1, None) ) # iou loss diff --git a/mindyolo/models/losses/yolov8_loss.py b/mindyolo/models/losses/yolov8_loss.py index 2e668d33..2ab47863 100644 --- a/mindyolo/models/losses/yolov8_loss.py +++ b/mindyolo/models/losses/yolov8_loss.py @@ -14,7 +14,7 @@ @register_model class YOLOv8Loss(nn.Cell): - def __init__(self, box, cls, dfl, stride, nc, reg_max=16, **kwargs): + def __init__(self, box, cls, dfl, stride, nc, reg_max=16, use_fused_op=False, **kwargs): super(YOLOv8Loss, self).__init__() self.bce = nn.BCEWithLogitsLoss(reduction="none") @@ -27,8 +27,8 @@ def __init__(self, box, cls, dfl, stride, nc, reg_max=16, **kwargs): self.reg_max = reg_max self.use_dfl = reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) - self.bbox_loss = BboxLoss(reg_max, use_dfl=self.use_dfl) + self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0, use_fused_op=use_fused_op) + self.bbox_loss = BboxLoss(reg_max, use_dfl=self.use_dfl, use_fused_op=use_fused_op) self.proj = mnp.arange(reg_max) # ops @@ -154,10 +154,11 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): class BboxLoss(nn.Cell): - def __init__(self, reg_max, use_dfl=False): + def __init__(self, reg_max, use_dfl=False, use_fused_op=False): super().__init__() self.reg_max = reg_max self.use_dfl = use_dfl + self.use_fused_op = use_fused_op def construct( self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask @@ -174,7 +175,7 @@ def construct( """ # IoU loss weight = target_scores.sum(-1).expand_dims(-1) # (bs, N, num_classes) -> (bs, N) -> (bs, N, 1) - iou = bbox_iou(pred_bboxes, target_bboxes, xywh=False, CIoU=True) + iou = bbox_iou(pred_bboxes, target_bboxes, xywh=False, CIoU=True, use_fused_op=self.use_fused_op) loss_iou = ((1.0 - iou) * weight * fg_mask.expand_dims(2)).sum() / target_scores_sum # DFL loss @@ -219,7 +220,7 @@ def _df_loss(pred_dist, target): class TaskAlignedAssigner(nn.Cell): - def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, use_fused_op=False): super().__init__() self.topk = topk self.num_classes = num_classes @@ -227,6 +228,7 @@ def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): self.alpha = alpha self.beta = beta self.eps = eps + self.use_fused_op=use_fused_op def construct(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): """This code referenced to @@ -310,7 +312,7 @@ def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): # (b, n_gt, 1, 4), (b, 1, N, 4) -> (b, n_gt, N) overlaps = ( - bbox_iou(gt_bboxes.expand_dims(2), pd_bboxes.expand_dims(1), xywh=False, CIoU=True).squeeze(3).clip(0, None) + bbox_iou(gt_bboxes.expand_dims(2), pd_bboxes.expand_dims(1), xywh=False, CIoU=True, use_fused_op=self.use_fused_op).squeeze(3).clip(0, None) ) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) return align_metric, overlaps diff --git a/mindyolo/models/losses/yolox_loss.py b/mindyolo/models/losses/yolox_loss.py index 1ebc77d1..55c3ae56 100644 --- a/mindyolo/models/losses/yolox_loss.py +++ b/mindyolo/models/losses/yolox_loss.py @@ -23,6 +23,7 @@ def __init__( strides=(8, 16, 32), use_l1=False, use_summary=False, + use_fused_op=False, **kwargs ): super(YOLOXLoss, self).__init__() @@ -52,6 +53,7 @@ def __init__( self.assign = ops.Assign() self.loss_item_name = ["loss", "lbox", "lobj", "lcls", "lboxl1"] # branch name returned by lossitem for print + self.use_fused_op = use_fused_op def _get_anchor_center_and_stride(self, norm=False): """ @@ -255,7 +257,7 @@ def construct(self, preds, targets, imgs=None): loss_l1 = ops.reduce_sum(self.l1_loss(l1_preds, l1_target), -1) * obj_target loss_l1 = ops.reduce_sum(loss_l1) # calculate target -----------END------------------------------------------------------------------------------- - iou = bbox_iou(bbox_preds.reshape(-1, 4), reg_target.reshape(-1, 4), xywh=True).reshape(batch_size, -1) + iou = bbox_iou(bbox_preds.reshape(-1, 4), reg_target.reshape(-1, 4), xywh=True, use_fused_op=self.use_fused_op).reshape(batch_size, -1) loss_iou = (1 - iou * iou) * obj_target # (bs, num_total_anchor) loss_iou = ops.reduce_sum(loss_iou) diff --git a/mindyolo/utils/utils.py b/mindyolo/utils/utils.py index 53b74b7a..cf056ff2 100644 --- a/mindyolo/utils/utils.py +++ b/mindyolo/utils/utils.py @@ -87,6 +87,11 @@ def set_default(args): args.data.test_set = os.path.join(args.data_dir, args.data.test_set) args.weight = args.ckpt_dir if args.ckpt_dir else "" args.ema_weight = os.path.join(args.ckpt_dir, args.ema_weight) if args.ema_weight else "" + + # Check Custom operator settings. + if args.device_target != "GPU" and args.use_fused_op: + logger.warning(f"mindyolo only support aot custom operator on GPU currently, please check configurations") + args.use_fused_op = False def load_pretrain(network, weight, ema=None, ema_weight=None): diff --git a/setup.py b/setup.py index 16c1f80f..3acab6f3 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,19 @@ #!/usr/bin/env python import os.path +import subprocess import pathlib import sys +import glob from setuptools import find_packages, setup exec(open("mindyolo/version.py").read()) here = pathlib.Path(__file__).parent.resolve() -long_description = (here / 'README.md').read_text(encoding='utf-8') +long_description = (here / "README.md").read_text(encoding="utf-8") -def parse_requirements(path=here / 'requirements.txt'): +def parse_requirements(path=here / "requirements.txt"): """parse requirements in file""" pkgs = [] if not os.path.exists(path): @@ -22,12 +24,25 @@ def parse_requirements(path=here / 'requirements.txt'): if line.isspace(): continue line = line.strip() - if line.startswith('#'): + if line.startswith("#"): continue pkgs.append(line) return pkgs +def compile_fused_op(path=here / "mindyolo/models/losses/fused_op"): + nvcc_result = subprocess.run(["nvcc", "--version"], timeout=3, capture_output=True).stdout + if "command not found" in str(nvcc_result): + print("nvcc not configured properly, skipped compiling fused operator.") + return + for fused_op_src in glob.glob(str(path / "*_kernel.cu")): + fused_op_so = f"{fused_op_src[:-3]}.so" + so_path = str(path / fused_op_so) + nvcc_cmd = "nvcc --shared -Xcompiler -fPIC -o " + so_path + " " + fused_op_src + print("nvcc compiler cmd: {}".format(nvcc_cmd)) + os.system(nvcc_cmd) + + # add c++ extension ext_modules = [] try: @@ -36,14 +51,14 @@ def parse_requirements(path=here / 'requirements.txt'): ext_modules = [ Pybind11Extension( name="mindyolo.csrc.fast_coco_eval.fast_coco_eval", # use relative path - sources=['mindyolo/csrc/fast_coco_eval/cocoeval/cocoeval.cpp'], # use relative path - include_dirs=['mindyolo/csrc/fast_coco_eval/cocoeval'], # use relative path + sources=["mindyolo/csrc/fast_coco_eval/cocoeval/cocoeval.cpp"], # use relative path + include_dirs=["mindyolo/csrc/fast_coco_eval/cocoeval"], # use relative path extra_compile_args=args ), ] except ImportError: pass - +compile_fused_op() setup( name="mindyolo", author="MindSpore Ecosystem", @@ -58,6 +73,7 @@ def parse_requirements(path=here / 'requirements.txt'): license="Apache Software License 2.0", include_package_data=True, packages=find_packages(include=["mindyolo", "mindyolo.*"]), + package_data={"mindyolo": ["models/losses/fused_op/*_kernel.so"]}, install_requires=parse_requirements(), python_requires=">=3.7", classifiers=[ diff --git a/train.py b/train.py index 393ba490..12247d22 100644 --- a/train.py +++ b/train.py @@ -81,6 +81,8 @@ def get_parser_train(parents=None): help="ModelArts: local device path to dataset folder") parser.add_argument("--ckpt_dir", type=str, default="/cache/pretrain_ckpt/", help="ModelArts: local device path to checkpoint folder") + parser.add_argument("--use_fused_op", type=ast.literal_eval, default=False, + help="Whether to use aot custom operator to accelerate GPU computation") return parser @@ -185,7 +187,8 @@ def train(args): # Create Loss loss_fn = create_loss( - **args.loss, anchors=args.network.get("anchors", 1), stride=args.network.stride, nc=args.data.nc + **args.loss, anchors=args.network.get("anchors", 1), stride=args.network.stride, nc=args.data.nc, + use_fused_op=args.use_fused_op ) ms.amp.auto_mixed_precision(loss_fn, amp_level="O0" if args.keep_loss_fp32 else args.ms_amp_level)