From 4de2ea2c6edfb56d2fe6a462284ab284321df7f0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 12 Apr 2024 09:29:33 +0200 Subject: [PATCH 1/8] Better implementation for ScatterND --- _doc/benchmarks.rst | 8 ++ _doc/examples/plot_op_scatternd_cuda.py | 116 ++++++++++++++++++++++++ clean_onnx.sh | 31 ++++++- 3 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 _doc/examples/plot_op_scatternd_cuda.py diff --git a/_doc/benchmarks.rst b/_doc/benchmarks.rst index f647cec0..b12fb5d3 100644 --- a/_doc/benchmarks.rst +++ b/_doc/benchmarks.rst @@ -99,6 +99,14 @@ See :ref:`l-example-op-mul_cuda`. The benchmark compares two operators Mul profiles with their fusion into a single operator. +plot_op_scatternd_cuda +++++++++++++++++++++++ + +See :ref:`l-example-op-scatternd_cuda`. + +The benchmark compares two operators Mul profiles +with their fusion into a single operator. + No specific provider ==================== diff --git a/_doc/examples/plot_op_scatternd_cuda.py b/_doc/examples/plot_op_scatternd_cuda.py new file mode 100644 index 00000000..53060154 --- /dev/null +++ b/_doc/examples/plot_op_scatternd_cuda.py @@ -0,0 +1,116 @@ +""" +.. _l-example-op-scatternd_cuda: + +===================================== +Optimizing ScatterND operator on CUDA +===================================== + +How to parallelize something like the following? + +ScatterND +========= + +This configuration happens in a :epkg:`LLAMA` model. + +:: + + gradient = ScatterND(zeros, indices, updates) + +Where the shapes are: + +* zeros: 32000x4906 +* indices: 2x1024x1 +* updates: 2x1024x4096 +""" + +import numpy as np +import onnx.helper as oh +from onnx import TensorProto +from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun +from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + +def get_model(d3=True): + indices_shape = ["i", "j", 1] if d3 else ["m", 1] + updates_shape = ["i", "j", "b"] if d3 else ["m", "b"] + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "ScatterND", ["X", "indices", "updates"], ["Y"], reduction="add" + ) + ], + "g", + [ + oh.make_tensor_value_info("X", TensorProto.FLOAT, ["a", "b"]), + oh.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + oh.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + ], + [oh.make_tensor_value_info("Y", TensorProto.FLOAT, ["a", "b"])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + return model + + +model = get_model() +print(onnx_simple_text_plot(model)) + + +########################################## +# Let's see the evaluation by the ReferenceEvaluator. + + +def _scatter_nd_impl(data, indices, updates, reduction=None): # type: ignore + output = np.copy(data) + for i in np.ndindex(indices.shape[:-1]): + print(f"updates for i={i}, indices={indices[i]}, updates={updates[i]}") + if reduction == "add": + output[tuple(indices[i])] += updates[i] + elif reduction == "mul": + output[tuple(indices[i])] *= updates[i] + elif reduction == "max": + output[tuple(indices[i])] = np.maximum(output[indices[i]], updates[i]) + elif reduction == "min": + output[tuple(indices[i])] = np.minimum(output[indices[i]], updates[i]) + else: + output[tuple(indices[i])] = updates[i] + return output + + +class ScatterND(OpRun): + def _run(self, data, indices, updates, reduction=None): # type: ignore + y = _scatter_nd_impl(data, indices, updates, reduction=reduction) + return (y,) + + +shape = (5, 7) +X = np.zeros(shape, dtype=np.float32) +indices = np.zeros((2, 10, 1)).astype(np.int64) +indices[:, ::2, 0] = 3 +updates = np.ones((2, 10, 7)).astype(np.float32) +feeds = {"X": X, "indices": indices, "updates": updates} + + + +ref = ReferenceEvaluator(model, new_ops=[ScatterND]) +got = ref.run(None, feeds)[0] +print(got) + + +########################################### +# To generalize, let's change the shapes. + +model = get_model(d3=False) +print(onnx_simple_text_plot(model)) + + +new_indices = indices.reshape((-1, 1)) +new_updates = updates.reshape((-1, updates.shape[-1])) +feeds = {"X": X, "indices": indices, "updates": updates} + +ref = ReferenceEvaluator(model, new_ops=[ScatterND]) +got = ref.run(None, feeds)[0] +print(got) diff --git a/clean_onnx.sh b/clean_onnx.sh index 938da293..8135eaa8 100644 --- a/clean_onnx.sh +++ b/clean_onnx.sh @@ -1,9 +1,32 @@ -rm _doc/tutorial/*.json rm *.onnx -rm *.data +rm *.json rm *.png +rm *.csv +rm *.nsys-rep +rm *.sqlite +rm tt_* +rm plot* +rm test* -rf +rm temp* -rf +rm dump* -rf +rm *.sarif +rm *.svg +rm dump_models -rf +rm neural_coder_workspace -rf +rm *.data rm .build_path.txt rm _doc/examples/plot*.onnx -rm _doc/examples/plot*.png -rm _doc/examples/plot*.csv +rm _doc/examples/plot*.txt +rm _doc/examples/ort*.onnx +rm _doc/examples/*.sarif +rm _doc/examples/*.json +rm _doc/examples/*.png +rm _doc/examples/*.csv +rm _doc/examples/*.xlsx +rm _doc/examples/dummy*.onnx +rm _doc/examples/*.opt.onnx +rm _doc/examples/*.dynamo.onnx +rm _doc/examples/*.script.onnx +rm _doc/examples/dump_models -rf +rm _doc/sg_execution* From ac98c953814f4960a2c7a3f01f85f0bc59f7f0be Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 12 Apr 2024 13:38:07 +0000 Subject: [PATCH 2/8] improve shape --- _doc/examples/plot_op_scatternd_cuda.py | 1 - _unittests/ut_ortops/test_optim_cuda.py | 81 +++++++ .../ortops/optim/cuda/scatter_nd_of_shape.cu | 207 ++++++++++++++++-- .../ortops/optim/cuda/scatter_nd_of_shape.h | 14 ++ .../reference/c_reference_evaluator.py | 4 + .../other_ops/op_scatternd_of_shape.py | 12 + 6 files changed, 301 insertions(+), 18 deletions(-) create mode 100644 onnx_extended/reference/other_ops/op_scatternd_of_shape.py diff --git a/_doc/examples/plot_op_scatternd_cuda.py b/_doc/examples/plot_op_scatternd_cuda.py index 53060154..eced8d07 100644 --- a/_doc/examples/plot_op_scatternd_cuda.py +++ b/_doc/examples/plot_op_scatternd_cuda.py @@ -94,7 +94,6 @@ def _run(self, data, indices, updates, reduction=None): # type: ignore feeds = {"X": X, "indices": indices, "updates": updates} - ref = ReferenceEvaluator(model, new_ops=[ScatterND]) got = ref.run(None, feeds)[0] print(got) diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index 5858a1c3..ad60e1cf 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -193,6 +193,87 @@ def test_addadd_cuda(self): self._addaddmulmul_cuda(TensorProto.FLOAT, "Add") self._addaddmulmul_cuda(TensorProto.FLOAT16, "Add") + def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): + import onnxruntime + from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + + indices_shape = ["i", "j", 1] if dim3 else ["j", 1] + updates_shape = ["i", "j", "b"] if dim3 else ["j", "b"] + + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "ScatterNDOfShape", + inputs=["shape", "indices", "updates"], + outputs=["y"], + reduction="add", + strategy="optimize" if optimize else "none", + domain="onnx_extended.ortops.optim.cuda", + ) + ], + "nd", + [ + oh.make_tensor_value_info("shape", TensorProto.INT64, [2]), + oh.make_tensor_value_info( + "indices", TensorProto.INT64, indices_shape + ), + oh.make_tensor_value_info("updates", itype, updates_shape), + ], + [oh.make_tensor_value_info("y", itype, [None, None])], + ), + opset_imports=[ + oh.make_opsetid("", 18), + oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + ], + ir_version=9, + ) + + if dim3: + shape = (128, 1024) + indices = np.zeros((2, 64, 1)).astype(np.int64) + indices[:, ::2, 0] = 87 + indices[:, ::3, 0] = 85 + updates = np.ones((2, 64, 1024)).astype(np.float32) + else: + shape = (128, 1024) + indices = np.zeros((128, 1)).astype(np.int64) + indices[::2, 0] = 87 + indices[::3, 0] = 85 + updates = np.ones((128, 1024)).astype(np.float32) + if itype != 1: + updates = updates.astype(np.float16) + feeds = dict( + shape=np.array(shape, dtype=np.int64), indices=indices, updates=updates + ) + + ref = CReferenceEvaluator(model) + expected = ref.run(None, feeds)[0] + + opts = onnxruntime.SessionOptions() + opts.register_custom_ops_library(get_ort_ext_libs()[0]) + # opts.log_severity_level = 0 + # opts.log_verbosity_level = 0 + sess = onnxruntime.InferenceSession( + model.SerializeToString(), opts, providers=["CUDAExecutionProvider"] + ) + if __name__ == "__main__": + print(f"running itype={itype}, optimize={optimize}, dim3={dim3}") + got = sess.run(None, feeds)[0] + self.assertEqual(expected.tolist(), got.tolist()) + if __name__ == "__main__": + print("done.") + + def test_scatternd_of_shape_optimize_cuda(self): + with self.subTest(optimize=True, dim3=True): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=False): + self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16) + if __name__ == "__main__": # TestOrtOpTutorialCpu().test_dynamic_quantize_linear() diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu index a9ce5905..97f57bc4 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu @@ -214,6 +214,30 @@ struct TensorPitches : std::vector { } }; +template +__global__ void addition_inplace_kernel(T *dst, const T *a, const T *b, const CUDA_LONG size) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); + dst[id] = a[id] + b[id]; +} + +template +__global__ void set_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); + a[id] = b[id]; +} + +template +__global__ void set_zero_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); + a[id] = 0; +} + +template +__global__ void set_weird_zero_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); + a[id] = static_cast(0); +} + ////////////////// // ScatterNDOfShapeOp... ////////////////// @@ -317,11 +341,24 @@ ScatterNDOfShapeKernel::ScatterNDOfShapeKernel(const OrtApi &api, char value_string[1000]; std::size_t size = 1000; ThrowOnError(api, api.KernelInfoGetAttribute_string(info, "reduction", value_string, &size)); - std::string reduction = value_string; - if (reduction == "add") + std::string value = value_string; + if (value == "add") reduction_ = Reduction::Add; else - EXT_THROW("unexpected reduction '", reduction, "'."); + EXT_THROW("unexpected reduction '", value, "'."); + + value = KernelInfoGetOptionalAttributeString(api, info, "strategy", "none"); + if (value == "none") + strategy_ = Strategy::None; + else if (value == "optimize") + strategy_ = Strategy::Optimize; + else + EXT_THROW("unexpected strategy '", value, "'."); + + cudaDeviceProp prop; + int deviceId = 0; + cudaGetDeviceProperties(&prop, deviceId); + maxThreadPerBlock_ = prop.maxThreadsPerBlock; } template void ScatterNDOfShapeKernel::Compute(OrtKernelContext *context) { @@ -339,7 +376,11 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * std::vector updates_shape = updates.GetTensorTypeAndShapeInfo().GetShape(); EXT_ENFORCE(dimensions.size() == 1, "shape must be a 1-dimension tensor."); - cudaStream_t cuda_stream = (cudaStream_t)ctx.GetGPUComputeStream(); + cudaStream_t stream = (cudaStream_t)ctx.GetGPUComputeStream(); + + auto memi = updates.GetTensorMemoryInfo(); + EXT_ENFORCE(memi.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, + "updates are not on GPU"); auto mem = shape.GetTensorMemoryInfo(); if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { @@ -355,17 +396,67 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * dims[i] = X[i]; output = ctx.GetOutput(0, dims); } else { - EXT_THROW("Unexpected device for input 1."); + EXT_THROW("Unexpected device for input 0."); + } + + std::vector input_shape = output.GetTensorTypeAndShapeInfo().GetShape(); + + if (reduction_ == Reduction::Add && strategy_ == Strategy::Optimize && + indices_shape[indices_shape.size() - 1] == 1 && input_shape.size() == 2 && + input_shape[input_shape.size() - 1] >= maxThreadPerBlock_) { + // We need the indices on CPU for this code. + + size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); + size_t update_size = static_cast(onnx_c_ops::flattened_dimension(updates_shape)); + EXT_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1], + "Size mismatch, update_size=", update_size, "indice_size=", indice_size, + "input_shape[-1]=", input_shape[input_shape.size() - 1], "."); + + const int64_t *indices_data; + std::vector indices_buffer; + + auto mem = indices.GetTensorMemoryInfo(); + if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { + indices_buffer.resize(indice_size); + indices_data = indices_buffer.data(); + CUDA_THROW_IF_ERROR(cudaMemcpy(static_cast(indices_buffer.data()), + indices.GetTensorData(), + indice_size * sizeof(int64_t), cudaMemcpyDeviceToHost)); + } else if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) { + indices_data = indices.GetTensorData(); + } else { + EXT_THROW("Unexpected device for input 1."); + } + + ComputeOptimize(stream, input_shape, indices_shape, output.GetTensorMutableData(), + indices_data, updates.GetTensorData()); + + auto n_elements = onnx_c_ops::flattened_dimension(input_shape); + + // The kernel does not execute if this line is not present? + set_weird_zero_inplace_kernel<<>>( + output.GetTensorMutableData(), 0, n_elements); + } else { + ComputeNone(stream, input_shape, indices_shape, output.GetTensorMutableData(), + indices.GetTensorData(), updates.GetTensorData()); } +} +template +void ScatterNDOfShapeKernel::ComputeNone(cudaStream_t &stream, + const std::vector &input_shape, + const std::vector &indices_shape, + T *output_data, const int64_t *indices_data, + const T *updates_data) const { int64_t indice_size = onnx_c_ops::flattened_dimension(indices_shape); - if (indice_size == 0) + auto n_elements = onnx_c_ops::flattened_dimension(input_shape); + + if (indice_size == 0) { + CUDA_THROW_IF_ERROR(cudaMemsetAsync(output_data, 0, sizeof(T) * n_elements, stream)); return; + } - void *output_data = output.GetTensorMutableData(); - std::vector input_shape = output.GetTensorTypeAndShapeInfo().GetShape(); - auto n_elements = onnx_c_ops::flattened_dimension(input_shape); - CUDA_THROW_IF_ERROR(cudaMemsetAsync(output_data, 0, sizeof(T) * n_elements, cuda_stream)); + CUDA_THROW_IF_ERROR(cudaMemsetAsync(output_data, 0, sizeof(T) * n_elements, stream)); auto last_index_dimension = indices_shape[indices_shape.size() - 1]; @@ -385,20 +476,18 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * cudaMalloc((void **)&workspace, element_counts_and_input_dims.size() * sizeof(int64_t))); CUDA_THROW_IF_ERROR(cudaMemcpyAsync(workspace, element_counts_and_input_dims.data(), element_counts_and_input_dims.size() * sizeof(int64_t), - cudaMemcpyHostToDevice, cuda_stream)); + cudaMemcpyHostToDevice, stream)); // Let's synchronize after the initialization of the results. - // CUDA_THROW_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + // CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); switch (reduction_) { case Reduction::Add: { auto element_type = CTypeToOnnxType().onnx_type(); ScatterNDImplReduction( - cuda_stream, output_data, element_type, - indice_size / static_cast(last_index_dimension), - indices.GetTensorData(), // only int64_t is supported for indices as per - // the onnx spec - last_index_dimension, workspace, updates.GetTensorData(), + stream, output_data, element_type, + indice_size / static_cast(last_index_dimension), indices_data, + last_index_dimension, workspace, updates_data, onnx_c_ops::SizeFromDimension(input_shape, last_index_dimension, input_shape.size()), reduction_); } break; @@ -410,6 +499,90 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * CUDA_THROW_IF_ERROR(cudaFree(workspace)); } +template +void _ComputeOptimize(cudaStream_t stream, const std::vector &input_shape, + const std::vector &indices_shape, T *output_data, + const int64_t *indices_data, const T *updates_data, + int maxThreadPerBlock_) { + + // The kernel is slow if there are a lot of duplicates. + // reduction_ == Reduction::add + // indices_shape[indices_shape.size() - 1] == 1 + // input_shape.size() == 2 + size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); + size_t next_batch_size = 0; + size_t stride = input_shape[input_shape.size() - 1]; + CUDA_LONG stride_ = static_cast(stride); + + std::vector next_batch(indice_size); + std::vector processed(input_shape[0], 0); + std::vector processed_once(input_shape[0], 0); + size_t row; + + int threads_per_block = std::min(256, maxThreadPerBlock_ / 2); + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + + // First iteration. + for (size_t i = 0; i < indice_size; ++i) { + row = static_cast(indices_data[i]); + if (processed[row]) { + next_batch[next_batch_size++] = i; + } else { + set_inplace_kernel<<>>(output_data + row * stride, + updates_data + i * stride, stride_); + processed[row] = 1; + processed_once[row] = 1; + } + } + + // We set to zero all rows not impacted. + for (size_t i = 0; i < processed_once.size(); ++i) { + if (processed_once[i]) + continue; + CUDA_THROW_IF_ERROR( + cudaMemsetAsync(output_data + row * stride, 0, sizeof(T) * stride, stream)); + } + + // We need to synchronize. + memset(processed.data(), 0, processed.size() * sizeof(uint8_t)); + CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); + + // Then the next iterations. + while (next_batch_size > 0) { + size_t current_batch_size = next_batch_size; + next_batch_size = 0; + for (size_t i = 0; i < current_batch_size; ++i) { + row = indices_data[next_batch[i]]; + if (processed[row]) { + next_batch[next_batch_size++] = next_batch[i]; + } else { + addition_inplace_kernel<<>>( + output_data + row * stride, output_data + row * stride, + updates_data + next_batch[i] * stride, stride_); + processed[row] = 1; + } + } + + // We need to synchronize. + if (next_batch_size > 0) { + memset(processed.data(), 0, processed.size() * sizeof(uint8_t)); + CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); + } + } +} + +template +void ScatterNDOfShapeKernel::ComputeOptimize(cudaStream_t &stream, + const std::vector &input_shape, + const std::vector &indices_shape, + T *output_data, const int64_t *indices_data, + const T *updates_data) const { + _ComputeOptimize(stream, input_shape, indices_shape, output_data, indices_data, updates_data, + maxThreadPerBlock_); +} + static ScatterNDOfShapeOp _op32; static ScatterNDOfShapeOp _op16; diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h index 688f9290..88571f7b 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h @@ -14,12 +14,26 @@ enum class Reduction : int { Max = 4, }; +enum class Strategy : int { + None = 0, + Optimize = 1, +}; + template struct ScatterNDOfShapeKernel { ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info); void Compute(OrtKernelContext *context); private: + void ComputeNone(cudaStream_t &stream, const std::vector &input_shape, + const std::vector &indices_shape, T *output_data, + const int64_t *indices_data, const T *updates_data) const; + void ComputeOptimize(cudaStream_t &stream, const std::vector &input_shape, + const std::vector &indices_shape, T *output_data, + const int64_t *indices_data, const T *updates_data) const; + Reduction reduction_; + Strategy strategy_; + int maxThreadPerBlock_; }; template diff --git a/onnx_extended/reference/c_reference_evaluator.py b/onnx_extended/reference/c_reference_evaluator.py index f14988f0..55a9843f 100644 --- a/onnx_extended/reference/c_reference_evaluator.py +++ b/onnx_extended/reference/c_reference_evaluator.py @@ -168,9 +168,13 @@ def default_ops(): ) from onnx_extended.reference.c_ops.c_op_tfidf_vectorizer import TfIdfVectorizer from onnx_extended.reference.other_ops.op_tokenizer import Tokenizer + from onnx_extended.reference.other_ops.op_scatternd_of_shape import ( + ScatterNDOfShape, + ) return [ Conv, + ScatterNDOfShape, SVMClassifier, SVMRegressor, TfIdfVectorizer, diff --git a/onnx_extended/reference/other_ops/op_scatternd_of_shape.py b/onnx_extended/reference/other_ops/op_scatternd_of_shape.py new file mode 100644 index 00000000..1bcda940 --- /dev/null +++ b/onnx_extended/reference/other_ops/op_scatternd_of_shape.py @@ -0,0 +1,12 @@ +import numpy as np +from onnx.reference.op_run import OpRun +from onnx.reference.ops.op_scatternd import _scatter_nd_impl + + +class ScatterNDOfShape(OpRun): + op_domain = "onnx_extended.ortops.optim.cuda" + + def _run(self, shape, indices, updates, reduction=None, strategy=None): + data = np.zeros(shape, dtype=updates.dtype) + y = _scatter_nd_impl(data, indices, updates, reduction=reduction) + return (y,) From 3f3ee5687f730a1f852ae31336fb03d6b18a4522 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 15 Apr 2024 18:41:52 +0000 Subject: [PATCH 3/8] first try --- _doc/examples/plot_op_scatternd_cuda.py | 269 +++++++++++++++++- _unittests/ut_ortops/test_optim_cuda.py | 22 +- .../ortops/optim/cuda/scatter_nd_of_shape.cu | 169 +++-------- .../ortops/optim/cuda/scatter_nd_of_shape.h | 5 + pyproject.toml | 1 + 5 files changed, 324 insertions(+), 142 deletions(-) diff --git a/_doc/examples/plot_op_scatternd_cuda.py b/_doc/examples/plot_op_scatternd_cuda.py index eced8d07..591ec4b4 100644 --- a/_doc/examples/plot_op_scatternd_cuda.py +++ b/_doc/examples/plot_op_scatternd_cuda.py @@ -23,33 +23,87 @@ * updates: 2x1024x4096 """ +from onnx_extended.args import get_parsed_args + +script_args = get_parsed_args( + "plot_op_scatternd_cuda", + description=__doc__, + config=( + "small", + "small, short optimization (default), " + "medium for medium sizes, " + "large for big sizes", + ), + warmup=3, + repeat=5, + itype=(1, "1 or 10 for float or float16"), + expose="config,itype,warmup,repeat", +) + +import time import numpy as np +from numpy.testing import assert_almost_equal +from pandas import DataFrame +from tqdm import tqdm import onnx.helper as oh from onnx import TensorProto from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +itype = script_args.itype +config = script_args.config +print(f"config={config}") +print(f"itype={itype}") + +if config == "small": + sizes = (256, 512, 1024) +elif config == "medium": + sizes = (512, 1024, 2048) +elif config == "large": + sizes = (1024, 2048, 4096, 8192) +else: + try: + sizes = list(map(int, config.split(","))) + except (ValueError, TypeError) as e: + raise AssertionError(f"Unexpected config value {config!r}.") from e + -def get_model(d3=True): +def get_model(d3=True, optimize=False, shape_input=False, itype=TensorProto.FLOAT): indices_shape = ["i", "j", 1] if d3 else ["m", 1] updates_shape = ["i", "j", "b"] if d3 else ["m", "b"] + kwargs = dict(reduction="add") + if shape_input: + kwargs["domain"] = "onnx_extended.ortops.optim.cuda" + if optimize: + kwargs["strategy"] = "optimize" + model = oh.make_model( oh.make_graph( [ oh.make_node( - "ScatterND", ["X", "indices", "updates"], ["Y"], reduction="add" + "ScatterNDOfShape" if shape_input else "ScatterND", + ["shape" if shape_input else "X", "indices", "updates"], + ["Y"], + **kwargs, ) ], "g", [ - oh.make_tensor_value_info("X", TensorProto.FLOAT, ["a", "b"]), + ( + oh.make_tensor_value_info("shape", TensorProto.INT64, ["s"]) + if shape_input + else oh.make_tensor_value_info("X", itype, ["a", "b"]) + ), oh.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), - oh.make_tensor_value_info("updates", TensorProto.FLOAT, updates_shape), + oh.make_tensor_value_info("updates", itype, updates_shape), ], - [oh.make_tensor_value_info("Y", TensorProto.FLOAT, ["a", "b"])], + [oh.make_tensor_value_info("Y", itype, ["a", "b"])], ), - opset_imports=[oh.make_opsetid("", 18)], + opset_imports=[ + oh.make_opsetid("", 18), + oh.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + ], ir_version=9, ) return model @@ -63,10 +117,11 @@ def get_model(d3=True): # Let's see the evaluation by the ReferenceEvaluator. -def _scatter_nd_impl(data, indices, updates, reduction=None): # type: ignore +def _scatter_nd_impl(data, indices, updates, reduction=None, verbose=False): # type: ignore output = np.copy(data) for i in np.ndindex(indices.shape[:-1]): - print(f"updates for i={i}, indices={indices[i]}, updates={updates[i]}") + if verbose: + print(f"updates for i={i}, indices={indices[i]}, updates={updates[i]}") if reduction == "add": output[tuple(indices[i])] += updates[i] elif reduction == "mul": @@ -81,7 +136,16 @@ def _scatter_nd_impl(data, indices, updates, reduction=None): # type: ignore class ScatterND(OpRun): - def _run(self, data, indices, updates, reduction=None): # type: ignore + def _run(self, data, indices, updates, reduction=None, optimize=None): # type: ignore + y = _scatter_nd_impl(data, indices, updates, reduction=reduction, verbose=True) + return (y,) + + +class ScatterNDOfShape(OpRun): + op_domain = "onnx_extended.ortops.optim.cuda" + + def _run(self, shape, indices, updates, reduction=None, optimize=None): # type: ignore + data = np.zeros(tuple(shape.tolist()), dtype=updates.dtype) y = _scatter_nd_impl(data, indices, updates, reduction=reduction) return (y,) @@ -102,7 +166,7 @@ def _run(self, data, indices, updates, reduction=None): # type: ignore ########################################### # To generalize, let's change the shapes. -model = get_model(d3=False) +model = get_model(d3=False, itype=itype) print(onnx_simple_text_plot(model)) @@ -113,3 +177,188 @@ def _run(self, data, indices, updates, reduction=None): # type: ignore ref = ReferenceEvaluator(model, new_ops=[ScatterND]) got = ref.run(None, feeds)[0] print(got) + + +############################################## +# First scenario +# ============== + +model = get_model(d3=False, shape_input=True, itype=itype) +print(onnx_simple_text_plot(model)) + + +feeds = { + "shape": np.array(X.shape, dtype=np.int64), + "indices": indices.reshape((-1, 1)), + "updates": updates.reshape((-1, updates.shape[-1])), +} + +ref = ReferenceEvaluator(model, new_ops=[ScatterNDOfShape]) +expected = ref.run(None, feeds)[0] +print(expected) + + +################################### +# With onnxruntime + + +def get_session(model): + import onnxruntime + from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + + if "CUDAExecutionProvider" not in onnxruntime.get_available_providers(): + return None + + opts = onnxruntime.SessionOptions() + opts.register_custom_ops_library(get_ort_ext_libs()[0]) + sess = onnxruntime.InferenceSession( + model.SerializeToString(), + opts, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + return sess + + +sess1 = get_session(model) +if sess1 is not None: + for k, v in feeds.items(): + print(k, v.dtype, v.shape) + got = sess1.run(None, feeds)[0] + print(got) + assert_almost_equal(expected, got) + +################################################## +# Same model but using an optimization to compute it. + +model = get_model(d3=False, shape_input=True, optimize=True, itype=itype) +print(onnx_simple_text_plot(model)) + +sess2 = get_session(model) +if sess2 is not None: + got = sess2.run(None, feeds)[0] + print(got) + assert_almost_equal(expected, got) + +################################################# +# Benchmark +# ========= + + +def move_inputs(sess, feeds): + from onnxruntime.capi._pybind_state import ( + SessionIOBinding, + OrtDevice as C_OrtDevice, + OrtValue as C_OrtValue, + ) + + input_names = [i.name for i in sess.get_inputs()] + + ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + feed_ort_value = [ + (name, C_OrtValue.ortvalue_from_numpy(feeds[name], ort_device)) + for name in input_names + ] + + bind = SessionIOBinding(sess._sess) + for name, value in feed_ort_value: + bind.bind_input( + name, ort_device, feeds[name].dtype, value.shape(), value.data_ptr() + ) + for o in sess.get_outputs(): + bind.bind_output(o.name, ort_device) + return bind, feed_ort_value + + +def benchmark(sess, sizes, label, itype, times_col: int = 1, times_indices: int = 1): + + data = [] + for size in tqdm(sizes): + + nrow, ncol = size, int(size * times_col) + nind = int(size * times_indices) + shape = np.array([nrow, ncol], dtype=np.int64) + indices = np.array( + [np.random.randint(0, nrow - 1) for _ in range(nind)], dtype=np.int64 + ).reshape((-1, 1)) + updates = np.random.randn(nind, ncol).astype( + np.float32 if itype == TensorProto.FLOAT else np.float16 + ) + feeds = dict(shape=shape, indices=indices, updates=updates) + bind, cuda_feeds = move_inputs(sess, feeds) + + begin = time.perf_counter() + for i in range(script_args.warmup): + # sess.run(None, feeds) + sess._sess.run_with_iobinding(bind, None) + warmup = time.perf_counter() - begin + + times = [] + for i in range(script_args.repeat): + begin = time.perf_counter() + # sess.run(None, feeds) + sess._sess.run_with_iobinding(bind, None) + times.append(time.perf_counter() - begin) + + npt = np.array(times) + obs = dict( + warmup=warmup, + time=npt.mean(), + std=npt.std(), + min=npt.min(), + max=npt.max(), + repeat=script_args.repeat, + size=size, + label=label, + ) + data.append(obs) + return data + + +####################################### +# Not Fused. + + +if sess1 is not None: + + print(f"sizes={sizes}") + + data_nd1 = benchmark(sess1, sizes, "Atomic", itype=itype) + +####################################### +# Fused. + +if sess2 is not None: + + data_nd2 = benchmark(sess2, sizes, "No Atomic", itype=itype) + + +########################################## +# Data +# ++++ + +if sess2 is not None: + + df = DataFrame(data_nd1 + data_nd2) + df.to_csv("plot_op_scatternd_cuda.csv", index=False) + df.to_csv("plot_op_scatternd_cuda.xlsx", index=False) + print(df.head()) + +##################### +# Pivot. + +if sess2 is not None: + + pivot = df.pivot(index="size", columns="label", values="time") + pivot["ratio"] = pivot["Atomic"] / pivot["No Atomic"] + print(pivot) + + ax = pivot[["Atomic", "No Atomic"]].plot( + logx=True, + logy=True, + title=f"Atomic/No-Atomic implementation for ScatterND on CUDA\nitype={itype}", + ) + ax.get_figure().savefig("plot_op_scatternd_cuda.png") + +############################## +# The best choice depends on the on input sizes. diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index ad60e1cf..110b9f68 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -252,16 +252,26 @@ def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): opts = onnxruntime.SessionOptions() opts.register_custom_ops_library(get_ort_ext_libs()[0]) - # opts.log_severity_level = 0 - # opts.log_verbosity_level = 0 + if __name__ == "disabled__main__": + opts.log_severity_level = 0 + opts.log_verbosity_level = 0 sess = onnxruntime.InferenceSession( model.SerializeToString(), opts, providers=["CUDAExecutionProvider"] ) - if __name__ == "__main__": - print(f"running itype={itype}, optimize={optimize}, dim3={dim3}") - got = sess.run(None, feeds)[0] + if __name__ == "disabled__main__": + print( + f"running itype={itype}, optimize={optimize}, dim3={dim3}, " + f"shape={shape}, indices.shape={indices.shape}, " + f"updates.shape={updates.shape}" + ) + ro = onnxruntime.RunOptions() + ro.log_severity_level = 0 + ro.log_verbosity_level = 0 + else: + ro = None + got = sess.run(None, feeds, ro)[0] self.assertEqual(expected.tolist(), got.tolist()) - if __name__ == "__main__": + if __name__ == "disabled__main__": print("done.") def test_scatternd_of_shape_optimize_cuda(self): diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu index 97f57bc4..18b7b234 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu @@ -74,7 +74,7 @@ template __global__ void _ScatterNDKernelReduction(T *output_data, const size_t num_indices, const int64_t *indices_data, const int64_t last_index_dimension, - const int64_t *element_counts_and_input_dims, const T *updates_data, + Shape2 element_counts_and_input_dims, const T *updates_data, const size_t num_updates_elements, const TFunc func) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices); @@ -86,8 +86,8 @@ _ScatterNDKernelReduction(T *output_data, const size_t num_indices, const int64_ for (size_t i = indices_start; i < indices_end; ++i) { int64_t index = indices_data[i]; - int64_t element_count_dim = element_counts_and_input_dims[i - indices_start]; - int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension]; + int64_t element_count_dim = element_counts_and_input_dims.dims[i - indices_start]; + int64_t dim_value = element_counts_and_input_dims.dims[i - indices_start + last_index_dimension]; // Clamp the index if out of range // This would have been an error in the CPU kernel, but throwing in the CUDA EP @@ -132,7 +132,7 @@ struct GridDim { void ScatterNDImplReduction(cudaStream_t stream, void *output_data, const int32_t element_type, const size_t num_indices, const int64_t *indices_data, const int64_t last_index_dimension, - const int64_t *element_counts_and_input_dims, + const Shape2& element_counts_and_input_dims, const void *updates_data, const size_t num_updates_elements, Reduction reduction) { if (num_indices == 0) @@ -215,27 +215,18 @@ struct TensorPitches : std::vector { }; template -__global__ void addition_inplace_kernel(T *dst, const T *a, const T *b, const CUDA_LONG size) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); - dst[id] = a[id] + b[id]; -} - -template -__global__ void set_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); - a[id] = b[id]; -} - -template -__global__ void set_zero_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); - a[id] = 0; -} +__global__ void addition_inplace_kernel(T *output_data, const int64_t *indices_data, const T *updates_data, const CUDA_LONG indice_size, const CUDA_LONG nrows, const CUDA_LONG stride) { + HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= stride) + return; + + for(size_t i=0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } -template -__global__ void set_weird_zero_inplace_kernel(T *a, const T *b, const CUDA_LONG size) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, size); - a[id] = static_cast(0); + for(size_t i=0;i < indice_size; ++i) { + output_data[indices_data[i] * stride + id] += updates_data[i * stride + id]; + } } ////////////////// @@ -283,6 +274,19 @@ ONNXTensorElementDataType ScatterNDOfShapeOp::GetInputType(std::size_t ind } } +template +OrtMemType ScatterNDOfShapeOp::GetInputMemoryType(std::size_t index) const { + switch (index) { + case 0: + return OrtMemTypeCPUInput; + case 1: + case 2: + return OrtMemTypeDefault; + default: + EXT_THROW("Input index=", (int64_t)index, " is out of boundary."); + } +} + template OrtCustomOpInputOutputCharacteristic ScatterNDOfShapeOp::GetInputCharacteristic(std::size_t index) const { @@ -383,59 +387,27 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * "updates are not on GPU"); auto mem = shape.GetTensorMemoryInfo(); - if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { - std::vector buf(dimensions[0]); - const int64_t *ptr = shape.GetTensorData(); - CUDA_THROW_IF_ERROR( - cudaMemcpy(buf.data(), ptr, dimensions[0] * sizeof(int64_t), cudaMemcpyDeviceToHost)); - output = ctx.GetOutput(0, buf); - } else if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) { - const int64_t *X = shape.GetTensorData(); - std::vector dims(dimensions[0]); - for (size_t i = 0; i < dimensions[0]; ++i) - dims[i] = X[i]; - output = ctx.GetOutput(0, dims); - } else { - EXT_THROW("Unexpected device for input 0."); - } + EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, + "The shape should be on CPU already, but mem.GetDeviceType()=", mem.GetDeviceType(), "."); + const int64_t *X = shape.GetTensorData(); + std::vector dims(X, X + dimensions[0]); + output = ctx.GetOutput(0, dims); std::vector input_shape = output.GetTensorTypeAndShapeInfo().GetShape(); if (reduction_ == Reduction::Add && strategy_ == Strategy::Optimize && indices_shape[indices_shape.size() - 1] == 1 && input_shape.size() == 2 && input_shape[input_shape.size() - 1] >= maxThreadPerBlock_) { - // We need the indices on CPU for this code. size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); size_t update_size = static_cast(onnx_c_ops::flattened_dimension(updates_shape)); + EXT_ENFORCE(update_size == indice_size * input_shape[input_shape.size() - 1], "Size mismatch, update_size=", update_size, "indice_size=", indice_size, "input_shape[-1]=", input_shape[input_shape.size() - 1], "."); - const int64_t *indices_data; - std::vector indices_buffer; - - auto mem = indices.GetTensorMemoryInfo(); - if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) { - indices_buffer.resize(indice_size); - indices_data = indices_buffer.data(); - CUDA_THROW_IF_ERROR(cudaMemcpy(static_cast(indices_buffer.data()), - indices.GetTensorData(), - indice_size * sizeof(int64_t), cudaMemcpyDeviceToHost)); - } else if (mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) { - indices_data = indices.GetTensorData(); - } else { - EXT_THROW("Unexpected device for input 1."); - } - ComputeOptimize(stream, input_shape, indices_shape, output.GetTensorMutableData(), - indices_data, updates.GetTensorData()); - - auto n_elements = onnx_c_ops::flattened_dimension(input_shape); - - // The kernel does not execute if this line is not present? - set_weird_zero_inplace_kernel<<>>( - output.GetTensorMutableData(), 0, n_elements); + indices.GetTensorData(), updates.GetTensorData()); } else { ComputeNone(stream, input_shape, indices_shape, output.GetTensorMutableData(), indices.GetTensorData(), updates.GetTensorData()); @@ -464,20 +436,14 @@ void ScatterNDOfShapeKernel::ComputeNone(cudaStream_t &stream, // for the range [0, last_index_dimension). // To avoid multiple GPU data transfers, we combine this into one array and send it through TensorPitches input_strides(input_shape); - std::vector element_counts_and_input_dims(last_index_dimension * 2, 0LL); + Shape2 element_counts_and_input_dims; + memset(element_counts_and_input_dims.dims, 0, sizeof(int64_t) * last_index_dimension * 2); for (int64_t i = 0; i < last_index_dimension; ++i) { - element_counts_and_input_dims[i] = input_strides[i]; - element_counts_and_input_dims[i + last_index_dimension] = input_shape[i]; + element_counts_and_input_dims.dims[i] = input_strides[i]; + element_counts_and_input_dims.dims[i + last_index_dimension] = input_shape[i]; } - int64_t *workspace; - CUDA_THROW_IF_ERROR( - cudaMalloc((void **)&workspace, element_counts_and_input_dims.size() * sizeof(int64_t))); - CUDA_THROW_IF_ERROR(cudaMemcpyAsync(workspace, element_counts_and_input_dims.data(), - element_counts_and_input_dims.size() * sizeof(int64_t), - cudaMemcpyHostToDevice, stream)); - // Let's synchronize after the initialization of the results. // CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); @@ -487,7 +453,7 @@ void ScatterNDOfShapeKernel::ComputeNone(cudaStream_t &stream, ScatterNDImplReduction( stream, output_data, element_type, indice_size / static_cast(last_index_dimension), indices_data, - last_index_dimension, workspace, updates_data, + last_index_dimension, element_counts_and_input_dims, updates_data, onnx_c_ops::SizeFromDimension(input_shape, last_index_dimension, input_shape.size()), reduction_); } break; @@ -495,8 +461,6 @@ void ScatterNDOfShapeKernel::ComputeNone(cudaStream_t &stream, EXT_THROW("ScatterNDOfShape not supported for other reduction than Add, None."); break; } - - CUDA_THROW_IF_ERROR(cudaFree(workspace)); } template @@ -510,67 +474,20 @@ void _ComputeOptimize(cudaStream_t stream, const std::vector &input_sha // indices_shape[indices_shape.size() - 1] == 1 // input_shape.size() == 2 size_t indice_size = static_cast(onnx_c_ops::flattened_dimension(indices_shape)); - size_t next_batch_size = 0; + size_t input_size = static_cast(onnx_c_ops::flattened_dimension(input_shape)); size_t stride = input_shape[input_shape.size() - 1]; - CUDA_LONG stride_ = static_cast(stride); + size_t nrows = input_size / stride; std::vector next_batch(indice_size); std::vector processed(input_shape[0], 0); std::vector processed_once(input_shape[0], 0); - size_t row; int threads_per_block = std::min(256, maxThreadPerBlock_ / 2); int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; dim3 threads(threads_per_block); dim3 blocks(blocks_per_grid); - // First iteration. - for (size_t i = 0; i < indice_size; ++i) { - row = static_cast(indices_data[i]); - if (processed[row]) { - next_batch[next_batch_size++] = i; - } else { - set_inplace_kernel<<>>(output_data + row * stride, - updates_data + i * stride, stride_); - processed[row] = 1; - processed_once[row] = 1; - } - } - - // We set to zero all rows not impacted. - for (size_t i = 0; i < processed_once.size(); ++i) { - if (processed_once[i]) - continue; - CUDA_THROW_IF_ERROR( - cudaMemsetAsync(output_data + row * stride, 0, sizeof(T) * stride, stream)); - } - - // We need to synchronize. - memset(processed.data(), 0, processed.size() * sizeof(uint8_t)); - CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); - - // Then the next iterations. - while (next_batch_size > 0) { - size_t current_batch_size = next_batch_size; - next_batch_size = 0; - for (size_t i = 0; i < current_batch_size; ++i) { - row = indices_data[next_batch[i]]; - if (processed[row]) { - next_batch[next_batch_size++] = next_batch[i]; - } else { - addition_inplace_kernel<<>>( - output_data + row * stride, output_data + row * stride, - updates_data + next_batch[i] * stride, stride_); - processed[row] = 1; - } - } - - // We need to synchronize. - if (next_batch_size > 0) { - memset(processed.data(), 0, processed.size() * sizeof(uint8_t)); - CUDA_THROW_IF_ERROR(cudaStreamSynchronize(stream)); - } - } + addition_inplace_kernel<<>>(output_data, indices_data, updates_data, indice_size, nrows, stride); } template diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h index 88571f7b..2cdce54f 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.h @@ -19,6 +19,10 @@ enum class Strategy : int { Optimize = 1, }; +struct Shape2 { + int64_t dims[12]; +}; + template struct ScatterNDOfShapeKernel { ScatterNDOfShapeKernel(const OrtApi &api, const OrtKernelInfo *info); void Compute(OrtKernelContext *context); @@ -48,6 +52,7 @@ struct ScatterNDOfShapeOp std::size_t GetInputTypeCount() const; ONNXTensorElementDataType GetInputType(std::size_t index) const; OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(std::size_t index) const; + OrtMemType GetInputMemoryType(std::size_t index) const; std::size_t GetOutputTypeCount() const; ONNXTensorElementDataType GetOutputType(std::size_t index) const; diff --git a/pyproject.toml b/pyproject.toml index 8ba00bc3..1a639a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -323,6 +323,7 @@ max-complexity = 10 [tool.ruff.lint.per-file-ignores] "_doc/examples/plot_op_mul_cuda.py" = ["E402"] +"_doc/examples/plot_op_scatternd_cuda.py" = ["E402"] "onnx_extended/helper/__init__.py" = ["F401"] "onnx_extended/reference/__init__.py" = ["F401"] "onnx_extended/tools/__init__.py" = ["F401"] From de9ef923bcfc38d5e6a17ab6b5e6c3880581963d Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Apr 2024 09:41:14 +0000 Subject: [PATCH 4/8] fix computation --- _doc/examples/plot_op_scatternd_cuda.py | 34 +++++--- .../ortops/optim/cuda/scatter_nd_of_shape.cu | 78 +++++++++++++++---- 2 files changed, 89 insertions(+), 23 deletions(-) diff --git a/_doc/examples/plot_op_scatternd_cuda.py b/_doc/examples/plot_op_scatternd_cuda.py index 591ec4b4..62136adf 100644 --- a/_doc/examples/plot_op_scatternd_cuda.py +++ b/_doc/examples/plot_op_scatternd_cuda.py @@ -33,6 +33,7 @@ "small, short optimization (default), " "medium for medium sizes, " "large for big sizes", + "llama for a specific case on llama", ), warmup=3, repeat=5, @@ -52,9 +53,10 @@ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot itype = script_args.itype +dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 config = script_args.config print(f"config={config}") -print(f"itype={itype}") +print(f"itype={itype}, dtype={dtype}") if config == "small": sizes = (256, 512, 1024) @@ -62,6 +64,8 @@ sizes = (512, 1024, 2048) elif config == "large": sizes = (1024, 2048, 4096, 8192) +elif config == "llama": + sizes = (16000, 32000) else: try: sizes = list(map(int, config.split(","))) @@ -151,10 +155,10 @@ def _run(self, shape, indices, updates, reduction=None, optimize=None): # type: shape = (5, 7) -X = np.zeros(shape, dtype=np.float32) +X = np.zeros(shape, dtype=dtype) indices = np.zeros((2, 10, 1)).astype(np.int64) indices[:, ::2, 0] = 3 -updates = np.ones((2, 10, 7)).astype(np.float32) +updates = np.ones((2, 10, 7)).astype(dtype) feeds = {"X": X, "indices": indices, "updates": updates} @@ -270,13 +274,23 @@ def move_inputs(sess, feeds): return bind, feed_ort_value -def benchmark(sess, sizes, label, itype, times_col: int = 1, times_indices: int = 1): +def benchmark( + sess, sizes, config, label, itype, times_col: int = 1, times_indices: int = 1 +): data = [] for size in tqdm(sizes): - nrow, ncol = size, int(size * times_col) - nind = int(size * times_indices) + if config == "llama": + # zeros: 32000x4096 + # indices: 2x1024x1 + # updates: 2x1024x4096 + nrow, ncol = size, 4096 + nind = 1024 + else: + nrow, ncol = size, int(size * times_col) + nind = int(size * times_indices) + shape = np.array([nrow, ncol], dtype=np.int64) indices = np.array( [np.random.randint(0, nrow - 1) for _ in range(nind)], dtype=np.int64 @@ -323,14 +337,14 @@ def benchmark(sess, sizes, label, itype, times_col: int = 1, times_indices: int print(f"sizes={sizes}") - data_nd1 = benchmark(sess1, sizes, "Atomic", itype=itype) + data_nd1 = benchmark(sess1, sizes, script_args.config, "Atomic", itype=itype) ####################################### # Fused. if sess2 is not None: - data_nd2 = benchmark(sess2, sizes, "No Atomic", itype=itype) + data_nd2 = benchmark(sess2, sizes, script_args.config, "No Atomic", itype=itype) ########################################## @@ -361,4 +375,6 @@ def benchmark(sess, sizes, label, itype, times_col: int = 1, times_indices: int ax.get_figure().savefig("plot_op_scatternd_cuda.png") ############################## -# The best choice depends on the on input sizes. +# The best choice depends on the input sizes, +# For big matrices, the use of atomic is slowing down +# the computation. diff --git a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu index 18b7b234..60fdc4de 100644 --- a/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu +++ b/onnx_extended/ortops/optim/cuda/scatter_nd_of_shape.cu @@ -87,7 +87,8 @@ _ScatterNDKernelReduction(T *output_data, const size_t num_indices, const int64_ int64_t index = indices_data[i]; int64_t element_count_dim = element_counts_and_input_dims.dims[i - indices_start]; - int64_t dim_value = element_counts_and_input_dims.dims[i - indices_start + last_index_dimension]; + int64_t dim_value = + element_counts_and_input_dims.dims[i - indices_start + last_index_dimension]; // Clamp the index if out of range // This would have been an error in the CPU kernel, but throwing in the CUDA EP @@ -132,7 +133,7 @@ struct GridDim { void ScatterNDImplReduction(cudaStream_t stream, void *output_data, const int32_t element_type, const size_t num_indices, const int64_t *indices_data, const int64_t last_index_dimension, - const Shape2& element_counts_and_input_dims, + const Shape2 &element_counts_and_input_dims, const void *updates_data, const size_t num_updates_elements, Reduction reduction) { if (num_indices == 0) @@ -215,20 +216,54 @@ struct TensorPitches : std::vector { }; template -__global__ void addition_inplace_kernel(T *output_data, const int64_t *indices_data, const T *updates_data, const CUDA_LONG indice_size, const CUDA_LONG nrows, const CUDA_LONG stride) { +__global__ void +addition_inplace_kernel(T *__restrict__ output_data, const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride) { HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; if (id >= stride) return; - - for(size_t i=0; i < nrows; ++i) { + + for (size_t i = 0; i < nrows; ++i) { output_data[i * stride + id] = 0; } - for(size_t i=0;i < indice_size; ++i) { + for (size_t i = 0; i < indice_size; ++i) { output_data[indices_data[i] * stride + id] += updates_data[i * stride + id]; } } +#ifdef ENABLE_NCONT + +template +__global__ void +addition_inplace_kernelN(T *__restrict__ output_data, const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride) { + HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x * NCONT; + + T *out; + for (size_t i = 0; i < nrows; ++i) { + out = output_data + i * stride + id; +#pragma unroll + for (int k = 0; k < NCONT; ++k) { + out[k] = 0; + } + } + + const T *up; + for (size_t i = 0; i < indice_size; ++i) { + out = output_data + (indices_data[i] * stride + id); + up = updates_data + (i * stride + id); +#pragma unroll + for (int k = 0; k < NCONT; ++k) { + out[k] += up[k]; + } + } +} + +#endif + ////////////////// // ScatterNDOfShapeOp... ////////////////// @@ -387,8 +422,9 @@ template void ScatterNDOfShapeKernel::Compute(OrtKernelContext * "updates are not on GPU"); auto mem = shape.GetTensorMemoryInfo(); - EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, - "The shape should be on CPU already, but mem.GetDeviceType()=", mem.GetDeviceType(), "."); + EXT_ENFORCE( + mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, + "The shape should be on CPU already, but mem.GetDeviceType()=", mem.GetDeviceType(), "."); const int64_t *X = shape.GetTensorData(); std::vector dims(X, X + dimensions[0]); output = ctx.GetOutput(0, dims); @@ -436,7 +472,7 @@ void ScatterNDOfShapeKernel::ComputeNone(cudaStream_t &stream, // for the range [0, last_index_dimension). // To avoid multiple GPU data transfers, we combine this into one array and send it through TensorPitches input_strides(input_shape); - Shape2 element_counts_and_input_dims; + Shape2 element_counts_and_input_dims; memset(element_counts_and_input_dims.dims, 0, sizeof(int64_t) * last_index_dimension * 2); for (int64_t i = 0; i < last_index_dimension; ++i) { @@ -482,12 +518,26 @@ void _ComputeOptimize(cudaStream_t stream, const std::vector &input_sha std::vector processed(input_shape[0], 0); std::vector processed_once(input_shape[0], 0); - int threads_per_block = std::min(256, maxThreadPerBlock_ / 2); - int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; - dim3 threads(threads_per_block); - dim3 blocks(blocks_per_grid); + int threads_per_block = std::min(256, maxThreadPerBlock_ / 8); - addition_inplace_kernel<<>>(output_data, indices_data, updates_data, indice_size, nrows, stride); +#ifdef ENABLE_NCONT +#define NCONT 65536 + if (stride % NCONT == 0 && stride > threads_per_block * NCONT) { + int blocks_per_grid = (stride / NCONT + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + addition_inplace_kernelN<<>>( + output_data, indices_data, updates_data, indice_size, nrows, stride); + } else { +#endif + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + addition_inplace_kernel<<>>( + output_data, indices_data, updates_data, indice_size, nrows, stride); +#ifdef ENABLE_NCONT + } +#endif } template From 645a04c70d7db2c4ae42a6c0cd40a411b4b780d9 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Apr 2024 17:27:05 +0000 Subject: [PATCH 5/8] fix issue --- _unittests/ut_ortops/test_optim_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/_unittests/ut_ortops/test_optim_cuda.py b/_unittests/ut_ortops/test_optim_cuda.py index 110b9f68..282f33d6 100644 --- a/_unittests/ut_ortops/test_optim_cuda.py +++ b/_unittests/ut_ortops/test_optim_cuda.py @@ -274,6 +274,7 @@ def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): if __name__ == "disabled__main__": print("done.") + @unittest.skipIf(not has_cuda(), reason="cuda not available") def test_scatternd_of_shape_optimize_cuda(self): with self.subTest(optimize=True, dim3=True): self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT) From 4b2c524e08479e87bc5f58878454a0e258bf520b Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 16 Apr 2024 17:56:43 +0000 Subject: [PATCH 6/8] doc --- CHANGELOGS.rst | 1 + _doc/examples/plot_op_scatternd_cuda.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index b81d9e4f..4ca5a105 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.3.0 +++++ +* :pr:`162`: add ScatterNDOfShape implementation on CUDA without atomics * :pr:`159`: add AddAdd custom operator on CUDA * :pr:`158`: add MulMul custom operator on CUDA * :pr:`157`: add ScatterNDOfShape custom operator diff --git a/_doc/examples/plot_op_scatternd_cuda.py b/_doc/examples/plot_op_scatternd_cuda.py index 62136adf..921f2cf9 100644 --- a/_doc/examples/plot_op_scatternd_cuda.py +++ b/_doc/examples/plot_op_scatternd_cuda.py @@ -10,7 +10,7 @@ ScatterND ========= -This configuration happens in a :epkg:`LLAMA` model. +This configuration happens in a :epkg:`Llama` model. :: From 15ded26b2769c1d10054d4f179d1051a04d34203 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 17 Apr 2024 10:31:59 +0200 Subject: [PATCH 7/8] Upgrade to onnxruntime==1.17.3 --- CHANGELOGS.rst | 2 +- _doc/tutorial/old_version.rst | 2 +- _doc/tutorial/trees.rst | 2 +- _unittests/ut_ortcy/test_ortcy.py | 2 +- _unittests/ut_ortops/test_optim_tfidf_vectorizer.py | 2 +- _unittests/ut_ortops/test_optim_tfidf_vectorizer_sparse.py | 2 +- _unittests/ut_xrun_doc/test_documentation_examples.py | 2 +- onnx_extended/__init__.py | 2 +- onnx_extended/tools/run_onnx.py | 2 +- requirements-dev.txt | 2 +- setup.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4ca5a105..e9ed890d 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,11 +4,11 @@ Change Logs 0.3.0 +++++ +* :pr:`163`: use onnxruntime==1.17.3 as default * :pr:`162`: add ScatterNDOfShape implementation on CUDA without atomics * :pr:`159`: add AddAdd custom operator on CUDA * :pr:`158`: add MulMul custom operator on CUDA * :pr:`157`: add ScatterNDOfShape custom operator -* :pr:`153`: use onnxruntime==1.17.1 as default * :pr:`155`: add a function to draw a timeline from a profile * :pr:`154`: improves ploting legend for profiling * :pr:`151`: refactoring of TreeEnsemble code to make them faster diff --git a/_doc/tutorial/old_version.rst b/_doc/tutorial/old_version.rst index 0aeec4b5..a2736c55 100644 --- a/_doc/tutorial/old_version.rst +++ b/_doc/tutorial/old_version.rst @@ -112,7 +112,7 @@ It calls function :func:`bench_virtual =1.17.1 +onnxruntime>=1.17.3 openpyxl opt_einsum packaging diff --git a/setup.py b/setup.py index c2eb9eb3..e9950183 100644 --- a/setup.py +++ b/setup.py @@ -730,7 +730,7 @@ def get_ext_modules(): # beginning of setup ###################### -DEFAULT_ORT_VERSION = "1.17.1" +DEFAULT_ORT_VERSION = "1.17.3" here = os.path.dirname(__file__) if here == "": here = "." From 88f9f0221bb6e4571fbd6d15298e6a5ac685d00d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Wed, 17 Apr 2024 11:03:10 +0200 Subject: [PATCH 8/8] fix documentation --- .github/workflows/documentation.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 5acf9d7b..e3de59c9 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -82,14 +82,14 @@ jobs: - name: Check for errors and warnings run: | - if [[ $(grep ERROR doc.txt | grep -v 'validation.cuda') ]]; then + if [[ $(grep ERROR doc.txt | grep -v 'validation.cuda' | grep -v 'pickable') ]]; then echo "Documentation produces errors." - grep ERROR doc.txt | grep -v 'validation.cuda' + grep ERROR doc.txt | grep -v 'validation.cuda' | grep -v 'pickable' exit 1 fi - if [[ $(grep WARNING doc.txt | grep -v 'validation.cuda') ]]; then + if [[ $(grep WARNING doc.txt | grep -v 'validation.cuda' | grep -v 'pickable') ]]; then echo "Documentation produces warnings." - grep WARNING doc.txt | grep -v 'validation.cuda' + grep WARNING doc.txt | grep -v 'validation.cuda' | grep -v 'pickable' exit 1 fi