Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom op MulMulSigmoid #185

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ Change Logs
0.3.0
+++++

* :pr:`185`: adds custom operator MulMulSigmoid on CUDA
* :pr:`184`: use onnxruntime==1.18.0 as default
* :pr:`181`: add MaskedScatterNDOfShape custom operator
* :pr:`181`: adds MaskedScatterNDOfShape custom operator
* :pr:`175`: adds custom operator MulSub and SubMul on CUDA
* :pr:`173`: adds custom operator AddSharedInput, MulSharedInput on CUDA
* :pr:`170`: adds custom operator TriMatrix on CUDA
Expand Down
5 changes: 0 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,17 @@ onnx-extended: extensions for onnx and onnxruntime

.. image:: https://dev.azure.com/xavierdupre3/onnx-extended/_apis/build/status/sdpython.onnx-extended
:target: https://dev.azure.com/xavierdupre3/onnx-extended/

.. image:: https://badge.fury.io/py/onnx-extended.svg
:target: http://badge.fury.io/py/onnx-extended

.. image:: http://img.shields.io/github/issues/sdpython/onnx-extended.png
:alt: GitHub Issues
:target: https://github.com/sdpython/onnx-extended/issues

.. image:: https://img.shields.io/badge/license-MIT-blue.svg
:alt: MIT License
:target: https://opensource.org/license/MIT/

.. image:: https://img.shields.io/github/repo-size/sdpython/onnx-extended
:target: https://github.com/sdpython/onnx-extended/
:alt: size

.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black

Expand Down
1 change: 1 addition & 0 deletions _cmake/targets/ortops_optim_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if(CUDA_AVAILABLE)
../onnx_extended/ortops/optim/cuda/addmul.cu
../onnx_extended/ortops/optim/cuda/add_or_mul_shared_input.cu
../onnx_extended/ortops/optim/cuda/mul_sigmoid.cu
../onnx_extended/ortops/optim/cuda/mul_mul_sigmoid.cu
../onnx_extended/ortops/optim/cuda/negxplus1.cu
../onnx_extended/ortops/optim/cuda/replace_zero.cu
../onnx_extended/ortops/optim/cuda/rotary.cu
Expand Down
76 changes: 76 additions & 0 deletions _unittests/ut_ortops/test_optim_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,82 @@ def test_addmul_transpose_numpy(self):
y = yr.reshape(tuple(new_shape))
self.assertEqualArray(t, y)

def _mulmulsigmoid_cuda(self, itype, broad=False, atol=1e-5, rtol=1e-3):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs

model1 = oh.make_model(
oh.make_graph(
[
oh.make_node("Mul", ["X", "Y"], ["xy"]),
oh.make_node("Sigmoid", ["Y"], ["sy"]),
oh.make_node("Mul", ["xy", "sy"], ["final"]),
],
"nd",
[
oh.make_tensor_value_info("X", itype, [None, None, None]),
oh.make_tensor_value_info("Y", itype, [None, None, None]),
],
[oh.make_tensor_value_info("final", itype, [None, None, None])],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)

model2 = oh.make_model(
oh.make_graph(
[
oh.make_node(
"MulMulSigmoid",
["X", "Y"],
["final"],
domain="onnx_extended.ortops.optim.cuda",
)
],
"nd",
[
oh.make_tensor_value_info("X", itype, [None, None, None]),
oh.make_tensor_value_info("Y", itype, [None, None, None]),
],
[oh.make_tensor_value_info("final", itype, [None, None, None])],
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
shapex = (1, 2, 3) if broad else (3, 2, 3)
shapey = (3, 2, 3)
x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype)
y = (np.arange(np.prod(shapey)) + 2).reshape(shapey).astype(dtype)
x /= x.size
y /= y.size

feeds1 = dict(X=x, Y=y)
ref = CReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])
sess = onnxruntime.InferenceSession(
model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]
)
got = sess.run(None, feeds1)[0]
self.assertEqualArray(expected, got, atol=atol, rtol=rtol)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulmulsigmoid_cuda(self):
self._mulmulsigmoid_cuda(TensorProto.FLOAT)
self._mulmulsigmoid_cuda(TensorProto.FLOAT16)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mulmulsigmoid_cuda_broadcast(self):
self._mulmulsigmoid_cuda(TensorProto.FLOAT, True)
self._mulmulsigmoid_cuda(TensorProto.FLOAT16, True)


if __name__ == "__main__":
unittest.main(verbosity=2)
35 changes: 29 additions & 6 deletions onnx_extended/ortops/optim/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,33 @@ def documentation() -> List[str]:
* T: float, float16
""",
"""
onnx_extended.ortops.optim.cuda.MulSoftmax
onnx_extended.ortops.optim.cuda.MulMulSigmoid
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Equivalent to Mul(X, Mul(Y, Sigmoid(Y))

**Provider**

CUDAExecutionProvider

**Inputs**

* X (T): tensor
* Y (T): tensor

**Outputs**

* Z (T): result

**Constraints**

* T: float, float16
""",
"""
onnx_extended.ortops.optim.cuda.MulSigmoid
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

MulSoftmax, equivalent to Mul(X, Softmax(X))
Equivalent to Mul(X, Sigmoid(X))

**Provider**

Expand Down Expand Up @@ -334,7 +357,7 @@ def documentation() -> List[str]:
onnx_extended.ortops.optim.cuda.NegXplus1
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

1 - X
Equivalent to 1 - X

**Provider**

Expand All @@ -356,7 +379,7 @@ def documentation() -> List[str]:
onnx_extended.ortops.optim.cuda.ReplaceZero
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

MulSoftmax, equivalent to Where(X == 0, cst, X)
Equivalent to Where(X == 0, cst, X)

**Provider**

Expand All @@ -378,7 +401,7 @@ def documentation() -> List[str]:
onnx_extended.ortops.optim.cuda.Rotary
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Rotary, equivalent to (side=="RIGHT")
Equivalent to (side=="RIGHT")

* Split(X, axis=-1) -> X1, X2
* Concat(-X2, X1)
Expand Down Expand Up @@ -407,7 +430,7 @@ def documentation() -> List[str]:
onnx_extended.ortops.optim.cuda.ScatterNDOfShape
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

ConstantOfShape + ScatterND
Equivalent to ConstantOfShape + ScatterND

**Provider**

Expand Down
165 changes: 165 additions & 0 deletions onnx_extended/ortops/optim/cuda/mul_mul_sigmoid.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#include "common/c_op_helpers.h"
#include "common/common_kernels.h"
#include "cuda/common_kernels_cuda.h"
#include "mul_mul_sigmoid.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>

namespace ortops {

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

struct GridDim {
enum : CUDA_LONG {
maxThreadsPerBlock = 256, // max threads per block
maxElementsPerThread = 4, // max element processed per thread
};
};

template <typename T> __device__ __inline__ T _exp_typed(const T x);

template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _exp_typed(const half x) {
return __float2half(expf(__half2float(x)));
}
#else
template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); }
#endif

template <typename T> __device__ __inline__ T sigmoid(const T a) {
return a > T(0) ? (T)1 / ((T)1. + _exp_typed<T>(-a))
: (T)1 - (T)1 / ((T)1 + _exp_typed<T>(a));
}

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half sigmoid(const half a) {
return __float2half(sigmoid(__half2float(a)));
}
#endif

template <typename T> __device__ __inline__ T mul_mul_sigmoid(const T x, const T y) {
return x * y * sigmoid(y);
}

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half mul_mul_sigmoid(const half x, const half y) {
float hy = __half2float(y);
return __float2half(__half2float(x) * hy * sigmoid(hy));
}
#endif

template <typename T>
__global__ void _MulMulSigmoidKernel(T *output_data, const T *px, const T *py, CUDA_LONG N,
CUDA_LONG Nx, CUDA_LONG Ny) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = mul_mul_sigmoid(px[id % Nx], py[id % Ny]);
}

template <typename T>
void MulMulSigmoidImpl(cudaStream_t stream, T *output_data, const T *px, const T *py,
size_t count_x, size_t count_y) {
if (count_x == 0 || count_y == 0)
// special case where there's a dim value of 0 in the output shape
return;

CUDA_LONG N = static_cast<CUDA_LONG>(std::max(count_x, count_y));

const int num_threads_per_block = GridDim::maxThreadsPerBlock;
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;

_MulMulSigmoidKernel<T><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
output_data, px, py, N, static_cast<CUDA_LONG>(count_x), static_cast<CUDA_LONG>(count_y));
}

//////////////////
// MulMulSigmoidOp...
//////////////////

template <typename T>
void *MulMulSigmoidOp<T>::CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
return std::make_unique<MulMulSigmoidKernel<T>>(api, info).release();
}

template <typename T> const char *MulMulSigmoidOp<T>::GetName() const {
return "MulMulSigmoid";
}

template <typename T> const char *MulMulSigmoidOp<T>::GetExecutionProviderType() const {
return "CUDAExecutionProvider";
}

template <typename T> size_t MulMulSigmoidOp<T>::GetInputTypeCount() const { return 2; };

template <typename T>
ONNXTensorElementDataType MulMulSigmoidOp<T>::GetInputType(std::size_t /* index */) const {
return CTypeToOnnxType<T>().onnx_type();
}

template <typename T>
ONNXTensorElementDataType MulMulSigmoidOp<T>::GetOutputType(std::size_t /* index */) const {
return CTypeToOnnxType<T>().onnx_type();
}

template <typename T>
OrtCustomOpInputOutputCharacteristic
MulMulSigmoidOp<T>::GetInputCharacteristic(std::size_t /* index */) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

template <typename T> size_t MulMulSigmoidOp<T>::GetOutputTypeCount() const { return 1; }

template <typename T>
OrtCustomOpInputOutputCharacteristic
MulMulSigmoidOp<T>::GetOutputCharacteristic(std::size_t /* index */) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

///////////////////
// MulMulSigmoidKernel
///////////////////

template <typename T>
MulMulSigmoidKernel<T>::MulMulSigmoidKernel(const OrtApi &api, const OrtKernelInfo *info) {}

template <typename T> void MulMulSigmoidKernel<T>::Compute(OrtKernelContext *context) {
Ort::KernelContext ctx(context);

int n_inputs = ctx.GetInputCount();
EXT_ENFORCE(n_inputs == 2, "Expected 2 inputs not ", n_inputs, ".");
Ort::ConstValue A = ctx.GetInput(0);
Ort::ConstValue B = ctx.GetInput(1);
Ort::UnownedValue output;

std::vector<int64_t> dimsA = A.GetTensorTypeAndShapeInfo().GetShape();
auto memi = A.GetTensorMemoryInfo();
EXT_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"first input is not on GPU");

std::vector<int64_t> dimsB = B.GetTensorTypeAndShapeInfo().GetShape();
memi = B.GetTensorMemoryInfo();
EXT_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"second input is not on GPU");

cudaStream_t cuda_stream = (cudaStream_t)ctx.GetGPUComputeStream();
// CUDA_THROW_IF_ERROR(cudaStreamSynchronize(cuda_stream));
size_t input_size_a = static_cast<size_t>(onnx_c_ops::flattened_dimension(dimsA));
size_t input_size_b = static_cast<size_t>(onnx_c_ops::flattened_dimension(dimsB));

output = ctx.GetOutput(0, input_size_a < input_size_b ? dimsB : dimsA);

MulMulSigmoidImpl(cuda_stream, output.GetTensorMutableData<T>(), A.GetTensorData<T>(),
B.GetTensorData<T>(), input_size_a, input_size_b);
}

static MulMulSigmoidOp<float> _kernel_f32;
static MulMulSigmoidOp<half> _kernel_f16;

} // namespace ortops
31 changes: 31 additions & 0 deletions onnx_extended/ortops/optim/cuda/mul_mul_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "common/common_kernels.h"
#include "cublas_v2.h"
#include <cuda_runtime.h>

namespace ortops {

template <typename T> struct MulMulSigmoidKernel {
MulMulSigmoidKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
};

template <typename T>
struct MulMulSigmoidOp : Ort::CustomOpBase<MulMulSigmoidOp<T>, MulMulSigmoidKernel<T>> {
typedef Ort::CustomOpBase<MulMulSigmoidOp<T>, MulMulSigmoidKernel<T>> parent_type;
MulMulSigmoidOp() : parent_type() {}
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const;
const char *GetName() const;
const char *GetExecutionProviderType() const;

std::size_t GetInputTypeCount() const;
ONNXTensorElementDataType GetInputType(std::size_t index) const;
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const;

std::size_t GetOutputTypeCount() const;
ONNXTensorElementDataType GetOutputType(std::size_t index) const;
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const;
};

} // namespace ortops
Loading
Loading