Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed May 7, 2024
1 parent e2e4f39 commit 5cc5ca3
Show file tree
Hide file tree
Showing 9 changed files with 545 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.3.0
+++++

* :pr:`180`: add MaskedScatterNDOfShape custom operator
* :pr:`175`: adds custom operator MulSub and SubMul on CUDA
* :pr:`173`: adds custom operator AddSharedInput, MulSharedInput on CUDA
* :pr:`170`: adds custom operator TriMatrix on CUDA
Expand Down
1 change: 1 addition & 0 deletions _cmake/targets/ortops_optim_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ if(CUDA_AVAILABLE)
../onnx_extended/ortops/optim/cuda/replace_zero.cu
../onnx_extended/ortops/optim/cuda/rotary.cu
../onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu
../onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu
../onnx_extended/ortops/optim/cuda/submul.cu
../onnx_extended/ortops/optim/cuda/transpose_cast_2d.cu
../onnx_extended/ortops/optim/cuda/tri_matrix.cu
Expand Down
109 changes: 109 additions & 0 deletions _unittests/ut_ortops/test_optim_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,115 @@ def test_scatternd_of_shape_standalone_cuda(self):
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)

def _masked_scatternd_of_shape_cuda(self, reduction, line, itype):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16

model1 = oh.make_model(
oh.make_graph(
[
oh.make_node("Equal", ["indices", "mone"], ["masked_indices"]),
oh.make_node(
"Where",
["masked_indices", "zero", "updates"],
["masked_updates"],
),
oh.make_node(
"ScatterND",
inputs=["data", "indices", "masked_updates"],
outputs=["y"],
reduction=reduction,
),
],
"nd",
[
oh.make_tensor_value_info("data", itype, [None, None]),
oh.make_tensor_value_info(
"indices", TensorProto.INT64, [None, None, 1]
),
oh.make_tensor_value_info("updates", itype, [None, None, None]),
],
[oh.make_tensor_value_info("y", itype, [None, None])],
[
onh.from_array(np.array([-1], dtype=np.int64), name="mone"),
onh.from_array(np.array([0], dtype=dtype), name="zero"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)

model2 = oh.make_model(
oh.make_graph(
[
oh.make_node(
"MaskedScatterNDOfShape",
inputs=["shape", "indices", "updates"],
outputs=["y"],
reduction=reduction,
maskedValue=-1,
domain="onnx_extended.ortops.optim.cuda",
)
],
"nd",
[
oh.make_tensor_value_info("shape", TensorProto.INT64, [None]),
oh.make_tensor_value_info(
"indices", TensorProto.INT64, [None, None, 1]
),
oh.make_tensor_value_info("updates", itype, [None, None, None]),
],
[oh.make_tensor_value_info("y", itype, [None, None])],
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
],
ir_version=9,
)

data = np.zeros((32, 16), dtype=dtype)
indices = np.array(
[
[0, 1, 2],
[2, 3, 4],
[-1, 30, 31],
[-1, 7, 8],
[10, 11, -1],
[20, -1, 21],
],
dtype=np.int64,
)
indices = indices[..., np.newaxis]
shape = (6, 3, data.shape[-1])
updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype)

feeds1 = dict(data=data, indices=indices, updates=updates)
feeds2 = dict(
shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates
)
ref = CReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = onnxruntime.SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])
# opts.log_severity_level = 0
# opts.log_verbosity_level = 0
sess = onnxruntime.InferenceSession(
model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]
)
got = sess.run(None, feeds2)[0]
self.assertEqual(expected.tolist(), got.tolist())

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_masked_scatternd_of_shape_standalone_cuda(self):
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT)
self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT)
self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16)

def _addaddmulmul_cuda(self, itype, op_type, broad=False):
import onnxruntime
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
Expand Down
31 changes: 31 additions & 0 deletions onnx_extended/ortops/optim/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,37 @@ def documentation() -> List[str]:
**Constraints**
* T: float, float16
""",
"""
onnx_extended.ortops.optim.cuda.MaskedScatterNDOfShape
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ConstantOfShape + Where + ScatterND,
updates a null matrix with updates if only indices are not
equal to a value (usually -1)
**Provider**
CUDAExecutionProvider
**Attributes**
* maskedValue (int): updates are ignore the indices are equal to this value.
**Inputs**
* shape (I): tensor of type I
* indices (I): tensor of type I
* updates (T): tensor of type T
**Outputs**
* Z (T): updated tensor
**Constraints**
* I: int64
* T: float, float16
""",
"""
Expand Down
7 changes: 7 additions & 0 deletions onnx_extended/ortops/optim/cuda/ort_optim_cuda_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "replace_zero.h"
#include "rotary.h"
#include "scatter_nd_of_shape.h"
#include "scatter_nd_of_shape_masked.h"
#include "submul.h"
#include "transpose_cast_2d.h"
#include "tri_matrix.h"
Expand Down Expand Up @@ -76,6 +77,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
static ortops::ScatterNDOfShapeOp<float> c_ScatterNDOfShapeOp32;
static ortops::ScatterNDOfShapeOp<half> c_ScatterNDOfShapeOp16;

static ortops::MaskedScatterNDOfShapeOp<float> c_MaskedScatterNDOfShapeOp32;
static ortops::MaskedScatterNDOfShapeOp<half> c_MaskedScatterNDOfShapeOp16;

static ortops::Transpose2DCastOp c_Transpose2DCast16(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
static ortops::Transpose2DCastOp c_Transpose2DCast32(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
Expand Down Expand Up @@ -128,6 +132,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
domain.Add(&c_ScatterNDOfShapeOp32);
domain.Add(&c_ScatterNDOfShapeOp16);

domain.Add(&c_MaskedScatterNDOfShapeOp32);
domain.Add(&c_MaskedScatterNDOfShapeOp16);

domain.Add(&c_Transpose2DCast16);
domain.Add(&c_Transpose2DCast32);

Expand Down
18 changes: 1 addition & 17 deletions onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,11 @@

#include "common/common_kernels.h"
#include "cublas_v2.h"
#include "scatter_nd_of_shape_common.h"
#include <cuda_runtime.h>

namespace ortops {

enum class Reduction : int {
None = 0,
Add = 1,
Mul = 2,
Min = 3,
Max = 4,
};

enum class Strategy : int {
None = 0,
Optimize = 1,
};

struct Shape2 {
int64_t dims[12];
};

template <typename T> struct ScatterNDOfShapeKernel {
ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
Expand Down
22 changes: 22 additions & 0 deletions onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

namespace ortops {

enum class Reduction : int {
None = 0,
Add = 1,
Mul = 2,
Min = 3,
Max = 4,
};

enum class Strategy : int {
None = 0,
Optimize = 1,
};

struct Shape2 {
int64_t dims[12];
};

} // namespace ortops
Loading

0 comments on commit 5cc5ca3

Please sign in to comment.