From 5cc5ca31db3d9eee3dbae71cf11248e5e85852d0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 7 May 2024 14:49:25 +0200 Subject: [PATCH] First commit --- CHANGELOGS.rst | 1 + _cmake/targets/ortops_optim_cuda.cmake | 1 + _unittests/ut_ortops/test_optim_cuda.py | 109 ++++++ onnx_extended/ortops/optim/cuda/__init__.py | 31 ++ .../ortops/optim/cuda/ort_optim_cuda_lib.cc | 7 + .../ortops/optim/cuda/scatter_nd_of_shape.h | 18 +- .../optim/cuda/scatter_nd_of_shape_common.h | 22 ++ .../optim/cuda/scatter_nd_of_shape_masked.cu | 329 ++++++++++++++++++ .../optim/cuda/scatter_nd_of_shape_masked.h | 44 +++ 9 files changed, 545 insertions(+), 17 deletions(-) create mode 100644 onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_common.h create mode 100644 onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu create mode 100644 onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.h diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9e297f94..aaab8abd 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.0 +++++ +* :pr:`180`: add MaskedScatterNDOfShape custom operator * :pr:`175`: adds custom operator MulSub and SubMul on CUDA * :pr:`173`: adds custom operator AddSharedInput, MulSharedInput on CUDA * :pr:`170`: adds custom operator TriMatrix on CUDA diff --git a/_cmake/targets/ortops_optim_cuda.cmake b/_cmake/targets/ortops_optim_cuda.cmake index a7f976ed..2836a032 100644 --- a/_cmake/targets/ortops_optim_cuda.cmake +++ b/_cmake/targets/ortops_optim_cuda.cmake @@ -20,6 +20,7 @@ if(CUDA_AVAILABLE) ../onnx_extended/ortops/optim/cuda/replace_zero.cu ../onnx_extended/ortops/optim/cuda/rotary.cu ../onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu + ../onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu ../onnx_extended/ortops/optim/cuda/submul.cu ../onnx_extended/ortops/optim/cuda/transpose_cast_2d.cu ../onnx_extended/ortops/optim/cuda/tri_matrix.cu diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index 9b74232e..222048c6 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -120,6 +120,115 @@ def test_scatternd_of_shape_standalone_cuda(self): self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT) self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16) + def _masked_scatternd_of_shape_cuda(self, reduction, line, itype): + import onnxruntime + from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + + model1 = oh.make_model( + oh.make_graph( + [ + oh.make_node("Equal", ["indices", "mone"], ["masked_indices"]), + oh.make_node( + "Where", + ["masked_indices", "zero", "updates"], + ["masked_updates"], + ), + oh.make_node( + "ScatterND", + inputs=["data", "indices", "masked_updates"], + outputs=["y"], + reduction=reduction, + ), + ], + "nd", + [ + oh.make_tensor_value_info("data", itype, [None, None]), + oh.make_tensor_value_info( + "indices", TensorProto.INT64, [None, None, 1] + ), + oh.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [oh.make_tensor_value_info("y", itype, [None, None])], + [ + onh.from_array(np.array([-1], dtype=np.int64), name="mone"), + onh.from_array(np.array([0], dtype=dtype), name="zero"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "MaskedScatterNDOfShape", + inputs=["shape", "indices", "updates"], + outputs=["y"], + reduction=reduction, + maskedValue=-1, + domain="onnx_extended.ortops.optim.cuda", + ) + ], + "nd", + [ + oh.make_tensor_value_info("shape", TensorProto.INT64, [None]), + oh.make_tensor_value_info( + "indices", TensorProto.INT64, [None, None, 1] + ), + oh.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [oh.make_tensor_value_info("y", itype, [None, None])], + ), + opset_imports=[ + oh.make_opsetid("", 18), + oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + ], + ir_version=9, + ) + + data = np.zeros((32, 16), dtype=dtype) + indices = np.array( + [ + [0, 1, 2], + [2, 3, 4], + [-1, 30, 31], + [-1, 7, 8], + [10, 11, -1], + [20, -1, 21], + ], + dtype=np.int64, + ) + indices = indices[..., np.newaxis] + shape = (6, 3, data.shape[-1]) + updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype) + + feeds1 = dict(data=data, indices=indices, updates=updates) + feeds2 = dict( + shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates + ) + ref = CReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = onnxruntime.SessionOptions() + opts.register_custom_ops_library(get_ort_ext_libs()[0]) + # opts.log_severity_level = 0 + # opts.log_verbosity_level = 0 + sess = onnxruntime.InferenceSession( + model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"] + ) + got = sess.run(None, feeds2)[0] + self.assertEqual(expected.tolist(), got.tolist()) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_masked_scatternd_of_shape_standalone_cuda(self): + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT) + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16) + def _addaddmulmul_cuda(self, itype, op_type, broad=False): import onnxruntime from onnx_extended.ortops.optim.cuda import get_ort_ext_libs diff --git a/onnx_extended/ortops/optim/cuda/__init__.py b/onnx_extended/ortops/optim/cuda/__init__.py index 5dccf705..8796038f 100644 --- a/onnx_extended/ortops/optim/cuda/__init__.py +++ b/onnx_extended/ortops/optim/cuda/__init__.py @@ -133,6 +133,37 @@ def documentation() -> List[str]: **Constraints** + * T: float, float16 + """, + """ + onnx_extended.ortops.optim.cuda.MaskedScatterNDOfShape + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + ConstantOfShape + Where + ScatterND, + updates a null matrix with updates if only indices are not + equal to a value (usually -1) + + **Provider** + + CUDAExecutionProvider + + **Attributes** + + * maskedValue (int): updates are ignore the indices are equal to this value. + + **Inputs** + + * shape (I): tensor of type I + * indices (I): tensor of type I + * updates (T): tensor of type T + + **Outputs** + + * Z (T): updated tensor + + **Constraints** + + * I: int64 * T: float, float16 """, """ diff --git a/onnx_extended/ortops/optim/cuda/ort_optim_cuda_lib.cc b/onnx_extended/ortops/optim/cuda/ort_optim_cuda_lib.cc index 919939b9..103f9abb 100644 --- a/onnx_extended/ortops/optim/cuda/ort_optim_cuda_lib.cc +++ b/onnx_extended/ortops/optim/cuda/ort_optim_cuda_lib.cc @@ -16,6 +16,7 @@ #include "replace_zero.h" #include "rotary.h" #include "scatter_nd_of_shape.h" +#include "scatter_nd_of_shape_masked.h" #include "submul.h" #include "transpose_cast_2d.h" #include "tri_matrix.h" @@ -76,6 +77,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, static ortops::ScatterNDOfShapeOp c_ScatterNDOfShapeOp32; static ortops::ScatterNDOfShapeOp c_ScatterNDOfShapeOp16; + static ortops::MaskedScatterNDOfShapeOp c_MaskedScatterNDOfShapeOp32; + static ortops::MaskedScatterNDOfShapeOp c_MaskedScatterNDOfShapeOp16; + static ortops::Transpose2DCastOp c_Transpose2DCast16(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); static ortops::Transpose2DCastOp c_Transpose2DCast32(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, @@ -128,6 +132,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, domain.Add(&c_ScatterNDOfShapeOp32); domain.Add(&c_ScatterNDOfShapeOp16); + domain.Add(&c_MaskedScatterNDOfShapeOp32); + domain.Add(&c_MaskedScatterNDOfShapeOp16); + domain.Add(&c_Transpose2DCast16); domain.Add(&c_Transpose2DCast32); diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h index 2cdce54f..ff0d6efb 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h @@ -2,27 +2,11 @@ #include "common/common_kernels.h" #include "cublas_v2.h" +#include "scatter_nd_of_shape_common.h" #include namespace ortops { -enum class Reduction : int { - None = 0, - Add = 1, - Mul = 2, - Min = 3, - Max = 4, -}; - -enum class Strategy : int { - None = 0, - Optimize = 1, -}; - -struct Shape2 { - int64_t dims[12]; -}; - template struct ScatterNDOfShapeKernel { ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info); void Compute(OrtKernelContext *context); diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_common.h b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_common.h new file mode 100644 index 00000000..d5ec4caa --- /dev/null +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_common.h @@ -0,0 +1,22 @@ +#pragma once + +namespace ortops { + +enum class Reduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +enum class Strategy : int { + None = 0, + Optimize = 1, +}; + +struct Shape2 { + int64_t dims[12]; +}; + +} // namespace ortops diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu new file mode 100644 index 00000000..a6f8c99f --- /dev/null +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.cu @@ -0,0 +1,329 @@ +#include "common/c_op_helpers.h" +#include "common/common_kernels.h" +#include "cuda/common_kernels_cuda.h" +#include "scatter_nd_of_shape_masked.h" +#include +#include +#include +#include + +namespace ortops { + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +template __device__ __forceinline__ void _add_inplace(T &x, const T a) { x += a; } + +template <> __device__ __forceinline__ void _add_inplace(half &x, const half a) { +#if __CUDA_ARCH__ < 700 + x = __float2half(__half2float(x) + __half2float(a)); +#else + x += a; +#endif +} + +template +__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + auto id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= stride) + return; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + for (size_t i = 0; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + _add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]); + } +} + +#ifdef ENABLE_NCONT + +template +__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x * NCONT; + + T *out; + for (size_t i = 0; i < nrows; ++i) { + out = output_data + i * stride + id; +#pragma unroll + for (int k = 0; k < NCONT; ++k) { + out[k] = 0; + } + } + + const T *up; + for (size_t i = 0; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + out = output_data + (indices_data[i] * stride + id); + up = updates_data + (i * stride + id); +#pragma unroll + for (int k = 0; k < NCONT; ++k) { + out[k] += up[k]; + } + } +} + +#endif + +////////////////// +// MaskedScatterNDOfShapeOp... +////////////////// + +template +void *MaskedScatterNDOfShapeOp::CreateKernel(const OrtApi &api, + const OrtKernelInfo *info) const { + return std::make_unique>(api, info).release(); +} + +template const char *MaskedScatterNDOfShapeOp::GetName() const { + return "MaskedScatterNDOfShape"; +} + +template +const char *MaskedScatterNDOfShapeOp::GetExecutionProviderType() const { + return "CUDAExecutionProvider"; +} + +template size_t MaskedScatterNDOfShapeOp::GetInputTypeCount() const { + return 3; +}; + +template <> +ONNXTensorElementDataType +MaskedScatterNDOfShapeOp::GetInputType(std::size_t index) const { + switch (index) { + case 0: + case 1: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + case 2: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + default: + EXT_THROW("Input index=", (int64_t)index, " is out of boundary."); + } +} + +template <> +ONNXTensorElementDataType +MaskedScatterNDOfShapeOp::GetInputType(std::size_t index) const { + switch (index) { + case 0: + case 1: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + case 2: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + default: + EXT_THROW("Input index=", (int64_t)index, " is out of boundary."); + } +} + +template +OrtMemType MaskedScatterNDOfShapeOp::GetInputMemoryType(std::size_t index) const { + switch (index) { + case 0: + return OrtMemTypeCPUInput; + case 1: + case 2: + return OrtMemTypeDefault; + default: + EXT_THROW("Input index=", (int64_t)index, " is out of boundary."); + } +} + +template +OrtCustomOpInputOutputCharacteristic +MaskedScatterNDOfShapeOp::GetInputCharacteristic(std::size_t index) const { + switch (index) { + case 0: + case 1: + case 2: + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + default: + EXT_THROW("Output index=", (uint64_t)index, " is out of boundary."); + } +} + +template size_t MaskedScatterNDOfShapeOp::GetOutputTypeCount() const { + return 1; +} + +template <> +ONNXTensorElementDataType +MaskedScatterNDOfShapeOp::GetOutputType(std::size_t index) const { + // D, scale D + switch (index) { + case 0: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + default: + EXT_THROW("Output index=", (uint64_t)index, " is out of boundary."); + } +} + +template <> +ONNXTensorElementDataType +MaskedScatterNDOfShapeOp::GetOutputType(std::size_t index) const { + // D, scale D + switch (index) { + case 0: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + default: + EXT_THROW("Output index=", (uint64_t)index, " is out of boundary."); + } +} + +template +OrtCustomOpInputOutputCharacteristic +MaskedScatterNDOfShapeOp::GetOutputCharacteristic(std::size_t index) const { + switch (index) { + case 0: + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + default: + EXT_THROW("Output index=", (uint64_t)index, " is out of boundary."); + } +} + +/////////////////// +// MaskedScatterNDOfShapeKernel +/////////////////// + +template +MaskedScatterNDOfShapeKernel::MaskedScatterNDOfShapeKernel(const OrtApi &api, + const OrtKernelInfo *info) { + char value_string[1000]; + std::size_t size = 1000; + ThrowOnError(api, api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size)); + std::string value = value_string; + if (value == "add") + reduction_ = Reduction::Add; + else + EXT_THROW("unexpected reduction '", value, "'."); + + ThrowOnError(api, api.KernelInfoGetAttribute_int64(info, "maskedValue", &masked_value_)); + + cudaDeviceProp prop; + int deviceId = 0; + cudaGetDeviceProperties(&prop, deviceId); + maxThreadPerBlock_ = prop.maxThreadsPerBlock; +} + +template void MaskedScatterNDOfShapeKernel::Compute(OrtKernelContext *context) { + Ort::KernelContext ctx(context); + + int n_inputs = ctx.GetInputCount(); + EXT_ENFORCE(n_inputs == 3, "Expected 3 inputs not ", n_inputs, "."); + Ort::ConstValue shape = ctx.GetInput(0); + Ort::ConstValue indices = ctx.GetInput(1); + Ort::ConstValue updates = ctx.GetInput(2); + Ort::UnownedValue output; + + std::vector dimensions = shape.GetTensorTypeAndShapeInfo().GetShape(); + std::vector indices_shape = indices.GetTensorTypeAndShapeInfo().GetShape(); + std::vector updates_shape = updates.GetTensorTypeAndShapeInfo().GetShape(); + EXT_ENFORCE(dimensions.size() == 1, "shape must be a 1-dimension tensor."); + + cudaStream_t stream = (cudaStream_t)ctx.GetGPUComputeStream(); + + auto memi = updates.GetTensorMemoryInfo(); + EXT_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, + "updates are not on GPU"); + + auto mem = shape.GetTensorMemoryInfo(); + EXT_ENFORCE( + mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, + "The shape should be on CPU already, but mem.GetDeviceType()=", mem.GetDeviceType(), "."); + const int64_t *X = shape.GetTensorData(); + std::vector dims(X, X + dimensions[0]); + output = ctx.GetOutput(0, dims); + + std::vector input_shape = output.GetTensorTypeAndShapeInfo().GetShape(); + + if (reduction_ == Reduction::Add && indices_shape[indices_shape.size() - 1] == 1 && + input_shape.size() == 2) { + + size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); + size_t update_size = static_cast(onnx_c_ops::flattened_dimension(updates_shape)); + + EXT_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1], + "Size mismatch, update_size=", update_size, "indice_size=", indice_size, + "input_shape[-1]=", input_shape[input_shape.size() - 1], "."); + + ComputeOptimize(stream, input_shape, indices_shape, output.GetTensorMutableData(), + indices.GetTensorData(), updates.GetTensorData()); + } else { + EXT_THROW("Only add reduction and 2D tensors are supported, reduction is ", (int)reduction_, + "input_shape.size()=", static_cast(input_shape.size()), + " indices_shape[indices_shape.size() - 1]=", + static_cast(indices_shape[indices_shape.size() - 1]), "."); + } +} + +template +void _ComputeOptimize(cudaStream_t stream, const std::vector &input_shape, + const std::vector &indices_shape, T *output_data, + const int64_t *indices_data, const T *updates_data, + int maxThreadPerBlock_, int64_t masked_value_) { + + // The kernel is slow if there are a lot of duplicates. + // reduction_ == Reduction::add + // indices_shape[indices_shape.size() - 1] == 1 + // input_shape.size() == 2 + size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); + size_t input_size = static_cast(onnx_c_ops::flattened_dimension(input_shape)); + size_t stride = input_shape[input_shape.size() - 1]; + size_t nrows = input_size / stride; + + std::vector next_batch(indice_size); + std::vector processed(input_shape[0], 0); + std::vector processed_once(input_shape[0], 0); + + int threads_per_block = std::min(256, maxThreadPerBlock_ / 8); + +#ifdef ENABLE_NCONT +#define NCONT 65536 + if (stride % NCONT == 0 && stride > threads_per_block * NCONT) { + int blocks_per_grid = (stride / NCONT + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + masked_addition_inplace_kernelN<<>>( + output_data, indices_data, updates_data, indice_size, nrows, stride, masked_value_); + } else { +#endif + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + masked_addition_inplace_kernel<<>>( + output_data, indices_data, updates_data, indice_size, nrows, stride, masked_value_); +#ifdef ENABLE_NCONT + } +#endif +} + +template +void MaskedScatterNDOfShapeKernel::ComputeOptimize(cudaStream_t &stream, + const std::vector &input_shape, + const std::vector &indices_shape, + T *output_data, + const int64_t *indices_data, + const T *updates_data) const { + _ComputeOptimize(stream, input_shape, indices_shape, output_data, indices_data, updates_data, + maxThreadPerBlock_, masked_value_); +} + +static MaskedScatterNDOfShapeOp _op32; +static MaskedScatterNDOfShapeOp _op16; + +} // namespace ortops diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.h b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.h new file mode 100644 index 00000000..50afc81e --- /dev/null +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape_masked.h @@ -0,0 +1,44 @@ +#pragma once + +#include "common/common_kernels.h" +#include "cublas_v2.h" +#include "scatter_nd_of_shape_common.h" +#include + +namespace ortops { + +template struct MaskedScatterNDOfShapeKernel { + MaskedScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info); + void Compute(OrtKernelContext *context); + +private: + void ComputeOptimize(cudaStream_t &stream, const std::vector &input_shape, + const std::vector &indices_shape, T *output_data, + const int64_t *indices_data, const T *updates_data) const; + + Reduction reduction_; + int maxThreadPerBlock_; + int64_t masked_value_; +}; + +template +struct MaskedScatterNDOfShapeOp + : Ort::CustomOpBase, MaskedScatterNDOfShapeKernel> { + typedef Ort::CustomOpBase, MaskedScatterNDOfShapeKernel> + parent_type; + MaskedScatterNDOfShapeOp() : parent_type() {} + void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const; + const char *GetName() const; + const char *GetExecutionProviderType() const; + + std::size_t GetInputTypeCount() const; + ONNXTensorElementDataType GetInputType(std::size_t index) const; + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const; + OrtMemType GetInputMemoryType(std::size_t index) const; + + std::size_t GetOutputTypeCount() const; + ONNXTensorElementDataType GetOutputType(std::size_t index) const; + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(std::size_t index) const; +}; + +} // namespace ortops