From 20f2dd8b6ba1c461f9a8d90a578178eab1ff20f7 Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Tue, 31 Oct 2023 14:58:21 -0700 Subject: [PATCH 01/12] use onnx rel-1.15.0, update cgman, cmake/external and requirement hash (#18177) --- cgmanifests/generated/cgmanifest.json | 12 +----------- cmake/deps.txt | 2 +- cmake/external/onnx | 2 +- .../azure-pipelines/templates/download-deps.yml | 4 ++-- .../x64/python/cpu/scripts/requirements.txt | 2 +- .../linux/docker/scripts/manylinux/requirements.txt | 2 +- .../github/linux/docker/scripts/requirements.txt | 2 +- 7 files changed, 8 insertions(+), 18 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f9501253661a..6b0e3659bd23 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -26,7 +26,7 @@ "component": { "type": "git", "git": { - "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "commitHash": "b86cc54efce19530fb953e4b21f57e6b3888534c", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "git submodule at cmake/external/onnx" @@ -192,16 +192,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "onnx" - } - }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index 631d326e2ba5..aeb7c05080ab 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e +onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index 6a20ba82b439..b86cc54efce1 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 +Subproject commit b86cc54efce19530fb953e4b21f57e6b3888534c diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 1373381e4c83..0f6310724e9a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.97 + version: 1.0.104 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index 5341ae062d33..680b12602910 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index b2893286803b..8ef1fd452297 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -4,7 +4,7 @@ mypy pytest setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 5d48a93b09c9..5673bddfe058 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -5,7 +5,7 @@ mypy pytest setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx +git+http://github.com/onnx/onnx.git@b86cc54efce19530fb953e4b21f57e6b3888534c#egg=onnx argparse sympy==1.12 flatbuffers From ed41a2836c7963f4c46a073ea7bc29f971e06618 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Tue, 31 Oct 2023 22:48:32 +0000 Subject: [PATCH 02/12] Fix cast removal bug (#17953) The `RemoveDuplicateCastTransformer` fairly naively removed Cast nodes from the graph without considering precision loss when using the same `TypeGroup`. For instance, F64 -> F32 -> F64 would be optimised out of the graph. I also noticed that signedness was not accounted for, which is not covered by any existing issue but is a problem. For example doing int -> unsigned int -> int produces very different values for negative inputs and so should not be optimised out One could argue that we shouldn't be performing such cast elimination at all (at least not in this transformer). The original scope might be well restricted to only eliminating unnecessary casts from the `InsertCastTransformer` and no others. ### Motivation and Context This should fix https://github.com/microsoft/onnxruntime/issues/17565, ttps://github.com/microsoft/onnxruntime/issues/9915 and https://github.com/microsoft/onnxruntime/issues/8787. --- .../core/optimizer/insert_cast_transformer.cc | 86 +++++++++++++++---- .../framework/insert_cast_transformer_test.cc | 65 ++++++++++++++ 2 files changed, 133 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 7c087ec77d9f..959fcd6efdc3 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -32,7 +32,7 @@ onnxruntime::NodeArg* AddCastNode(onnxruntime::Graph& graph, int64_t to_type, onnxruntime::ProviderType providerType) { // insert cast op to cast input - std::string node_name = graph.GenerateNodeName("InsertedCast_" + old_arg->Name()); + std::string node_name = graph.GenerateNodeName("InsertedPrecisionFreeCast_" + old_arg->Name()); auto* new_arg = &graph.GetOrCreateNodeArg(node_name, new_type); @@ -235,7 +235,8 @@ enum TypeGroup { Unknown = -1, Bool = 0, Integer = 1, - Float = 2, + Unsigned = 2, + Float = 3, }; TypeGroup GetTypeGroup(DataType type) { @@ -243,11 +244,14 @@ TypeGroup GetTypeGroup(DataType type) { return Bool; } - if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)" || - *type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { return Integer; } + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { return Float; } @@ -255,6 +259,22 @@ TypeGroup GetTypeGroup(DataType type) { return Unknown; } +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -262,6 +282,48 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. + + // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. + // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + return true; + } + + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + } + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { auto output_args = graph.GetOutputs(); InlinedHashSet graph_outputs; @@ -293,17 +355,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { // - for each consumer cast node, it meets above condition for this optimization. auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - TypeGroup src_type_group = GetTypeGroup(src_type); - TypeGroup dst_type_group = GetTypeGroup(dst_type); - if (src_type_group == Unknown || dst_type_group == Unknown) { - continue; - } - - bool loss_precision_cast = false; - if (src_type_group > dst_type_group) { - loss_precision_cast = true; - } + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); size_t num_children = node.GetOutputEdgesCount(); bool inconsistent_casts = false; @@ -312,10 +365,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - TypeGroup src_type_group1 = GetTypeGroup(src_type1); - TypeGroup dst_type_group1 = GetTypeGroup(dst_type1); - if (src_type_group1 == Unknown || dst_type_group1 == Unknown || - (loss_precision_cast && dst_type_group1 > src_type_group1)) { + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; break; } diff --git a/onnxruntime/test/framework/insert_cast_transformer_test.cc b/onnxruntime/test/framework/insert_cast_transformer_test.cc index c38baee39216..1804c09043c7 100644 --- a/onnxruntime/test/framework/insert_cast_transformer_test.cc +++ b/onnxruntime/test/framework/insert_cast_transformer_test.cc @@ -4,6 +4,7 @@ #include "core/framework/allocator.h" #include "core/optimizer/insert_cast_transformer.h" #include "core/graph/model.h" +#include "core/graph/node_attr_utils.h" #include "gtest/gtest.h" #include "test_utils.h" #include "test/test_environment.h" @@ -110,6 +111,70 @@ TEST(TransformerTest, InsertCastAllCPUTest) { } } +TEST(TransformerTest, CastRemovalDoesNotLowerPrecisionTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_float_32; + tensor_float_32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + TypeProto tensor_float_64; + tensor_float_64.mutable_tensor_type()->set_elem_type(TensorProto_DataType_DOUBLE); + onnxruntime::NodeArg n1_def("N1", &tensor_float_64), + n2_def("N2", &tensor_float_32), + n3_def("N3", &tensor_float_64); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))}}; + + graph.AddNode("node1", "Cast", "F64 to F32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "F32 to F64 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting f64 -> f32 -> f64 we should not be optimising away the cast since there is a loss of precision. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + +TEST(TransformerTest, CastRemovalDoesNotRemoveSignednessTest) { + auto model = std::make_shared("test", false, DefaultLoggingManager().DefaultLogger()); + onnxruntime::Graph& graph = model->MainGraph(); + TypeProto tensor_uint32; + tensor_uint32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_UINT32); + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + onnxruntime::NodeArg n1_def("N1", &tensor_int32), + n2_def("N2", &tensor_uint32), + n3_def("N3", &tensor_int32); + + NodeAttributes n1_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_UINT32))}}; + NodeAttributes n2_attrs = {{"to", utils::MakeAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32))}}; + + graph.AddNode("node1", "Cast", "I32 to UI32 cast", ArgMap{&n1_def}, ArgMap{&n2_def}, &n1_attrs); + graph.AddNode("node2", "Cast", "UI32 to I32 cast", ArgMap{&n2_def}, ArgMap{&n3_def}, &n2_attrs); + + auto status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + InsertCastTransformer cast_inserter("Test", DefaultCpuExecutionProvider()->GetKernelRegistry().get()); + + bool modified = true; + status = cast_inserter.Apply(graph, modified, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // When casting i32 -> ui32 -> i32 we should not be optimising away the cast since applying the casts produces a very different result. + EXPECT_EQ(graph.NumberOfNodes(), 2); +} + // test that when there are 3 Cast ops in a row we remove the correct ones TEST(TransformerTest, ThreeInARowRemoval) { auto model_uri = MODEL_FOLDER ORT_TSTR("triple-cast.onnx"); From 62c7894ffe15efb7d43d891a326c2cbdcfbb529d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 1 Nov 2023 09:25:48 +1000 Subject: [PATCH 03/12] Add mobile CIs to list run by script for external PRs. (#18094) ### Description Add the mobile CIs to the list so we check external PRs don't break those. ### Motivation and Context Recent external PR was found to break iOS CI after checkin --- tools/python/run_CIs_for_external_pr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index dcc6a92d84ef..7a77839c4a4e 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,10 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", ] # remove pipelines that have already run successfully From 2b95e74fa113ec168a79974987b2c6b98cecf700 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:50:27 -0700 Subject: [PATCH 04/12] Versioning for custom op (#18088) Allow custom ops to have versions. --------- Co-authored-by: Randy Shuai --- .../core/session/onnxruntime_c_api.h | 4 ++ .../core/session/onnxruntime_cxx_api.h | 13 ++++ .../core/session/onnxruntime_lite_custom_op.h | 59 ++++++++++++++----- onnxruntime/core/session/custom_ops.cc | 22 ++++++- onnxruntime/test/shared_lib/test_inference.cc | 16 +++++ .../testdata/custom_op_library/cpu/cpu_ops.cc | 37 +++++++----- .../test/testdata/fuse_select_filter.onnx | 5 +- .../testdata/fuse_select_filter_opset_8.onnx | 29 +++++++++ 8 files changed, 148 insertions(+), 37 deletions(-) create mode 100644 onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4a63018f870a..613c1ac93cf1 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4605,6 +4605,10 @@ struct OrtCustomOp { OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 467eb31ee2c8..92c25d8688b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2228,6 +2228,8 @@ struct ShapeInferContext { using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); +#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 + template struct CustomOpBase : OrtCustomOp { CustomOpBase() { @@ -2280,6 +2282,14 @@ struct CustomOpBase : OrtCustomOp { } SetShapeInferFn(0); + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { + return static_cast(this_)->end_ver_; + }; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider @@ -2348,6 +2358,9 @@ struct CustomOpBase : OrtCustomOp { protected: // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. void GetSessionConfigs(std::unordered_map& out, ConstSessionOptions options) const; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index b12221e56b79..443710884743 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp { PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ) OrtLiteCustomOp(const char* op_name, - const char* execution_provider) : op_name_(op_name), - execution_provider_(execution_provider) { + const char* execution_provider, + int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -837,6 +840,16 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::KernelCompute = {}; OrtCustomOp::InferOutputShapeFn = {}; + + OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->start_ver_; + }; + + OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) { + auto self = reinterpret_cast(op); + return self->end_ver_; + }; } const std::string op_name_; @@ -844,6 +857,9 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + + int start_ver_ = 1; + int end_ver_ = MAX_CUSTOM_OP_END_VER; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFn compute_fn, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_(compute_fn), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtLiteCustomFunc(const char* op_name, const char* execution_provider, ComputeFnReturnStatus compute_fn_return_status, - ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + ShapeInferFn shape_infer_fn = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), + compute_fn_return_status_(compute_fn_return_status), + shape_infer_fn_(shape_infer_fn) { ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { }; OrtLiteCustomStruct(const char* op_name, - const char* execution_provider) : OrtLiteCustomOp(op_name, - execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { @@ -1049,25 +1070,31 @@ template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, void (*custom_compute_fn)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider, Status (*custom_compute_fn_v2)(Args...), - Status (*shape_infer_fn)(ShapeInferContext&) = {}) { + Status (*shape_infer_fn)(ShapeInferContext&) = {}, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomFunc; - return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release(); + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release(); } template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, - const char* execution_provider) { + const char* execution_provider, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) { using LiteOp = OrtLiteCustomStruct; - return std::make_unique(op_name, execution_provider).release(); + return std::make_unique(op_name, execution_provider, start_ver, end_ver).release(); } } // namespace Custom diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 041250adc3fc..b827c28f129b 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -25,6 +25,7 @@ #if !defined(ORT_MINIMAL_BUILD) static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; +static constexpr uint32_t min_ort_version_with_custom_version = 17; #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -698,8 +699,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust KernelDefBuilder def_builder; def_builder.SetName(op->GetName(op)) - .SetDomain(domain) - .SinceVersion(1); + .SetDomain(domain); + + if (op->version >= min_ort_version_with_custom_version) { + if (op->GetStartVersion && op->GetEndVersion) { + def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op)); + } else if (op->GetStartVersion) { + def_builder.SinceVersion(op->GetStartVersion(op)); + } else { + def_builder.SinceVersion(1); + } + } else { + def_builder.SinceVersion(1); + } // GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions // to work with newer versions (> 12) of the ORT binary. @@ -820,7 +832,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types"); } schema.SetDomain(domain); - schema.SinceVersion(1); + if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) { + schema.SinceVersion(op->GetStartVersion(op)); + } else { + schema.SinceVersion(1); + } schema.AllowUncheckedAttributes(); if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ba282193c5ca..33d50f90333c 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3323,6 +3323,22 @@ TEST(LiteCustomOpTest, CustomFunc) { ASSERT_TRUE(floats_output[1] == 16); } +TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) { + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); + session_options.SetLogSeverityLevel(0); +#if defined(_WIN32) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll")); +#elif defined(__APPLE__) + session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib")); +#else + session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so")); +#endif + + EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception); +} + struct Merge { Merge(const OrtApi* ort_api, const OrtKernelInfo* info) { int64_t reverse; diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index ad99b675c7d2..85edfa0e59f1 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span& indices_in, } } -void Filter(const Ort::Custom::Tensor& floats_in, - Ort::Custom::Tensor& floats_out) { - const float* in = floats_in.Data(); - auto in_len = floats_in.NumberOfElement(); +struct Filter { + Filter(const OrtApi*, const OrtKernelInfo*) {} + Ort::Status Compute(const Ort::Custom::Tensor& floats_in, + Ort::Custom::Tensor& floats_out) { + const float* in = floats_in.Data(); + auto in_len = floats_in.NumberOfElement(); + + std::vector filter_floats; + for (int64_t i = 0; i < in_len; ++i) { + if (in[i] > 1.f) { + filter_floats.push_back(in[i]); + } + } - std::vector filter_floats; - for (int64_t i = 0; i < in_len; ++i) { - if (in[i] > 1.f) { - filter_floats.push_back(in[i]); + float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); + for (size_t j = 0; j < filter_floats.size(); ++j) { + out[j] = filter_floats[j]; } - } - float* out = static_cast(floats_out.Allocate({static_cast(filter_floats.size())})); - for (size_t j = 0; j < filter_floats.size(); ++j) { - out[j] = filter_floats[j]; + return Ort::Status{nullptr}; } -} +}; void Box(const Ort::Custom::Tensor* float_in_1, const Ort::Custom::Tensor* float_in_2, @@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; static const std::unique_ptr c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; static const std::unique_ptr c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop)}; - static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)}; + static const std::unique_ptr c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)}; static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; - static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; + static const std::unique_ptr c_Filter{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", 15, 17)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; @@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_MulTopOpInt32.get()); domain.Add(c_Fuse.get()); domain.Add(c_Select.get()); - domain.Add(c_Fill.get()); + domain.Add(c_Filter.get()); domain.Add(c_Box.get()); domain.Add(c_CopyTensorArrayAllVariadic.get()); domain.Add(c_CopyTensorArrayCombined.get()); diff --git a/onnxruntime/test/testdata/fuse_select_filter.onnx b/onnxruntime/test/testdata/fuse_select_filter.onnx index 15d7dd64788d..0b881228edb9 100644 --- a/onnxruntime/test/testdata/fuse_select_filter.onnx +++ b/onnxruntime/test/testdata/fuse_select_filter.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä P vector_1 vector_2 @@ -25,4 +25,5 @@ N ÿÿÿÿÿÿÿÿÿb& vector_filtered  - ÿÿÿÿÿÿÿÿÿB \ No newline at end of file + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file diff --git a/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx new file mode 100644 index 000000000000..3ea27767eb9f --- /dev/null +++ b/onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx @@ -0,0 +1,29 @@ + :Ä +P +vector_1 +vector_2 +alpha vector_fused fuse_node"Fuse* + fuse_algo :v2 +4 +indicesindices_selected select_node"Select:v2 +N + vector_fused +indices_selectedvector_gathered gather_node"GatherElements +; +vector_gatheredvector_filtered filter_node"Filter:v2graphZ +vector_1 + + ÿÿÿÿÿÿÿÿÿZ +vector_2 + + ÿÿÿÿÿÿÿÿÿZ +alpha + + ÿÿÿÿÿÿÿÿÿZ +indices + + ÿÿÿÿÿÿÿÿÿb& +vector_filtered + + ÿÿÿÿÿÿÿÿÿB +v2 \ No newline at end of file From d1b85f5fb4fff6fc674e50e2053039c7ded4969e Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 31 Oct 2023 17:53:52 -0700 Subject: [PATCH 05/12] Reduce LLaMA memory usage (#18181) ### Description This PR reduces the memory usage when exporting and benchmarking LLaMA. ### Motivation and Context - Exporting: The PyTorch model is deleted from memory after a successful export instead of deleting it from memory after exporting + converting the ONNX model to the desired precision. - Benchmarking: In the ONNX model with GroupQueryAttention, the KV cache inputs use the same GPU memory for both the prompt and token generation benchmarks. --- .../transformers/models/llama/benchmark.py | 104 +++---- .../models/llama/convert_to_onnx.py | 2 +- .../transformers/models/llama/llama_inputs.py | 271 +++++++++++++----- .../transformers/models/llama/llama_parity.py | 57 ++-- 4 files changed, 248 insertions(+), 186 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index a721979eb0bc..245ff3dfe7f9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,9 +11,8 @@ import onnx import psutil import torch -from benchmark_helper import setup_logger from llama_inputs import ( - convert_inputs_for_ort, + add_io_bindings, get_merged_sample_with_past_kv_inputs, get_msft_sample_inputs, get_sample_inputs, @@ -25,7 +24,7 @@ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory +from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -48,9 +47,19 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): init_inputs, iter_inputs = None, None # For past_present_share_buffer: - # Set max_seq_len to 2048 for Hugging Face model since that is the default value - # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported - max_seq_len = 2048 + # Set max_seq_len to 16384 for CodeLLaMA (finetuned variant of LLaMA-2) + # Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value + # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_seq_len = ( + 2048 + if args.benchmark_type == "ort-msft" + else 16384 + if "codellama" in temp_name + else 4096 + if "llama2" in temp_name + else 2048 + ) if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( @@ -95,7 +104,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -104,7 +115,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="pt", return_dict=True, ) @@ -116,7 +129,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=args.sequence_length, past_seq_len=0, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, + engine="ort", return_dict=True, ) iter_inputs = get_merged_sample_with_past_kv_inputs( @@ -125,26 +140,10 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, seq_len=1, past_seq_len=args.sequence_length, - use_fp16=args.use_fp16, - return_dict=True, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + engine="ort", + return_dict=True, ) elif args.benchmark_type == "ort-msft": @@ -156,6 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=0, seq_len=args.sequence_length, + max_seq_len=max_seq_len, use_fp16=args.use_fp16, split_kv=split_kv, ) @@ -164,26 +164,9 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): args.batch_size, past_seq_len=args.sequence_length, seq_len=1, - use_fp16=args.use_fp16, - split_kv=split_kv, - ) - init_inputs = convert_inputs_for_ort( - init_inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=0, max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, - ) - iter_inputs = convert_inputs_for_ort( - iter_inputs, use_fp16=args.use_fp16, - use_buffer_share=args.past_present_share_buffer, - past_seq_len=args.sequence_length, - max_seq_len=max_seq_len, - device=args.device, - device_id=args.device_id, + split_kv=split_kv, ) else: @@ -449,7 +432,7 @@ def get_logits(inputs): def run_ort_inference(args, init_inputs, iter_inputs, model): - def prepare_ort_inputs(inputs): + def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Check that all model inputs will be provided model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs())) user_inputs = set(inputs.keys()) @@ -467,29 +450,13 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.past_present_share_buffer: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.past_present_share_buffer and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) - + io_binding, kv_cache_ortvalues = add_io_bindings( + model, inputs, args.device, int(args.device_id), kv_cache_ortvalues + ) setattr(args, "io_binding", io_binding) # noqa: B010 - return io_binding + return io_binding, kv_cache_ortvalues - return inputs + return inputs, kv_cache_ortvalues def with_io_binding(io_binding): # Inference pass with IO binding @@ -501,9 +468,10 @@ def without_io_binding(inputs): return outputs generate_fn = with_io_binding if args.device != "cpu" else without_io_binding + kv_cache_ortvalues = {} if args.profile: - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt") # Turn profiling off to stop appending to log file @@ -513,7 +481,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log @@ -524,12 +492,12 @@ def without_io_binding(inputs): # ORT evaluations logger.info("\nEvaluating `model(inputs)` step to get past_key_values") - ort_init_inputs = prepare_ort_inputs(init_inputs) + ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_init_inputs) measure_fn(args, generate_fn, ort_init_inputs) logger.info("\nEvaluating `model(inputs)` step with past_key_values") - ort_iter_inputs = prepare_ort_inputs(iter_inputs) + ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues) time_fn(args, generate_fn, ort_iter_inputs) measure_fn(args, generate_fn, ort_iter_inputs) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 69603fd3ed48..3f05be53c672 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -716,6 +716,7 @@ def main(): run_torchscript_separate_export(args, l_config, llama) else: run_torchscript_merged_export(args, l_config, llama) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check # Set model paths to store FP32 optimized model decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") @@ -811,7 +812,6 @@ def main(): logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") remove_existing_model(fp_path) - del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 2652e9f0ca64..f7a1b05249ab 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,7 +4,7 @@ import torch from transformers import LlamaConfig -from onnxruntime import OrtValue +from onnxruntime import InferenceSession, OrtValue # Get position_ids from attention_mask @@ -12,22 +12,36 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: + # Shape: (batch_size, 1) position_ids = position_ids[:, -1].unsqueeze(-1) + + # Shape: (batch_size, sequence_length) return position_ids # Inputs for first pass to get initial past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, sequence_length) +# position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, return_dict: bool = False + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + engine: str = "pt", + return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, seq_len, device=device, dtype=torch.int64) - # position_ids is of shape (batch_size, seq_len) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) position_ids = get_position_ids(attention_mask, use_past_kv=False) + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + if not return_dict: + # For export return (input_ids, attention_mask, position_ids) inputs = { @@ -39,85 +53,192 @@ def get_sample_inputs( # Inputs for subsequent passes with past_key_values +# input_ids: (batch_size, 1) +# attention_mask: (batch_size, past_sequence_length + 1) +# position_ids: (batch_size, 1) +# past_key: (batch_size, num_heads, past_sequence_length, head_size) +# past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), device=device, dtype=torch.int64) - attention_mask = torch.ones(batch_size, past_seq_len + 1, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) # position_ids is of shape (batch_size, 1) position_ids = get_position_ids(attention_mask, use_past_kv=True) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs # Inputs for all passes with past_key_values +# input_ids: (batch_size, sequence_length) +# attention_mask: (batch_size, past_sequence_length + sequence_length) +# position_ids: (batch_size, sequence_length) +# past_key: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length +# past_value: (batch_size, num_heads, kv_sequence_length, head_size) +# For models with GQA, kv_sequence_length = max_sequence_length +# For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, + max_seq_len: int, use_fp16: bool = False, + engine: str = "pt", return_dict: bool = False, ): - input_ids = torch.randint( - low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 - ) - attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) - past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16) + + # Convert inputs to NumPy (for ORT) or send to device (for PyTorch) + input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device) + attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device) + position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device) + past_kv = ( + flatten_past_kv_inputs(past_kv) + if engine == "ort" + else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv)) + ) if not return_dict: + # For export + assert isinstance(past_kv, list) return (input_ids, attention_mask, position_ids, past_kv) inputs = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, - "past_key_values": past_kv, } + if engine == "ort": + assert isinstance(past_kv, dict) + inputs.update(past_kv) + + if use_fp16: # If model has GQA + del inputs["attention_mask"] + inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) + + else: + assert isinstance(past_kv, list) + inputs["past_key_values"] = past_kv + return inputs -# Create past_key_values -def get_sample_past_kv_inputs( - config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool +# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx +def get_msft_sample_inputs( + config: LlamaConfig, + batch_size: int, + past_seq_len: int, + seq_len: int, + max_seq_len: int, + use_fp16: bool, + split_kv: bool, ): + np_dtype = np.float16 if use_fp16 else np.float32 + head_size = config.hidden_size // config.num_attention_heads + + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + + if use_fp16: # If model has GQA + del ort_inputs["attn_mask"] + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) + + return ort_inputs + + +# Create past_key_values +# Each is of shape (batch_size, num_heads, past_sequence_length, head_size) +def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool): num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), - torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), + torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype), ) for _ in range(config.num_hidden_layers) ] return past_kv -# Convert list of past_kv to dict of past_key and past_value -def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], use_fp16: bool): +# Convert list of past_key_values to dict of past_key and past_value +def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]]): past_kv = {} - np_dtype = np.float16 if use_fp16 else np.float32 for i, (past_k, past_v) in enumerate(past_key_values): - past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy().astype(np_dtype) - past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy().astype(np_dtype) + past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy() + past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy() return past_kv @@ -136,7 +257,7 @@ def convert_inputs_for_ort( if isinstance(v, np.ndarray): ort_inputs[k] = v elif k == "past_key_values": - ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + ort_inputs.update(flatten_past_kv_inputs(v)) elif k == "attention_mask" and use_fp16 and use_buffer_share: # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, # and GQA supports a causal mask by default @@ -146,59 +267,55 @@ def convert_inputs_for_ort( else: ort_inputs[k] = v.detach().cpu().numpy() - # Enable past-present-share-buffer by using device memory directly + # Reshape kv caches if using past-present-share-buffer if use_buffer_share and device != "" and device != "cpu" and device_id > -1: - for k, v in ort_inputs.items(): - new_v = v - # Allocate new buffers with max_sequence_length for GQA - if "cache" in k or "past_key_values" in k: - # Copy v (BxSxPxH) into new_v (BxSxMxH) - batch_size, num_heads, _, head_size = v.shape - new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) - new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v - ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs -# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs( - config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool -): - np_dtype = np.float16 if use_fp16 else np.float32 - head_size = config.hidden_size // config.num_attention_heads - max_seq_len = 2048 +def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int): + for k, v in ort_inputs.items(): + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = new_v + return ort_inputs - if not split_kv: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } - else: - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( - np.int32 - ), - "pos": np.array(past_seq_len, dtype=np.int64), - } - for i in range(config.num_hidden_layers): - ort_inputs.update( - { - f"k_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - f"v_{i}_cache": np.random.rand( - batch_size, config.num_attention_heads, past_seq_len, head_size - ).astype(np_dtype), - } - ) - return ort_inputs +# Add IO bindings for execution providers +def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): + use_fp16 = False + io_binding = model.io_binding() + + for k, v in ort_inputs.items(): + # Detect if model is in FP16 + if v.dtype == np.float16: + use_fp16 = True + + # Bind OrtValue inputs to device + if use_fp16 and ("cache" in k or "past_key_values" in k): + if k not in kv_cache_ortvalues: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + kv_cache_ortvalues[k] = v_device + else: + kv_cache_ortvalues[k].update_inplace(v) + io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k]) + else: + v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) + io_binding.bind_ortvalue_input(k, v_device) + + for output in model.get_outputs(): + name = output.name + if use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to past KV cache inputs in order to buffer share + input_name = name.replace("out", "cache").replace("present", "past_key_values") + io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) + else: + io_binding.bind_output(name, device_type=device, device_id=device_id) + + return io_binding, kv_cache_ortvalues diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 4353d0606803..c1c5d3c412f2 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -8,6 +8,7 @@ import torch from benchmark_helper import setup_logger from llama_inputs import ( + add_io_bindings, convert_inputs_for_ort, get_merged_sample_with_past_kv_inputs, get_sample_inputs, @@ -22,22 +23,24 @@ def get_sequence_lengths(args: argparse.Namespace): past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) - max_sequence_length = 2048 + temp_name = args.model_name.lower().replace("-", "").replace("_", "") + max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048 return past_sequence_length, curr_sequence_length, max_sequence_length def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity batch_size = 2 - past_sequence_length, sequence_length, _ = get_sequence_lengths(args) + past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args) if args.merged: inputs = get_merged_sample_with_past_kv_inputs( config, args.device, batch_size, - sequence_length, - past_sequence_length, + seq_len=sequence_length, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, use_fp16=args.use_fp16, return_dict=True, ) @@ -51,31 +54,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): - # Add IO bindings for non-CPU execution providers - io_binding = model.io_binding() - - for k, v in inputs.items(): - if args.use_fp16: - # Bind all OrtValue inputs to device - io_binding.bind_ortvalue_input(k, v) - else: - io_binding.bind_cpu_input(k, v) - - for output in model.get_outputs(): - name = output.name - if args.use_fp16 and ("out" in name or "present" in name): - # Bind present KV cache outputs to OrtValue with buffer sharing - io_binding.bind_ortvalue_output( - name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] - ) - else: - io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) - - return io_binding - - -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -111,7 +90,9 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": - io_binding = add_io_bindings(args, ort_model, inputs) + io_binding, kv_cache_ortvalues = add_io_bindings( + ort_model, inputs, args.execution_provider, int(args.device_id), kv_cache_ortvalues + ) io_binding.synchronize_inputs() start_time = time.time() @@ -131,17 +112,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = ( - 2e1 - if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path - else 1e-3 - if args.precision == "fp32" - else 5e-1 - ) + tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1 parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}") + return kv_cache_ortvalues def get_args(argv: List[str]): @@ -250,16 +226,17 @@ def main(argv: List[str] = []): # noqa: B006 use_cache=True, ).to(args.device) + kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) else: # Verify prompt generation in merged model (decoder_model.onnx) args.use_past_kv = False - verify_parity(args, config, llama) + kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama) + verify_parity(args, config, llama, kv_cache_ortvalues) if __name__ == "__main__": From c181159783d1245adb9bb1af18a469ad7d89df45 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 1 Nov 2023 11:30:32 +0800 Subject: [PATCH 06/12] [WebNN EP] Restore to use deviceType enum (#18154) The Chromium implementation will support `MLDeviceType` enum to align with spec. CL: https://chromium-review.googlesource.com/c/chromium/src/+/4986939 --- .../core/providers/webnn/webnn_execution_provider.cc | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 26c739e9a1ce..02a3d16b5b64 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -26,11 +26,7 @@ WebNNExecutionProvider::WebNNExecutionProvider( ORT_THROW("Failed to get ml from navigator."); } emscripten::val context_options = emscripten::val::object(); - // Currently WebNN implementation in Chromium temporarily reuses the MLContextOptions - // defined in Model Loader API, which uses MLDevicePreference instead of MLDeviceType - // defined in WebNN. Because there's an ongoing spec discussion to simplify this API at - // https://github.com/webmachinelearning/webnn/issues/302. - context_options.set("devicePreference", emscripten::val(webnn_device_flags)); + context_options.set("deviceType", emscripten::val(webnn_device_flags)); // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { preferred_layout_ = DataLayout::NHWC; From 819b5a3eba85cca9276c9d763c814eb45067b280 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 31 Oct 2023 21:05:42 -0700 Subject: [PATCH 07/12] Split KV on MHA and Attention ops (#18007) ### Description Implement Split KV optimization for FlashAttention in MHA and Attention operators. ### Motivation and Context Can help further accelerate these ops. --- .../contrib_ops/cpu/bert/attention_common.h | 3 ++- .../contrib_ops/cuda/bert/attention.cc | 22 +++++++++++++++ .../contrib_ops/cuda/bert/attention_impl.cu | 4 ++- .../contrib_ops/cuda/bert/attention_impl.h | 5 ++++ .../cuda/bert/flash_attention/flash_api.cc | 27 ++++++++++++++++--- .../cuda/bert/flash_attention/flash_api.h | 6 ++--- .../flash_fwd_launch_template.h | 12 ++------- .../cuda/bert/group_query_attention.cc | 22 ++++++--------- .../cuda/bert/multihead_attention.cc | 22 +++++++++++++++ 9 files changed, 90 insertions(+), 33 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 5184dd99309b..0fd8790e0d29 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -55,6 +55,7 @@ struct AttentionParameters { int v_hidden_size; // hidden size of V int v_head_size; // hidden size per head of V int num_heads; + int num_splits; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; @@ -95,9 +96,9 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; - int num_splits; // number of splits for splitkv AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0dc7de0e9e51..bf6431cf1afb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -135,8 +135,24 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif if (!use_flash_attention) { @@ -279,6 +295,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.fused_runner = reinterpret_cast(fused_runner); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } return QkvToContext(device_prop, cublas, context->GetComputeStream(), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index b4a4ae208ceb..eb9e6d5c6246 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -316,7 +316,9 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, + parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + true)); DUMP_TENSOR("flash attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index d0a5fb51a25d..3e78978c3cc4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -88,6 +88,11 @@ struct AttentionData { T* v = nullptr; T* scratch = nullptr; AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + // Flash buffers + T* softmax_lse = nullptr; + T* softmax_lse_accum = nullptr; + T* out_accum = nullptr; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index ff7a22d253a5..89a27c4d2b0d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -140,11 +140,10 @@ void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream, bool force_split // So we find the best efficiency, then find the smallest number of splits that gets 85% // of the best efficiency. int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, - int max_splits, bool new_kv, bool is_sm8x) { + int max_splits) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = is_sm8x ? (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64)) - : (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64)); - const int num_n_blocks = (seqlen_k + (!new_kv ? 0 : seqlen_q) + block_n - 1) / block_n; + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. const int num_m_blocks = (seqlen_q + 64 - 1) / 64; @@ -190,6 +189,26 @@ int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_hea return 1; } +// Returns (num_splits, softmax_lse_accum bytes, out_accum bytes) +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs) { + int max_splits = 128; + // split kv buffers + int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, + num_SMs, max_splits); + if (num_splits > 1) { + // softmax_lse_accum buffer + int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q); + // out_accum buffer + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + int out_accum_bytes = get_out_accum_size(num_splits, batch_size, num_heads, seqlen_q, head_size_rounded); + return {num_splits, softmax_lse_accum_bytes, out_accum_bytes}; + } else { + return {0, 0, 0}; + } +} + Status mha_fwd(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 0a0328edb005..58f430425187 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -31,6 +31,7 @@ #if USE_FLASH_ATTENTION #include "core/providers/cuda/cuda_common.h" +#include namespace onnxruntime { namespace flash { @@ -99,10 +100,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, ); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); -size_t get_softmax_lse_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q); -size_t get_out_accum_size(int num_splits, int batch_size, int num_heads, int seqlen_q, int head_size_rounded); -int num_splits_heuristic(int batch_size, int seqlen_q, int seqlen_k, int num_heads, int head_size, int num_SMs, int max_splits, bool new_kv, bool is_sm8x); +std::tuple get_num_splits_and_buffer_sizes(int batch_size, int seqlen_q, int seqlen_k, int num_heads, + int head_size, int num_SMs); bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 784335a124c7..82dfa59b8f8e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -123,17 +123,9 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions - if (!is_sm8x) { // A100, H100 - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above - constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); - } + constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd>(params, stream); } template diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 65d19d447387..67d750aeac11 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -116,22 +116,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t out_accum_bytes = 0; size_t seqlens_k_bytes = 0; if (use_flash_attention) { + // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); - // split kv buffers - parameters.num_splits = onnxruntime::flash::num_splits_heuristic( + // split kv buffer + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, - parameters.head_size, device_prop.multiProcessorCount, 128, false, - device_prop.major == 8 && device_prop.minor > 0); - if (parameters.num_splits > 1) { - // softmax_lse_accum buffer - softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); - // out_accum buffer - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - const int head_size_rounded = round_multiple(parameters.head_size, 32); - out_accum_bytes = onnxruntime::flash::get_out_accum_size( - parameters.num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, head_size_rounded); - } + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; // seqlens_k buffer if (past_key != nullptr) { seqlens_k_bytes = sizeof(int) * parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index e3f53ca6a63c..ebd66d8c6528 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -153,8 +153,24 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { use_flash_attention = false; } + // Allocate buffers + size_t softmax_lse_accum_bytes = 0; + size_t out_accum_bytes = 0; + if (use_flash_attention) { + using namespace std; + auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( + parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = num_splits; + softmax_lse_accum_bytes = slse_accum_bytes; + out_accum_bytes = o_accum_bytes; + } + auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); + auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; + auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr + auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif bool use_fused_cross_attention = !use_flash_attention && @@ -291,6 +307,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); + if (softmax_lse_accum_buffer != nullptr) { + data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); + } + if (out_accum_buffer != nullptr) { + data.out_accum = reinterpret_cast(out_accum_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); From 69f029797d24e44c6854d9c68231bef174e627e4 Mon Sep 17 00:00:00 2001 From: weischan-quic <138087696+weischan-quic@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:04:42 +0800 Subject: [PATCH 08/12] [QNN EP] Fix Batch Normalization Op Builder (#17981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description There is a gap between onnx’s definition of batch normalization and QNN’s. According to the formula: onnx: `(X - input_mean) / sqrt(input_var + epsilon) * scale + B` QNN: `X * weight + bias` We can then deduce that: `weight = scale / sqrt(var + epsilon)` `bias = B – (mean * scale / sqrt(var + epsilon))` We must calculate the weight and bias, and their quantization parameters for QNN in QNN EP. Therefore, `scale`, `B`, `input_mean`, and `input_var` must be static (`initializer`). Implementation: Firstly, dequantize `scale`, `B`, `input_mean`, and `input_var` to floating point. Second, calculate `weight` and `bias`, and their quantization parameters. Finally, quantize `weight` and `bias`, and add them into `TensorWrapper` ### Motivation and Context Fix QnnHTPBackendTests.BatchNorm1D and QnnHTPBackendTests.BatchNorm2D failures --- .../opbuilder/batch_norm_op_builder.cc | 589 +++++++++++++++++- .../test/providers/qnn/batch_norm_htp_test.cc | 18 +- 2 files changed, 583 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index ccbc1acaa2f9..3e17fb157b16 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -1,16 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include + #include "core/providers/common.h" +#include "core/util/qmath.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class BatchNormOpBuilder : public BaseOpBuilder { @@ -18,9 +22,446 @@ class BatchNormOpBuilder : public BaseOpBuilder { BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(BatchNormOpBuilder); + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const override final ORT_MUST_USE_RESULT; + + std::pair CheckMinMax(float rmin, float rmax) const { + // Ensure a minimum range of 0.0001 (required by QNN) + rmax = std::max(rmax, rmin + 0.0001f); + + // Both QNN and ORT require the range to include 0.0f + rmin = std::min(rmin, 0.0f); + rmax = std::max(rmax, 0.0f); + + return std::make_pair(rmin, rmax); + } + + template + Status GetQminQmax(const Qnn_DataType_t qnn_data_type, + T& qmin, + T& qmax) const { + if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else if (qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + qmin = static_cast(std::numeric_limits::min()); + qmax = static_cast(std::numeric_limits::max()); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + Status GetQuantParams(float rmin, + float rmax, + const Qnn_DataType_t qnn_data_type, + float& scale, + int& zero_point) const { + std::tie(rmin, rmax) = CheckMinMax(rmin, rmax); + float qmin = 0.0f; + float qmax = 255.0f; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + + scale = (rmax - rmin) / (qmax - qmin); + const float initial_zero_point = qmin - (rmin / scale); + zero_point = static_cast(RoundHalfToEven(Saturate(qmax, qmin, initial_zero_point))); + // To match QNN quantization definition + zero_point = 0 - zero_point; + return Status::OK(); + } + + inline Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value, + int& offset) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int8_t); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int16_t); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int32_t); + break; + } + case QNN_DATATYPE_INT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(int64_t); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint8_t); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint16_t); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint32_t); + break; + } + case QNN_DATATYPE_UINT_64: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(uint64_t); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + offset += sizeof(float); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status AssertUnpackedTensorSize(const Qnn_DataType_t qnn_data_type, + const uint32_t channel, + const size_t raw_ptr_length) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_INT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(int64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint8_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint16_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint32_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_UINT_64: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(uint64_t)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_FLOAT_32: { + ORT_ENFORCE(channel == static_cast(raw_ptr_length / sizeof(float)), + "initializer size not match Qnn data type."); + break; + } + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline Status ConvertToRawOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const std::vector& double_tensor, + std::vector& raw_tensor) const { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(int8_t)); + int8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(int16_t)); + int16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(int32_t)); + int32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_INT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(int64_t)); + int64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_8: { + raw_tensor.resize(double_tensor.size() * sizeof(uint8_t)); + uint8_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_16: { + raw_tensor.resize(double_tensor.size() * sizeof(uint16_t)); + uint16_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(uint32_t)); + uint32_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UINT_64: { + raw_tensor.resize(double_tensor.size() * sizeof(uint64_t)); + uint64_t* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_FLOAT_32: { + raw_tensor.resize(double_tensor.size() * sizeof(float)); + float* raw_ptr = reinterpret_cast(raw_tensor.data()); + for (size_t i = 0; i < double_tensor.size(); ++i) { + raw_ptr[i] = static_cast(double_tensor[i]); + } + break; + } + case QNN_DATATYPE_UFIXED_POINT_32: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_32: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_BOOL_8: + case QNN_DATATYPE_STRING: + case QNN_DATATYPE_FLOAT_16: + default: + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", qnn_data_type); + } + return Status::OK(); + } + + inline double Dequantize(const OnnxInputInfo& info, + const double quant_value) const { + auto offset = static_cast(info.quant_param.scaleOffsetEncoding.offset); + auto scale = static_cast(info.quant_param.scaleOffsetEncoding.scale); + return (quant_value + offset) * scale; + } + + template + inline T Saturate(const T qmax, + const T qmin, + const T quant_value) const { + if (quant_value > qmax) { + return qmax; + } else if (quant_value < qmin) { + return qmin; + } else { + return quant_value; + } + } + + inline Status Quantize(const double double_value, + const float scale, + const int zero_point, + const Qnn_DataType_t qnn_data_type, + int& quant_value) const { + int qmin = 0; + int qmax = 255; + ORT_RETURN_IF_ERROR(GetQminQmax(qnn_data_type, qmin, qmax)); + quant_value = Saturate(qmax, qmin, static_cast(std::round((double_value / scale) - zero_point))); + return Status::OK(); + } + + Status PreprocessMean(const OnnxInputInfo& mean_info, + const bool is_npu_backend, + const uint8_t* mean_raw_ptr, + const size_t mean_raw_ptr_length, + std::vector& mean_out) const { + // tensor length (channel) + uint32_t channel = mean_info.shape[0]; + mean_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double mean_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); + mean_out[i] = (is_npu_backend) ? Dequantize(mean_info, mean_value) : mean_value; + } + return Status::OK(); + } + + Status PreprocessStd(const OnnxInputInfo& var_info, + const bool is_npu_backend, + const uint8_t* var_raw_ptr, + const size_t var_raw_ptr_length, + const float epsilon, + std::vector& std_out) const { + // tensor length (channel) + uint32_t channel = var_info.shape[0]; + std_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double var_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); + std_out[i] = (is_npu_backend) ? Dequantize(var_info, var_value) : var_value; + std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + } + return Status::OK(); + } + + Status PreprocessScale(const OnnxInputInfo& scale_info, + const bool is_npu_backend, + const uint8_t* scale_raw_ptr, + const size_t scale_raw_ptr_length, + const std::vector& std_double_tensor, + double& rmax, + double& rmin, + std::vector& scale_out) const { + // tensor length (channel) + uint32_t channel = scale_info.shape[0]; + scale_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double scale_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); + scale_out[i] = (is_npu_backend) ? Dequantize(scale_info, scale_value) : scale_value; + scale_out[i] = scale_out[i] / std_double_tensor[i]; + rmax = std::max(rmax, scale_out[i]); + rmin = std::min(rmin, scale_out[i]); + } + return Status::OK(); + } + + Status PreprocessBias(const OnnxInputInfo& bias_info, + const bool is_npu_backend, + const uint8_t* bias_raw_ptr, + const size_t bias_raw_ptr_length, + const std::vector& scale_double_tensor, + const std::vector& mean_double_tensor, + double& rmax, + double& rmin, + std::vector& bias_out) const { + // tensor length (channel) + uint32_t channel = bias_info.shape[0]; + bias_out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); + int i = 0; + int offset = 0; + for (; i < static_cast(channel); ++i) { + double bias_value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); + bias_out[i] = (is_npu_backend) ? Dequantize(bias_info, bias_value) : bias_value; + bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + rmax = std::max(rmax, bias_out[i]); + rmin = std::min(rmin, bias_out[i]); + } + return Status::OK(); + } + + Status Postprocess(const OnnxInputInfo& info, + const bool is_npu_backend, + const std::vector& double_tensor, + const double rmax, + const double rmin, + Qnn_QuantizeParams_t& quant_param, + std::vector& raw_tensor) const { + if (is_npu_backend) { + raw_tensor.resize(double_tensor.size()); + float scale = 0.0f; + int zero_point = 0; + ORT_RETURN_IF_ERROR(GetQuantParams(static_cast(rmin), + static_cast(rmax), + info.qnn_data_type, + scale, + zero_point)); + quant_param = QNN_QUANTIZE_PARAMS_INIT; + utils::InitializeQuantizeParam(quant_param, true, scale, zero_point); + for (size_t i = 0; i < double_tensor.size(); ++i) { + // onnx only supports 8 bits quantization + int quant_value_int = 0; + ORT_RETURN_IF_ERROR(Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); + if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { + raw_tensor[i] = static_cast(quant_value_int); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { + int8_t quant_value = static_cast(quant_value_int); + raw_tensor[i] = *reinterpret_cast(&quant_value); + } else { + ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); + } + } + } else { + ORT_RETURN_IF_ERROR(ConvertToRawOnQnnDataType(info.qnn_data_type, double_tensor, raw_tensor)); + } + return Status::OK(); + } }; // BatchNorm is sensitive with data layout, no special validation so far @@ -34,11 +475,6 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, // Still do it here so hopefully QNN Op validation API can tell us some details why it's not supported return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { - NodeAttrHelper node_helper(node_unit); - const float default_epsilon = 1e-05f; - const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. - ORT_RETURN_IF(abs(epsilon - default_epsilon) > default_epsilon, "QNN BatchNorm doesn't support epsilon."); - const auto& inputs = node_unit.Inputs(); ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); @@ -56,11 +492,16 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[1].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[2].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic bias."); + ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, "QNN BatchNorm input 2 (bias) must have 1D shape [channel]."); @@ -68,13 +509,15 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[3].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.IsInitializerInput(inputs[4].node_arg.Name()), + "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); } @@ -82,6 +525,134 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + ORT_UNUSED_PARAMETER(do_op_validation); + ORT_UNUSED_PARAMETER(logger); + + const auto& inputs = node_unit.Inputs(); + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + // + // Input 0 + // + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, input_names)); + + // + // Input 1: scale + // Input 2: bias + // QNN only accept 3 input. We need to first combine mean and variance into scale and bias. + // + { + const std::string& scale_name = inputs[1].node_arg.Name(); + const std::string& bias_name = inputs[2].node_arg.Name(); + OnnxInputInfo var_info = {}; + OnnxInputInfo mean_info = {}; + OnnxInputInfo scale_info = {}; + OnnxInputInfo bias_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], scale_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], bias_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[3], mean_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[4], var_info)); + + // scale, bias, mean, and var must be initializers + ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); + ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); + ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); + ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + + std::vector scale_unpacked_tensor; + std::vector bias_unpacked_tensor; + std::vector var_unpacked_tensor; + std::vector mean_unpacked_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*scale_info.initializer_tensor, scale_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*bias_info.initializer_tensor, bias_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*mean_info.initializer_tensor, mean_unpacked_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*var_info.initializer_tensor, var_unpacked_tensor)); + + std::vector mean_double_tensor; + std::vector std_double_tensor; + std::vector scale_double_tensor; + std::vector bias_double_tensor; + + NodeAttrHelper node_helper(node_unit); + const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. + + double scale_rmax = std::numeric_limits::min(); + double scale_rmin = std::numeric_limits::max(); + double bias_rmax = std::numeric_limits::min(); + double bias_rmin = std::numeric_limits::max(); + + // Calculate and convert new scale, new bias, mean and std to double array (may be dequantized) + ORT_RETURN_IF_ERROR(PreprocessMean(mean_info, + is_npu_backend, + mean_unpacked_tensor.data(), + mean_unpacked_tensor.size(), + mean_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessStd(var_info, + is_npu_backend, + var_unpacked_tensor.data(), + var_unpacked_tensor.size(), + epsilon, + std_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessScale(scale_info, + is_npu_backend, + scale_unpacked_tensor.data(), + scale_unpacked_tensor.size(), + std_double_tensor, + scale_rmax, + scale_rmin, + scale_double_tensor)); + ORT_RETURN_IF_ERROR(PreprocessBias(bias_info, + is_npu_backend, + bias_unpacked_tensor.data(), + bias_unpacked_tensor.size(), + scale_double_tensor, + mean_double_tensor, + bias_rmax, + bias_rmin, + bias_double_tensor)); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { + std::vector scale_raw_tensor; + Qnn_QuantizeParams_t scale_quant_param = scale_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(scale_info, + is_npu_backend, + scale_double_tensor, + scale_rmax, + scale_rmin, + scale_quant_param, + scale_raw_tensor)); + Qnn_TensorType_t scale_tensor_type = GetInputTensorType(qnn_model_wrapper, scale_name); + QnnTensorWrapper input_tensorwrapper(scale_name, scale_tensor_type, scale_info.qnn_data_type, scale_quant_param, + std::move(scale_info.shape), std::move(scale_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(scale_name); + + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(bias_name)) { + std::vector bias_raw_tensor; + Qnn_QuantizeParams_t bias_quant_param = bias_info.quant_param; + ORT_RETURN_IF_ERROR(Postprocess(bias_info, + is_npu_backend, + bias_double_tensor, + bias_rmax, + bias_rmin, + bias_quant_param, + bias_raw_tensor)); + Qnn_TensorType_t bias_tensor_type = GetInputTensorType(qnn_model_wrapper, bias_name); + QnnTensorWrapper input_tensorwrapper(bias_name, bias_tensor_type, bias_info.qnn_data_type, bias_quant_param, + std::move(bias_info.shape), std::move(bias_raw_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); + } + input_names.push_back(bias_name); + } + + return Status::OK(); +} + void CreateBatchNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.AddOpBuilder(op_type, std::make_unique()); } diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 9b65ca7bda3e..b4e8f5390787 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -175,13 +175,7 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 3. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 4. -// Output quant params: scale=0.019084848463535309, zero_point=9. -// Expected val: 1.7755576372146606 -// QNN QDQ val: 2.9963212013244629 (err 1.2207635641098022) -// CPU QDQ val: 0.82064849138259888 (err 0.95490914583206177) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { +TEST_F(QnnHTPBackendTests, BatchNorm1D) { constexpr int64_t num_channels = 2; RunBatchNormQDQTest(TestInputDef({1, num_channels, 3}, false, {-5.0f, -4.0f, -3.0f, 0.0f, 2.0f, 5.0f}), // Input data @@ -193,13 +187,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm1D) { // TODO: FIX TRANSLATION!!! // Check that QNN compiles DQ -> BatchNormalization -> Q as a single unit. // Use an input of rank 4. -// QNN v2.13 -// Inaccuracy detected for output 'output', element 14. -// Output quant params: scale=0.023071292787790298, zero_point=19. -// Expected val: 2.8554618358612061 -// QNN QDQ val: 5.3294687271118164 (err 2.4740068912506104) -// CPU QDQ val: 1.6611330509185791 (err 1.194328784942627) -TEST_F(QnnHTPBackendTests, DISABLED_BatchNorm2D) { +TEST_F(QnnHTPBackendTests, BatchNorm2D) { constexpr int64_t num_channels = 2; std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; @@ -226,4 +214,4 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { } // namespace test } // namespace onnxruntime -#endif \ No newline at end of file +#endif From d87216bcb13c8a3937a74b1cd2160aeb7d9cffb7 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Wed, 1 Nov 2023 08:39:39 -0700 Subject: [PATCH 09/12] Openvino ep ort 23.1 (#17911) ### Description Integration to OpenVINO 2023.1 ### Motivation and Context - Alignment with latest OpenVINO Version. - Device name change from VPUX to NPU and Remove from supported list until official public support is available. --------- Co-authored-by: Sahar Fatima Co-authored-by: Saurabh Kale Co-authored-by: Suryaprakash Shanmugam Co-authored-by: sfatimar --- cmake/CMakeLists.txt | 18 - docs/python/ReadMeOV.rst | 2 - .../core/session/onnxruntime_c_api.h | 4 +- .../providers/openvino/backend_manager.cc | 24 +- .../core/providers/openvino/backend_manager.h | 13 +- .../core/providers/openvino/backend_utils.cc | 11 +- .../core/providers/openvino/backend_utils.h | 12 +- .../openvino/backends/backend_factory.cc | 2 +- .../openvino/backends/basic_backend.cc | 77 ++-- .../openvino/backends/basic_backend.h | 11 +- .../core/providers/openvino/contexts.h | 7 +- .../openvino/openvino_execution_provider.cc | 24 +- .../openvino/openvino_execution_provider.h | 50 ++- .../openvino/openvino_provider_factory.cc | 55 ++- .../core/providers/openvino/ov_interface.cc | 10 +- .../core/providers/openvino/ov_interface.h | 13 +- .../openvino/ov_versions/capabilities.h | 2 + .../openvino/ov_versions/capability.cc | 17 +- .../openvino/ov_versions/data_ops.cc | 419 +++++++++++------- .../providers/openvino/ov_versions/data_ops.h | 13 +- .../providers/openvino/ov_versions/utils.cc | 12 +- .../providers/openvino/ov_versions/utils.h | 21 +- .../core/session/provider_bridge_ort.cc | 2 +- .../python/onnxruntime_pybind_state.cc | 4 +- .../python/onnxruntime_pybind_state_common.h | 8 +- .../test/perftest/command_args_parser.cc | 4 +- onnxruntime/test/perftest/ort_test_session.cc | 11 +- .../test/providers/cpu/nn/lp_norm_op_test.cc | 4 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 4 +- .../providers/cpu/tensor/compress_op.test.cc | 2 +- .../providers/cpu/tensor/unsqueeze_op_test.cc | 2 +- .../test/python/onnx_backend_test_series.py | 3 + .../onnx_backend_test_series_filters.jsonc | 4 + tools/ci_build/build.py | 13 +- .../nuget/generate_nuspec_for_native_nuget.py | 44 +- 35 files changed, 564 insertions(+), 358 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index f81a268d38df..94181448fd21 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1282,14 +1282,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_CONFIG_CPU_FP16=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP16) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_U8) - add_definitions(-DOPENVINO_CONFIG_VPUX_U8=1) - endif() - if (onnxruntime_USE_OPENVINO_GPU_FP32_NP) add_definitions(-DOPENVINO_CONFIG_GPU_FP32=1) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) @@ -1310,16 +1302,6 @@ if (onnxruntime_USE_OPENVINO) add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) endif() - if (onnxruntime_USE_OPENVINO_VPUX_FP32_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP32=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - - if (onnxruntime_USE_OPENVINO_VPUX_FP16_NP) - add_definitions(-DOPENVINO_CONFIG_VPUX_FP16=1) - add_definitions(-DOPENVINO_DISABLE_GRAPH_PARTITION=1) - endif() - if (onnxruntime_USE_OPENVINO_HETERO) add_definitions(-DOPENVINO_CONFIG_HETERO=1) add_definitions(-DDEVICE_NAME="${onnxruntime_USE_OPENVINO_DEVICE}") diff --git a/docs/python/ReadMeOV.rst b/docs/python/ReadMeOV.rst index f12c01d278dc..6ef16e137813 100644 --- a/docs/python/ReadMeOV.rst +++ b/docs/python/ReadMeOV.rst @@ -7,7 +7,6 @@ OpenVINOâ„¢ Execution Provider for ONNX Runtime accelerates inference across man - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs Installation ------------ @@ -22,7 +21,6 @@ This package supports: - Intel® CPUs - Intel® integrated GPUs - Intel® discrete GPUs - - Intel® integrated VPUs ``pip3 install onnxruntime-openvino`` diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 613c1ac93cf1..729a302f3dd0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -611,7 +611,7 @@ typedef struct OrtMIGraphXProviderOptions { typedef struct OrtOpenVINOProviderOptions { #ifdef __cplusplus OrtOpenVINOProviderOptions() : device_type{}, - enable_vpu_fast_compile{}, + enable_npu_fast_compile{}, device_id{}, num_of_threads{}, cache_dir{}, @@ -624,7 +624,7 @@ typedef struct OrtOpenVINOProviderOptions { * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" */ const char* device_type; - unsigned char enable_vpu_fast_compile; ///< 0 = disabled, nonzero = enabled + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled const char* device_id; size_t num_of_threads; ///< 0 = Use default number of threads const char* cache_dir; // path is set to empty by default diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 78467b646b19..7e4c0dc8d726 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -2,9 +2,7 @@ // Licensed under the MIT License #include -#include -#include -#include +#include #include "core/providers/shared_library/provider_api.h" #include "contexts.h" @@ -18,7 +16,8 @@ namespace openvino_ep { static std::unique_ptr g_global_context; GlobalContext& BackendManager::GetGlobalContext() { - // This is not thread safe to call for the first time, but it is first called on the main thread by the constructor so it is safe. + // This is not thread safe to call for the first time, + // but it is first called on the main thread by the constructor so it is safe. if (!g_global_context) g_global_context = std::make_unique(); return *g_global_context; @@ -88,7 +87,9 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, << "Backend created for graph " << subgraph_context_.subgraph_name; } } else { - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. Initializing backend for graph " << subgraph_context_.subgraph_name; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has concrete input dims. " + << "Initializing backend for graph " + << subgraph_context_.subgraph_name; subgraph_context_.has_dynamic_input_shape = false; try { @@ -104,7 +105,7 @@ BackendManager::BackendManager(const onnxruntime::Node& fused_node, bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { bool has_batched_inputs = true; - for (int i = 0; i < (int)subgraph_context_.input_indexes.size(); i++) { + for (int i = 0; i < static_cast(subgraph_context_.input_indexes.size()); i++) { auto& input = model_proto.graph().input(subgraph_context_.input_indexes[i]); // Batch-process only raw image inputs (NCHW or NHWC layouts) @@ -215,7 +216,10 @@ BackendManager::ReWriteInputShapeInfo(const ONNX_NAMESPACE::ModelProto& model_pr auto graph_proto = model_copy->mutable_graph(); for (size_t i = 0, limit = input_shapes.size(); i < limit; i++) { - auto g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + auto g_in_shape = graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->clear_dim(); const auto& shape = input_shapes[i]; for (size_t dim = 0, end = shape.size(); dim < end; dim++) { @@ -234,7 +238,11 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p auto graph_proto = model_copy->mutable_graph(); for (int i = 0; i < graph_proto->input_size(); i++) { - ONNX_NAMESPACE::TensorShapeProto* g_in_shape = graph_proto->mutable_input((int)i)->mutable_type()->mutable_tensor_type()->mutable_shape(); + ONNX_NAMESPACE::TensorShapeProto* g_in_shape = + graph_proto->mutable_input(static_cast(i)) + ->mutable_type() + ->mutable_tensor_type() + ->mutable_shape(); g_in_shape->mutable_dim(0)->clear_dim_value(); g_in_shape->mutable_dim(0)->set_dim_value(1); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index c247ab60d3a6..a177324b23f7 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -3,6 +3,11 @@ #pragma once +#include +#include +#include +#include + #include "ov_interface.h" #include "contexts.h" #include "ibackend.h" @@ -13,7 +18,9 @@ namespace openvino_ep { // Singleton class that manages all the backends class BackendManager { public: - BackendManager(const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger); + BackendManager(const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger); void Compute(OrtKernelContext* context); void ShutdownBackendManager(); static GlobalContext& GetGlobalContext(); @@ -21,7 +28,9 @@ class BackendManager { private: std::unique_ptr GetModelProtoFromFusedNode( - const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger) const; + const onnxruntime::Node& fused_node, + const onnxruntime::GraphViewer& subgraph, + const logging::Logger& logger) const; bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index d49968cdb7f3..d47c91dd4662 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -1,9 +1,7 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License -#include -#include -#include +#include #include #include @@ -58,7 +56,7 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext try { auto cnn_network = global_context.ie_core.ReadModel(model); if ((subgraph_context.precision == "FP16") && - (global_context.device_type.find("VPUX") == std::string::npos)) { + (global_context.device_type.find("NPU") == std::string::npos)) { // FP16 transformations ov::pass::ConvertFP32ToFP16 pass_obj; pass_obj.run_on_model(cnn_network); @@ -88,7 +86,8 @@ CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext size_t index = results.size() - 1; for (auto it = results.rbegin(); it != results.rend(); ++it) { - if (auto const_node = std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { + if (auto const_node = + std::dynamic_pointer_cast((*it)->input_value(0).get_node_shared_ptr())) { const_outputs_map[(*it)->get_friendly_name()] = const_node; results.erase(results.begin() + index); } @@ -254,7 +253,7 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName) { - long long totalTime = 0; + int64_t totalTime = 0; // Print performance counts stream << std::endl << "performance counts:" << std::endl diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index de78a150fe2d..82b0351e87da 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -4,9 +4,15 @@ #pragma once #define ORT_API_MANUAL_INIT +#include +#include +#include +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" #include "contexts.h" -#include #include "ov_interface.h" #ifdef _WIN32 #include @@ -57,7 +63,9 @@ void FillOutputBlob(OVTensorPtr outputBlob, Ort::UnownedValue& output_tensor, size_t batch_slice_idx); std::shared_ptr -CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, const GlobalContext& global_context, const SubGraphContext& subgraph_context, +CreateOVModel(const ONNX_NAMESPACE::ModelProto& model_proto, + const GlobalContext& global_context, + const SubGraphContext& subgraph_context, std::map>& const_outputs_map); void printPerformanceCounts(const std::vector& performanceMap, diff --git a/onnxruntime/core/providers/openvino/backends/backend_factory.cc b/onnxruntime/core/providers/openvino/backends/backend_factory.cc index c339f24e7022..c586dd8b38af 100644 --- a/onnxruntime/core/providers/openvino/backends/backend_factory.cc +++ b/onnxruntime/core/providers/openvino/backends/backend_factory.cc @@ -16,7 +16,7 @@ BackendFactory::MakeBackend(const ONNX_NAMESPACE::ModelProto& model_proto, const SubGraphContext& subgraph_context) { std::string type = global_context.device_type; if (type == "CPU" || type.find("GPU") != std::string::npos || - type.find("VPUX") != std::string::npos || + type.find("NPU") != std::string::npos || type.find("HETERO") != std::string::npos || type.find("MULTI") != std::string::npos || type.find("AUTO") != std::string::npos) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index f9517d794266..09e1322ff59f 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -6,10 +6,10 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" -// #include #include "basic_backend.h" #include "../backend_manager.h" @@ -57,33 +57,39 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, cl_context ctx = static_cast(global_context_.context); remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx); ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) if (!subgraph_context_.has_dynamic_input_shape && dev_prec != "CPU_FP16") { const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + model, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; #endif #endif } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); - exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); + exe_network_ = global_context_.ie_core.LoadNetwork( + ie_cnn_network_, hw_target, device_config, subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } } catch (const char* msg) { @@ -127,10 +133,10 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { } #endif #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - if (global_context_.device_type.find("VPUX") != std::string::npos) { + if (global_context_.device_type.find("NPU") != std::string::npos) { std::pair device_property; - device_property = std::make_pair("VPU_COMPILER_TYPE", "MLIR"); - device_config.emplace(ov::device::properties("VPUX", device_property)); + device_property = std::make_pair("NPU_COMPILER_TYPE", "DRIVER"); + device_config.emplace(ov::device::properties("NPU", device_property)); } #endif } @@ -152,12 +158,12 @@ void BasicBackend::EnableCaching() { } void BasicBackend::EnableGPUThrottling(ov::AnyMap& device_config) { - if (global_context_.enable_opencl_throttling == true && global_context_.device_type.find("GPU") != std::string::npos) { + if (global_context_.enable_opencl_throttling == true && + global_context_.device_type.find("GPU") != std::string::npos) { LOGS_DEFAULT(INFO) << log_tag << "Enabled OpenCL queue throttling for GPU device"; std::pair device_property; device_property = std::make_pair("PLUGIN_THROTTLE", "1"); device_config.emplace(ov::device::properties("GPU_CONFIG_KEY", device_property)); - // device_config[GPU_CONFIG_KEY(PLUGIN_THROTTLE)] = "1"; } } @@ -187,7 +193,9 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } size_t batch_slice_idx = 0; if (subgraph_context_.has_dynamic_input_shape && @@ -197,6 +205,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque auto tensor_info = tensor.GetTensorTypeAndShapeInfo(); auto tensor_shape = tensor_info.GetShape(); auto tensor_size = tensor_shape.size(); + const char* tensor_data = tensor.GetTensorData(); auto tensor_iter = 0; ov::Shape input_tensor_shape = ov::Shape(tensor_size, 0); for (auto i = tensor_shape.begin(); i != tensor_shape.end(); ++i) { @@ -204,8 +213,16 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque tensor_iter += 1; } auto input = ie_cnn_network_->get_parameters().at(input_idx); - OVTensorPtr tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); - FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + OVTensorPtr tensor_ptr; + // avoid input copies on the CPU device + if (global_context_.device_type.find("CPU") != std::string::npos) { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape, + (void*)tensor_data); + } else { + tensor_ptr = std::make_shared(input->get_element_type(), input_tensor_shape); + FillInputBlob(tensor_ptr, batch_slice_idx, input_name, context, subgraph_context_); + } + try { infer_request->SetTensor(input_name, tensor_ptr); } catch (const char* msg) { @@ -251,7 +268,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe if (input_names.find(onnx_input_name) != input_names.end()) { input_name = onnx_input_name; } else { - throw(log_tag + "Input names mismatch between OpenVINO and ONNX. " + onnx_input_name + " doesn't exist in the list of OpenVINO input tensor names"); + throw(log_tag + + "Input names mismatch between OpenVINO and ONNX. " + + onnx_input_name + + " doesn't exist in the list of OpenVINO input tensor names"); } input_idx++; // Kernel Context Input Buffer @@ -264,9 +284,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create an Input Remote Blob auto input = ie_cnn_network_->get_parameters().at(0); - auto remote_blob = remote_context_->create_tensor(input->get_element_type(), input->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_blob); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_blob = remote_context_->create_tensor( + input->get_element_type(), input->get_shape(), *shared_buffer_const); + ov::Tensor tensor_remote = static_cast(remote_blob); + OVTensorPtr tensor_ptr = std::make_shared(tensor_remote); infer_request->SetTensor(input_name, tensor_ptr); } else { OVTensorPtr graph_input_blob; @@ -295,7 +316,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe } } if (!output_name_found) { - throw std::string(log_tag + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); + throw std::string( + log_tag + + "Output names mismatch between OpenVINO and ONNX. [ONNX Output: ] " + + onnx_output_name + " doesn't exist in the list of OpenVINO output tensor names"); } size_t batch_size = 1; @@ -307,9 +331,10 @@ void BasicBackend::StartRemoteAsyncInference(Ort::KernelContext& context, OVInfe const cl::Buffer* shared_buffer_const = static_cast(tensor_data); // Create a shared Blob, set the Infer Request Output Blob auto output = ie_cnn_network_->get_results().at(0); - auto remote_tensor = remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); - ov::Tensor tensor = static_cast(remote_tensor); - OVTensorPtr tensor_ptr = std::make_shared(tensor); + auto remote_tensor = + remote_context_->create_tensor(output->get_element_type(), output->get_shape(), *shared_buffer_const); + ov::Tensor tensor_t = static_cast(remote_tensor); + OVTensorPtr tensor_ptr = std::make_shared(tensor_t); try { infer_request->SetTensor(output_name, tensor_ptr); } catch (const char* msg) { @@ -364,7 +389,8 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe throw(msg); } size_t batch_size = 1; - auto output_tensor = GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); + auto output_tensor = + GetOutputTensor(context, batch_size, infer_request, output_name, subgraph_context_.output_names); auto mem_info = output_tensor.GetTensorMemoryInfo(); if (mem_info.GetAllocatorName() == OpenVINO_GPU) { return; @@ -465,7 +491,8 @@ void BasicBackend::Infer(OrtKernelContext* ctx) { #ifndef IO_BUFFER_ENABLED // Printing performance counts is disabled when IO_BUFFER_ENABLED if (openvino_ep::backend_utils::IsDebugEnabled()) { inferRequestsQueue_->printstatus(); // Printing the elements of infer_requests_ vector pool only in debug mode - std::string& hw_target = (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; + std::string& hw_target = + (global_context_.device_id != "") ? global_context_.device_id : global_context_.device_type; printPerformanceCounts(infer_request, std::cout, hw_target); } #endif diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2f1d60364080..6eda641451a7 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -6,16 +6,17 @@ #include #define ORT_API_MANUAL_INIT -#include "core/session/onnxruntime_cxx_api.h" -#include "core/providers/openvino/contexts.h" -#include "core/providers/openvino/ibackend.h" -#include "core/providers/openvino/ov_interface.h" #include #include #include #include #include +#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/openvino/contexts.h" +#include "core/providers/openvino/ibackend.h" +#include "core/providers/openvino/ov_interface.h" + namespace onnxruntime { namespace openvino_ep { @@ -29,7 +30,7 @@ class BasicBackend : public IBackend { void Infer(OrtKernelContext* context) override; private: - bool ImportBlob(std::string hw_target, bool vpu_status); + bool ImportBlob(std::string hw_target, bool npu_status); void PopulateCompiledDirectory(std::string, std::string&, std::string&, bool&); bool ValidateSubgraph(std::map>& const_outputs_map); void PopulateConfigValue(ov::AnyMap& device_config); diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index b61dcf8ca492..29233e72c33b 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -3,6 +3,9 @@ #pragma once +#include +#include +#include #include "ov_interface.h" namespace onnxruntime { @@ -12,7 +15,7 @@ namespace openvino_ep { struct GlobalContext { OVCore ie_core; bool is_wholly_supported_graph = false; - bool enable_vpu_fast_compile = false; + bool enable_npu_fast_compile = false; bool enable_opencl_throttling = false; bool enable_dynamic_shapes = false; size_t num_of_threads; @@ -34,7 +37,7 @@ struct GlobalContext { struct SubGraphContext { bool has_dynamic_input_shape = false; bool enable_batching = false; - bool set_vpu_config = false; + bool set_npu_config = false; bool is_constant = false; void* context = 0; std::string subgraph_name; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 990809926299..a4c6b0f851c0 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,17 +17,18 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_; openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_; - openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_; + openvino_ep::BackendManager::GetGlobalContext().enable_npu_fast_compile = info.enable_npu_fast_compile_; openvino_ep::BackendManager::GetGlobalContext().cache_dir = info.cache_dir_; openvino_ep::BackendManager::GetGlobalContext().num_streams = info.num_streams_; openvino_ep::BackendManager::GetGlobalContext().context = info.context_; openvino_ep::BackendManager::GetGlobalContext().enable_opencl_throttling = info.enable_opencl_throttling_; openvino_ep::BackendManager::GetGlobalContext().enable_dynamic_shapes = info.enable_dynamic_shapes_; - if ((int)info.num_of_threads_ <= 0) { + if (static_cast(info.num_of_threads_) <= 0) { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = 8; - } else if ((int)info.num_of_threads_ > 8) { - std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; + } else if (static_cast(info.num_of_threads_) > 8) { + std::string err_msg = std::string("\n [ERROR] num_of_threads configured during runtime is: ") + + std::to_string(info.num_of_threads_) + "\nnum_of_threads configured should be >0 and <=8.\n"; ORT_THROW(err_msg); } else { openvino_ep::BackendManager::GetGlobalContext().num_of_threads = info.num_of_threads_; @@ -56,7 +57,8 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProv device_found = true; break; } - if (info.device_type_.find("VPUX") != std::string::npos && (info.precision_ == "FP16" || info.precision_ == "U8")) { + if ((info.device_type_.find("NPU") != std::string::npos) && + (info.precision_ == "FP16" || info.precision_ == "U8")) { device_found = true; break; } @@ -109,11 +111,14 @@ OpenVINOExecutionProvider::GetCapability(const GraphViewer& graph_viewer, openvino_ep::BackendManager::GetGlobalContext().onnx_model_name = graph_viewer.Name(); #ifdef _WIN32 std::wstring onnx_path = graph_viewer.ModelPath().ToPathString(); - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = std::string(onnx_path.begin(), onnx_path.end()); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + std::string(onnx_path.begin(), onnx_path.end()); #else - openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = graph_viewer.ModelPath().ToPathString(); + openvino_ep::BackendManager::GetGlobalContext().onnx_model_path_name = + graph_viewer.ModelPath().ToPathString(); #endif - openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = graph_viewer.DomainToVersionMap().at(kOnnxDomain); + openvino_ep::BackendManager::GetGlobalContext().onnx_opset_version = + graph_viewer.DomainToVersionMap().at(kOnnxDomain); #if defined(OPENVINO_2022_1) openvino_ep::GetCapability obj(graph_viewer, @@ -151,7 +156,8 @@ common::Status OpenVINOExecutionProvider::Compile( openvino_ep::BackendManager::GetGlobalContext().use_api_2 = true; - std::shared_ptr backend_manager = std::make_shared(fused_node, graph_body_viewer, *GetLogger()); + std::shared_ptr backend_manager = + std::make_shared(fused_node, graph_body_viewer, *GetLogger()); compute_info.create_state_func = [backend_manager](ComputeContext* context, FunctionState* state) { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a4fc09362fa2..3b56b54410e4 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -3,19 +3,28 @@ #pragma once -#include "backend_manager.h" #include #include #include +#include +#include +#include + +#include "backend_manager.h" namespace onnxruntime { static void print_build_options() { std::cout << "[ERROR] INVALID DEVICE BUILD TYPE SPECIFIED" << std::endl; - std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority you want to build" << std::endl; - std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build "; - std::cout << "are ['CPU','GPU','VPUX']" << std::endl; - std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" << std::endl; + std::cout << "Specify the keyword HETERO (or) MULTI (or) AUTO followed by the devices in the order of priority " + << "you want to build" + << std::endl; + std::cout << "The different hardware devices that can be added with HETERO/MULTI/AUTO build " + << "are ['CPU','GPU']" + << std::endl; + std::cout << "An example of how to specify the HETERO or MULTI or AUTO build type. " + << "Ex: HETERO:GPU,CPU Ex: MULTI:GPU,CPU Ex: AUTO:GPU,CPU" + << std::endl; } static std::vector split(const std::string& s, char delim) { @@ -39,7 +48,7 @@ static std::vector parseDevices(const std::string& device_string) { print_build_options(); ORT_THROW("Invalid device string: " + device_string); } - std::vector dev_options = {"CPU", "GPU", "VPUX"}; + std::vector dev_options = {"CPU", "GPU"}; for (std::string dev : devices) { if (!std::count(dev_options.begin(), dev_options.end(), dev)) { print_build_options(); @@ -53,7 +62,7 @@ static std::vector parseDevices(const std::string& device_string) { struct OpenVINOExecutionProviderInfo { std::string device_type_; std::string precision_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -62,11 +71,18 @@ struct OpenVINOExecutionProviderInfo { bool enable_opencl_throttling_; bool enable_dynamic_shapes_; - explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_vpu_fast_compile, std::string dev_id, + explicit OpenVINOExecutionProviderInfo(std::string dev_type, bool enable_npu_fast_compile, std::string dev_id, size_t num_of_threads, std::string cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), device_id_(dev_id), num_of_threads_(num_of_threads), cache_dir_(cache_dir), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + device_id_(dev_id), + num_of_threads_(num_of_threads), + cache_dir_(cache_dir), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { if (dev_type == "") { LOGS_DEFAULT(INFO) << "[OpenVINO-EP]" << "No runtime device selection option provided."; @@ -82,11 +98,11 @@ struct OpenVINOExecutionProviderInfo { #elif defined OPENVINO_CONFIG_GPU_FP16 device_type_ = "GPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_FP16 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_FP16 + device_type_ = "NPU"; precision_ = "FP16"; -#elif defined OPENVINO_CONFIG_VPUX_U8 - device_type_ = "VPUX"; +#elif defined OPENVINO_CONFIG_NPU_U8 + device_type_ = "NPU"; precision_ = "U8"; #elif defined OPENVINO_CONFIG_HETERO || defined OPENVINO_CONFIG_MULTI || defined OPENVINO_CONFIG_AUTO #ifdef DEVICE_NAME @@ -126,11 +142,11 @@ struct OpenVINOExecutionProviderInfo { } else if (dev_type == "GPU.1_FP16") { device_type_ = "GPU.1"; precision_ = "FP16"; - } else if (dev_type == "VPUX_FP16") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_FP16") { + device_type_ = "NPU"; precision_ = "FP16"; - } else if (dev_type == "VPUX_U8") { - device_type_ = "VPUX"; + } else if (dev_type == "NPU_U8") { + device_type_ = "NPU"; precision_ = "U8"; } else if (dev_type.find("HETERO") == 0 || dev_type.find("MULTI") == 0) { std::vector devices = parseDevices(dev_type); diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 95b39bcc0598..fbb89710c800 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -8,11 +8,16 @@ namespace onnxruntime { struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(const char* device_type, bool enable_vpu_fast_compile, + OpenVINOProviderFactory(const char* device_type, bool enable_npu_fast_compile, const char* device_id, size_t num_of_threads, const char* cache_dir, int num_streams, void* context, bool enable_opencl_throttling, bool enable_dynamic_shapes) - : enable_vpu_fast_compile_(enable_vpu_fast_compile), num_of_threads_(num_of_threads), num_streams_(num_streams), context_(context), enable_opencl_throttling_(enable_opencl_throttling), enable_dynamic_shapes_(enable_dynamic_shapes) { + : enable_npu_fast_compile_(enable_npu_fast_compile), + num_of_threads_(num_of_threads), + num_streams_(num_streams), + context_(context), + enable_opencl_throttling_(enable_opencl_throttling), + enable_dynamic_shapes_(enable_dynamic_shapes) { device_type_ = (device_type == nullptr) ? "" : device_type; device_id_ = (device_id == nullptr) ? "" : device_id; cache_dir_ = (cache_dir == nullptr) ? "" : cache_dir; @@ -24,7 +29,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { private: std::string device_type_; - bool enable_vpu_fast_compile_; + bool enable_npu_fast_compile_; std::string device_id_; size_t num_of_threads_; std::string cache_dir_; @@ -35,7 +40,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { }; std::unique_ptr OpenVINOProviderFactory::CreateProvider() { - OpenVINOExecutionProviderInfo info(device_type_, enable_vpu_fast_compile_, device_id_, num_of_threads_, + OpenVINOExecutionProviderInfo info(device_type_, enable_npu_fast_compile_, device_id_, num_of_threads_, cache_dir_, num_streams_, context_, enable_opencl_throttling_, enable_dynamic_shapes_); return std::make_unique(info); @@ -59,17 +64,18 @@ struct OpenVINO_Provider : Provider { std::string device_type = ""; // [device_type]: Overrides the accelerator hardware type and precision // with these values at runtime. - bool enable_vpu_fast_compile = false; // [enable_vpu_fast_compile]: Fast-compile may be optionally enabled to - // speeds up the model's compilation to VPU device specific format. + bool enable_npu_fast_compile = false; // [enable_npu_fast_compile]: Fast-compile may be optionally enabled to + // speeds up the model's compilation to NPU device specific format. const char* device_id = ""; // [device_id]: Selects a particular hardware device for inference. - size_t num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of + int num_of_threads = 8; // [num_of_threads]: Overrides the accelerator default value of number of // threads with this value at runtime. const char* cache_dir = ""; // [cache_dir]: specify the path to // dump and load the blobs for the model caching/kernel caching (GPU) // feature. If blob files are already present, it will be directly loaded. int num_streams = 1; // [num_streams]: Option that specifies the number of parallel inference // requests to be processed on a given `device_type`. Overrides the - // accelerator default value of number of streams with this value at runtime. + // accelerator default value of number of streams + // with this value at runtime. bool enable_opencl_throttling = false; // [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU // device (Reduces CPU Utilization when using GPU) bool enable_dynamic_shapes = false; // [enable_dynamic_shapes]: Enables Dynamic Shapes feature for CPU device) @@ -80,14 +86,15 @@ struct OpenVINO_Provider : Provider { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (!((ov_supported_device_types.find(device_type) != ov_supported_device_types.end()) || - (device_type.find("HETERO:") == 0) || (device_type.find("MULTI:") == 0) || (device_type.find("AUTO:") == 0))) { + (device_type.find("HETERO:") == 0) || + (device_type.find("MULTI:") == 0) || + (device_type.find("AUTO:") == 0))) { ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } @@ -97,30 +104,37 @@ struct OpenVINO_Provider : Provider { if (provider_options_map.find("cache_dir") != provider_options_map.end()) { cache_dir = provider_options_map.at("cache_dir").c_str(); } + if (provider_options_map.find("context") != provider_options_map.end()) { - context = (void*)provider_options_map.at("context").c_str(); + std::string str = provider_options_map.at("context"); + uint64_t number = std::strtoull(str.c_str(), nullptr, 16); + context = reinterpret_cast(number); } if (provider_options_map.find("num_of_threads") != provider_options_map.end()) { num_of_threads = std::stoi(provider_options_map.at("num_of_threads")); if (num_of_threads <= 0) { num_of_threads = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_threads' should be in the positive range.\n " + << "Executing with num_threads=1"; } } if (provider_options_map.find("num_streams") != provider_options_map.end()) { num_streams = std::stoi(provider_options_map.at("num_streams")); - if (num_streams <= 0 && num_streams > 8) { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'num_streams' should be in the range of 1-8 \n"); + if (num_streams <= 0) { + num_streams = 1; + LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] The value for the key 'num_streams' should be in the range of 1-8.\n " + << "Executing with num_streams=1"; } } std::string bool_flag = ""; - if (provider_options_map.find("enable_vpu_fast_compile") != provider_options_map.end()) { - bool_flag = provider_options_map.at("enable_vpu_fast_compile"); + if (provider_options_map.find("enable_npu_fast_compile") != provider_options_map.end()) { + bool_flag = provider_options_map.at("enable_npu_fast_compile"); if (bool_flag == "true" || bool_flag == "True") - enable_vpu_fast_compile = true; + enable_npu_fast_compile = true; else if (bool_flag == "false" || bool_flag == "False") - enable_vpu_fast_compile = false; + enable_npu_fast_compile = false; bool_flag = ""; } @@ -141,7 +155,7 @@ struct OpenVINO_Provider : Provider { enable_dynamic_shapes = false; } return std::make_shared(const_cast(device_type.c_str()), - enable_vpu_fast_compile, + enable_npu_fast_compile, device_id, num_of_threads, cache_dir, @@ -157,7 +171,6 @@ struct OpenVINO_Provider : Provider { void Shutdown() override { openvino_ep::BackendManager::ReleaseGlobalContext(); } - } g_provider; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 3914488fc523..d2ce378c97e0 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -29,7 +29,10 @@ std::shared_ptr OVCore::ReadModel(const std::string& model) const { } } -OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(ie_cnn_network, hw_target, device_config); @@ -43,7 +46,10 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, std } #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) -OVExeNetwork OVCore::LoadNetwork(const std::string& model, std::string& hw_target, ov::AnyMap& device_config, std::string name) { +OVExeNetwork OVCore::LoadNetwork(const std::string& model, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name) { ov::CompiledModel obj; try { obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index ed9583033ab3..935ac8f68411 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -4,6 +4,7 @@ #pragma once #include +#include #if defined(OPENVINO_2022_1) || (OPENVINO_2022_2) || (OPENVINO_2022_3) || (OPENVINO_2023_0) || (OPENVINO_2023_1) #define OV_API_20 @@ -43,9 +44,15 @@ class OVCore { public: std::shared_ptr ReadModel(const std::string& model_stream) const; - OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(std::shared_ptr& ie_cnn_network, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #if defined(OPENVINO_2023_0) || (OPENVINO_2023_1) - OVExeNetwork LoadNetwork(const std::string& model_stream, std::string& hw_target, ov::AnyMap& device_config, std::string name); + OVExeNetwork LoadNetwork(const std::string& model_stream, + std::string& hw_target, + ov::AnyMap& device_config, + std::string name); #endif void SetCache(std::string cache_dir_path); #ifdef IO_BUFFER_ENABLED @@ -62,7 +69,7 @@ class OVExeNetwork { ov::CompiledModel obj; public: - OVExeNetwork(ov::CompiledModel md) { obj = md; } + explicit OVExeNetwork(ov::CompiledModel md) { obj = md; } OVExeNetwork() { obj = ov::CompiledModel(); } ov::CompiledModel& Get() { return obj; } OVInferRequest CreateInferRequest(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h index b76d1cf534c2..5bcf9d68cd94 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capabilities.h +++ b/onnxruntime/core/providers/openvino/ov_versions/capabilities.h @@ -3,6 +3,8 @@ #pragma once #include +#include +#include #include "data_ops.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 171dd45c508c..b030efa23820 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -24,7 +24,8 @@ namespace openvino_ep { // Constructor GetCapability::GetCapability(const GraphViewer& graph_viewer_param, std::string device_type_param, - const std::string version_param) : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { + const std::string version_param) + : graph_viewer_(graph_viewer_param), device_type_(device_type_param) { if (version_param == "V_2022_1") { data_ops_ = new DataOps(graph_viewer_, V_2022_1, device_type_); } else if (version_param == "V_2022_2") { @@ -114,11 +115,11 @@ std::vector> GetCapability::Execute() { } openvino_ep::BackendManager::GetGlobalContext().is_wholly_supported_graph = true; - } else { // unsupported_nodes_idx.empty() - + } else { // unsupported_nodes_idx.empty() #if defined(OPENVINO_DISABLE_GRAPH_PARTITION) // disables graph partition at build time LOGS_DEFAULT(INFO) << "[OpenVINO-EP] DISABLE_GRAPH_PARTITION option is set"; - LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, so making the full model fall back to default CPU Execution Provider"; + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model is not fully supported by OpenVINO, " + << "so making the full model fall back to default CPU Execution Provider"; return result; #endif @@ -159,7 +160,13 @@ std::vector> GetCapability::Execute() { std::vector cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs; - GetInputsOutputsOfCluster(graph_viewer_, this_cluster, ng_required_initializers, cluster_graph_inputs, cluster_inputs, const_inputs, cluster_outputs); + GetInputsOutputsOfCluster(graph_viewer_, + this_cluster, + ng_required_initializers, + cluster_graph_inputs, + cluster_inputs, + const_inputs, + cluster_outputs); bool omit_subgraph = false; // Omitting zero dim subgraphs diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 70118c94f9ff..a5a0faa3a8f2 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -2,11 +2,15 @@ // Licensed under the MIT License #include +#include +#include +#include +#include +#include + #include "core/providers/shared_library/provider_api.h" #include "../backend_utils.h" #include "../backend_manager.h" -#include -#include #include "data_ops.h" #include "capabilities.h" #include "utils.h" @@ -72,269 +76,355 @@ std::set ops_supported_as_function = { std::vector supported_op_mode = { {"Abs", V_2020_4, {"CPU", "GPU"}}, - {"Abs", V_2023_0, {"VPUX"}}, + {"Abs", V_2023_0, {"NPU"}}, {"Acos", V_2020_4, {"CPU"}}, {"Acos", V_2022_1, {"GPU"}}, + {"Acos", V_2023_1, {"NPU"}}, {"Acosh", V_2020_4, {"CPU"}}, {"Acosh", V_2022_1, {"GPU"}}, + {"Acosh", V_2023_1, {"NPU"}}, {"Add", V_2020_4, {"CPU", "GPU"}}, - {"Add", V_2023_0, {"VPUX"}}, + {"Add", V_2023_0, {"NPU"}}, {"And", V_2020_4, {"CPU", "GPU"}}, + {"And", V_2023_1, {"NPU"}}, {"ArgMax", V_2020_4, {"CPU"}}, {"ArgMax", V_2021_1, {"GPU"}}, {"ArgMin", V_2020_4, {"CPU"}}, {"ArgMin", V_2022_1, {"GPU"}}, {"Asin", V_2020_4, {"CPU", "GPU"}}, + {"Asin", V_2023_1, {"NPU"}}, {"Asinh", V_2020_4, {"CPU", "GPU"}}, + {"Asinh", V_2023_1, {"NPU"}}, {"Atan", V_2020_4, {"CPU", "GPU"}}, + {"Atan", V_2023_1, {"NPU"}}, {"Atanh", V_2020_4, {"CPU"}}, {"Atanh", V_2022_1, {"GPU"}}, + {"Atanh", V_2023_1, {"NPU"}}, {"AveragePool", V_2020_4, {"CPU", "GPU"}}, - {"AveragePool", V_2023_0, {"VPUX"}}, + {"AveragePool", V_2023_0, {"NPU"}}, {"BatchNormalization", V_2020_4, {"CPU", "GPU"}}, - {"BatchNormalization", V_2023_0, {"VPUX"}}, + {"BatchNormalization", V_2023_0, {"NPU"}}, {"BitShift", V_2022_1, {"CPU"}}, + {"BitShift", V_2023_1, {"NPU"}}, {"Cast", V_2020_4, {"CPU", "GPU"}}, - {"Cast", V_2023_0, {"VPUX"}}, + {"Cast", V_2023_0, {"NPU"}}, + {"CastLike", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Ceil", V_2020_4, {"GPU"}}, {"Ceil", V_2021_4, {"CPU"}}, + {"Ceil", V_2023_1, {"NPU"}}, {"Celu", V_2022_1, {"CPU", "GPU"}}, {"Clip", V_2020_4, {"CPU", "GPU"}}, - {"Clip", V_2023_0, {"VPUX"}}, + {"Clip", V_2023_0, {"NPU"}}, + {"Compress", V_2023_1, {"CPU", "GPU"}}, {"Concat", V_2020_4, {"CPU", "GPU"}}, - {"Concat", V_2023_0, {"VPUX"}}, + {"Concat", V_2023_0, {"NPU"}}, {"Constant", V_2020_4, {"CPU", "GPU"}}, - {"Constant", V_2023_0, {"VPUX"}}, + {"Constant", V_2023_0, {"NPU"}}, {"ConstantOfShape", V_2020_4, {"CPU", "GPU"}}, - {"ConstantOfShape", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op in the plugin. + {"ConstantOfShape", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op in the plugin. {"Conv", V_2020_4, {"CPU", "GPU"}}, - {"Conv", V_2023_0, {"VPUX"}}, + {"Conv", V_2023_0, {"NPU"}}, {"ConvInteger", V_2022_1, {"CPU", "GPU"}}, + {"ConvInteger", V_2023_1, {"NPU"}}, {"ConvTranspose", V_2020_4, {"CPU", "GPU"}}, + {"ConvTranspose", V_2023_1, {"NPU"}}, {"Cos", V_2020_4, {"CPU"}}, {"Cos", V_2022_1, {"GPU"}}, - {"Cos", V_2023_0, {"VPUX"}}, + {"Cos", V_2023_0, {"NPU"}}, {"Cosh", V_2020_4, {"CPU"}}, {"Cosh", V_2022_1, {"GPU"}}, + {"Cosh", V_2023_1, {"NPU"}}, {"CumSum", V_2022_1, {"CPU", "GPU"}}, - {"CumSum", V_2023_0, {"VPUX"}}, + {"CumSum", V_2023_0, {"NPU"}}, {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, - {"DepthToSpace", V_2023_0, {"VPUX"}}, + {"DepthToSpace", V_2023_0, {"NPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"DequantizeLinear", V_2023_0, {"VPUX"}}, + {"DequantizeLinear", V_2023_0, {"NPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, - {"Div", V_2023_0, {"VPUX"}}, + {"Div", V_2023_0, {"NPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, - {"Dropout", V_2023_0, {"VPUX"}}, + {"Dropout", V_2023_0, {"NPU"}}, {"Elu", V_2020_4, {"CPU", "GPU"}}, - {"Elu", V_2023_0, {"VPUX"}}, + {"Elu", V_2023_0, {"NPU"}}, // {"Einsum", V_2023_0, {"CPU", "GPU"}}, {"Equal", V_2020_4, {"CPU", "GPU"}}, - {"Equal", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Equal", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Erf", V_2020_4, {"CPU", "GPU"}}, - {"Erf", V_2023_0, {"VPUX"}}, + {"Erf", V_2023_0, {"NPU"}}, {"Exp", V_2020_4, {"CPU", "GPU"}}, - {"Exp", V_2023_0, {"VPUX"}}, + {"Exp", V_2023_0, {"NPU"}}, {"Expand", V_2022_1, {"CPU", "GPU"}}, - {"Expand", V_2023_0, {"VPUX"}}, // Gets mapped to broadcast op and multiply op in the plugin. + {"Expand", V_2023_0, {"NPU"}}, // Gets mapped to broadcast op and multiply op in the plugin. {"EyeLike", V_2022_1, {"CPU"}}, - {"EyeLike", V_2023_0, {"VPUX"}}, // NoOP + {"EyeLike", V_2023_0, {"NPU"}}, // NoOP {"Flatten", V_2020_4, {"CPU", "GPU"}}, - {"Flatten", V_2023_0, {"VPUX"}}, + {"Flatten", V_2023_0, {"NPU"}}, {"Floor", V_2020_4, {"CPU", "GPU"}}, + {"Floor", V_2023_1, {"NPU"}}, {"Gather", V_2020_4, {"CPU", "GPU"}}, - {"Gather", V_2023_0, {"VPUX"}}, + {"Gather", V_2023_0, {"NPU"}}, {"GatherElements", V_2022_2, {"CPU", "GPU"}}, + {"GatherElements", V_2023_1, {"NPU"}}, {"GatherND", V_2021_4, {"CPU", "GPU"}}, + {"GatherND", V_2023_1, {"NPU"}}, {"Gemm", V_2020_4, {"CPU", "GPU"}}, - {"Gemm", V_2023_0, {"VPUX"}}, + {"Gemm", V_2023_0, {"NPU"}}, {"GlobalAveragePool", V_2020_4, {"CPU", "GPU"}}, - {"GlobalAveragePool", V_2023_0, {"VPUX"}}, + {"GlobalAveragePool", V_2023_0, {"NPU"}}, {"GlobalLpPool", V_2020_4, {"CPU", "GPU"}}, + {"GlobalLpPool", V_2023_1, {"NPU"}}, {"GlobalMaxPool", V_2022_1, {"CPU", "GPU"}}, + {"GlobalMaxPool", V_2023_1, {"NPU"}}, {"Greater", V_2020_4, {"CPU", "GPU"}}, - {"Greater", V_2023_0, {"VPUX"}}, + {"Greater", V_2023_0, {"NPU"}}, {"GreaterOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"GreaterOrEqual", V_2023_0, {"VPUX"}}, + {"GreaterOrEqual", V_2023_0, {"NPU"}}, {"GridSample", V_2022_3, {"CPU"}}, {"GridSample", V_2023_0, {"GPU"}}, + {"GridSample", V_2023_1, {"NPU"}}, + {"HardMax", V_2023_1, {"CPU", "GPU", "NPU"}}, {"Identity", V_2020_4, {"CPU", "GPU"}}, - {"Identity", V_2023_0, {"VPUX"}}, // NoOP + {"Identity", V_2023_0, {"NPU"}}, // NoOP {"If", V_2022_3, {"CPU", "GPU"}}, + {"If", V_2023_1, {"NPU"}}, {"ImageScaler", V_2022_1, {"CPU", "GPU"}}, - {"ImageScaler", V_2023_0, {"VPUX"}}, + {"ImageScaler", V_2023_0, {"NPU"}}, {"InstanceNormalization", V_2020_4, {"CPU", "GPU"}}, - {"InstanceNormalization", V_2023_0, {"VPUX"}}, + {"InstanceNormalization", V_2023_0, {"NPU"}}, {"HardSigmoid", V_2020_4, {"CPU", "GPU"}}, + {"HardSigmoid", V_2023_1, {"NPU"}}, {"HardMax", V_2022_1, {"CPU", "GPU"}}, {"LeakyRelu", V_2020_4, {"CPU", "GPU"}}, - {"LeakyRelu", V_2023_0, {"VPUX"}}, + {"LeakyRelu", V_2023_0, {"NPU"}}, {"Less", V_2020_4, {"CPU", "GPU"}}, - {"Less", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Less", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"LessOrEqual", V_2022_1, {"CPU", "GPU"}}, - {"LessOrEqual", V_2023_0, {"VPUX"}}, + {"LessOrEqual", V_2023_0, {"NPU"}}, {"Log", V_2020_4, {"CPU", "GPU"}}, - {"Log", V_2023_0, {"VPUX"}}, + {"Log", V_2023_0, {"NPU"}}, {"LogSoftMax", V_2022_1, {"CPU", "GPU"}}, {"Loop", V_2021_4, {"CPU", "GPU"}}, + {"LpNormalization", V_2023_1, {"CPU", "GPU", "NPU"}}, + {"LpPool", V_2023_1, {"CPU", "GPU", "NPU"}}, {"LRN", V_2020_4, {"CPU", "GPU"}}, - {"LRN", V_2023_0, {"VPUX"}}, + {"LRN", V_2023_0, {"NPU"}}, {"LSTM", V_2020_4, {"CPU", "GPU"}}, + {"LSTM", V_2023_1, {"NPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, - {"MatMul", V_2023_0, {"VPUX"}}, + {"MatMul", V_2023_0, {"NPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2023_1, {"NPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, - {"Max", V_2023_0, {"VPUX"}}, + {"Max", V_2023_0, {"NPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, - {"MaxPool", V_2023_0, {"VPUX"}}, + {"MaxPool", V_2023_0, {"NPU"}}, {"Mean", V_2020_4, {"CPU", "GPU"}}, - {"Mean", V_2023_0, {"VPUX"}}, + {"Mean", V_2023_0, {"NPU"}}, {"MeanVarianceNormalization", V_2022_1, {"CPU", "GPU"}}, + {"MeanVarianceNormalization", V_2023_1, {"NPU"}}, {"Min", V_2020_4, {"CPU", "GPU"}}, - {"Min", V_2023_0, {"VPUX"}}, + {"Min", V_2023_0, {"NPU"}}, {"Mod", V_2022_1, {"CPU", "GPU"}}, {"Mul", V_2020_4, {"CPU", "GPU"}}, - {"Mul", V_2023_0, {"VPUX"}}, + {"Mul", V_2023_0, {"NPU"}}, {"Neg", V_2020_4, {"CPU", "GPU"}}, - {"Neg", V_2023_0, {"VPUX"}}, + {"Neg", V_2023_0, {"NPU"}}, {"NonMaxSuppression", V_2021_1, {"CPU", "GPU"}}, + {"NonMaxSuppression", V_2023_1, {"NPU"}}, {"NonZero", V_2021_1, {"CPU"}}, {"NonZero", V_2023_0, {"GPU"}}, {"Not", V_2021_1, {"CPU", "GPU"}}, {"Not", V_2020_4, {"CPU", "GPU"}}, + {"Not", V_2023_1, {"NPU"}}, {"OneHot", V_2020_4, {"CPU", "GPU"}}, + {"OneHot", V_2023_1, {"NPU"}}, {"Or", V_2022_1, {"CPU", "GPU"}}, + {"Or", V_2023_1, {"NPU"}}, {"Pad", V_2020_4, {"CPU", "GPU"}}, - {"Pad", V_2023_0, {"VPUX"}}, + {"Pad", V_2023_0, {"NPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, - {"Pow", V_2023_0, {"VPUX"}}, + {"Pow", V_2023_0, {"NPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"PRelu", V_2023_0, {"VPUX"}}, + {"PRelu", V_2023_0, {"NPU"}}, {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2023_1, {"NPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, - {"QuantizeLinear", V_2023_0, {"VPUX"}}, + {"QuantizeLinear", V_2023_0, {"NPU"}}, + {"RNN", V_2023_1, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, {"RandomNormalLike", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormalLike", V_2023_1, {"NPU"}}, {"RandomNormal", V_2023_0, {"CPU", "GPU"}}, + {"RandomNormal", V_2023_1, {"NPU"}}, {"Range", V_2022_1, {"CPU", "GPU"}}, - {"Range", V_2023_0, {"VPUX"}}, + {"Range", V_2023_0, {"NPU"}}, {"Reciprocal", V_2020_4, {"CPU", "GPU"}}, - {"Reciprocal", V_2023_0, {"VPUX"}}, + {"Reciprocal", V_2023_0, {"NPU"}}, {"ReduceL1", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL1", V_2023_1, {"NPU"}}, {"ReduceL2", V_2022_1, {"CPU", "GPU"}}, + {"ReduceL2", V_2023_1, {"NPU"}}, {"ReduceLogSum", V_2020_4, {"CPU"}}, {"ReduceLogSum", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSum", V_2023_1, {"NPU"}}, {"ReduceLogSumExp", V_2022_1, {"CPU", "GPU"}}, + {"ReduceLogSumExp", V_2023_1, {"NPU"}}, {"ReduceMax", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMax", V_2023_1, {"NPU"}}, {"ReduceMean", V_2020_4, {"CPU", "GPU"}}, - {"ReduceMean", V_2023_0, {"VPUX"}}, + {"ReduceMean", V_2023_0, {"NPU"}}, {"ReduceMin", V_2020_4, {"CPU", "GPU"}}, + {"ReduceMin", V_2023_1, {"NPU"}}, {"ReduceProd", V_2020_4, {"CPU"}}, {"ReduceProd", V_2022_1, {"GPU"}}, + {"ReduceProd", V_2023_1, {"NPU"}}, {"ReduceSum", V_2020_4, {"CPU", "GPU"}}, + // {"ReduceSum", V_2023_1, {"NPU"}}, {"ReduceSumSquare", V_2020_4, {"CPU"}}, {"ReduceSumSquare", V_2022_1, {"CPU", "GPU"}}, + {"ReduceSumSquare", V_2023_1, {"NPU"}}, {"Relu", V_2020_4, {"CPU", "GPU"}}, - {"Relu", V_2023_0, {"VPUX"}}, + {"Relu", V_2023_0, {"NPU"}}, {"Resize", V_2020_4, {"CPU"}}, {"Resize", V_2022_1, {"GPU"}}, + {"Resize", V_2023_1, {"NPU"}}, {"Reshape", V_2020_4, {"CPU", "GPU"}}, - {"Reshape", V_2023_0, {"VPUX"}}, + {"Reshape", V_2023_0, {"NPU"}}, {"ReverseSequence", V_2022_1, {"CPU", "GPU"}}, {"RoiAlign", V_2021_1, {"CPU", "GPU"}}, + {"RoiAlign", V_2023_1, {"NPU"}}, {"Round", V_2021_4, {"CPU", "GPU"}}, + {"Round", V_2023_1, {"NPU"}}, {"Scatter", V_2022_1, {"CPU", "GPU"}}, + {"Scatter", V_2023_1, {"NPU"}}, {"ScatterElements", V_2022_1, {"CPU", "GPU"}}, + {"ScatterElements", V_2023_1, {"NPU"}}, {"ScatterND", V_2022_1, {"CPU", "GPU"}}, + {"ScatterND", V_2023_1, {"NPU"}}, {"Selu", V_2020_4, {"CPU", "GPU"}}, + {"Selu", V_2023_1, {"NPU"}}, {"Shape", V_2020_4, {"CPU", "GPU"}}, - {"Shape", V_2023_0, {"VPUX"}}, + {"Shape", V_2023_0, {"NPU"}}, {"Shrink", V_2022_1, {"CPU", "GPU"}}, - {"Shrink", V_2023_0, {"VPUX"}}, + {"Shrink", V_2023_0, {"NPU"}}, {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, - {"Sigmoid", V_2023_0, {"VPUX"}}, + {"Sigmoid", V_2023_0, {"NPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, - {"Sign", V_2023_0, {"VPUX"}}, + {"Sign", V_2023_0, {"NPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, - {"Sin", V_2023_0, {"VPUX"}}, + {"Sin", V_2023_0, {"NPU"}}, {"Sinh", V_2020_4, {"CPU"}}, + {"Sinh", V_2023_1, {"NPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, + {"Size", V_2023_1, {"NPU"}}, {"Slice", V_2020_4, {"CPU", "GPU"}}, - {"Slice", V_2023_0, {"VPUX"}}, + {"Slice", V_2023_0, {"NPU"}}, {"Softmax", V_2020_4, {"CPU", "GPU"}}, - {"Softmax", V_2023_0, {"VPUX"}}, + {"Softmax", V_2023_0, {"NPU"}}, {"Softplus", V_2022_1, {"CPU", "GPU"}}, - {"Softplus", V_2023_0, {"VPUX"}}, + {"Softplus", V_2023_0, {"NPU"}}, {"Softsign", V_2022_1, {"CPU", "GPU"}}, {"SpaceToDepth", V_2020_4, {"CPU", "GPU"}}, - {"SpaceToDepth", V_2023_0, {"VPUX"}}, + {"SpaceToDepth", V_2023_0, {"NPU"}}, {"Split", V_2020_4, {"CPU", "GPU"}}, - {"Split", V_2023_0, {"VPUX"}}, + {"Split", V_2023_0, {"NPU"}}, {"Sqrt", V_2020_4, {"CPU", "GPU"}}, - {"Sqrt", V_2023_0, {"VPUX"}}, + {"Sqrt", V_2023_0, {"NPU"}}, {"Squeeze", V_2020_4, {"CPU", "GPU"}}, - {"Squeeze", V_2023_0, {"VPUX"}}, + {"Squeeze", V_2023_0, {"NPU"}}, {"Softsign", V_2020_4, {"CPU"}}, {"Sub", V_2020_4, {"CPU", "GPU"}}, - {"Sub", V_2023_0, {"VPUX"}}, + {"Sub", V_2023_0, {"NPU"}}, {"Sum", V_2020_4, {"CPU", "GPU"}}, - {"Sum", V_2023_0, {"VPUX"}}, + {"Sum", V_2023_0, {"NPU"}}, {"Tan", V_2020_4, {"CPU", "GPU"}}, + {"Tan", V_2023_1, {"NPU"}}, {"Tanh", V_2020_4, {"CPU", "GPU"}}, - {"Tanh", V_2023_0, {"VPUX"}}, + {"Tanh", V_2023_0, {"NPU"}}, {"ThresholdedRelu", V_2022_1, {"CPU", "GPU"}}, - {"ThresholdedRelu", V_2023_0, {"VPUX"}}, + {"ThresholdedRelu", V_2023_0, {"NPU"}}, {"Tile", V_2021_3, {"CPU", "GPU"}}, - {"Tile", V_2023_0, {"VPUX"}}, + {"Tile", V_2023_0, {"NPU"}}, {"Transpose", V_2020_4, {"CPU", "GPU"}}, - {"Transpose", V_2023_0, {"VPUX"}}, + {"Transpose", V_2023_0, {"NPU"}}, {"Trilu", V_2023_0, {"CPU", "GPU"}}, + {"Trilu", V_2023_1, {"NPU"}}, {"TopK", V_2020_4, {"CPU", "GPU"}}, - {"TopK", V_2023_0, {"VPUX"}}, + {"TopK", V_2023_0, {"NPU"}}, + {"Upsample", V_2020_4, {"CPU", "GPU"}}, {"Unsqueeze", V_2020_4, {"CPU", "GPU"}}, - {"Unsqueeze", V_2023_0, {"VPUX"}}, - {"Upsample", V_2021_1, {"CPU"}}, - {"Upsample", V_2021_4, {"GPU"}}, - {"Upsample", V_2023_0, {"VPUX"}}, + {"Unsqueeze", V_2023_0, {"NPU"}}, {"Where", V_2022_1, {"CPU", "GPU"}}, - {"Where", V_2023_0, {"VPUX"}}, // Added for whisper decoder model. + {"Where", V_2023_0, {"NPU"}}, // Added for whisper decoder model. {"Xor", V_2022_1, {"CPU", "GPU"}}, + {"Xor", V_2023_1, {"NPU"}}, }; void DataOps::populate_types_supported() { - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_initializer_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_initializer_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_initializer_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_initializer_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_initializer_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_vpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_vpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_npu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_npu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_cpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_cpu_.insert(std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_cpu_.insert( + std::make_pair(V_2022_2, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); - supported_types_gpu_.insert(std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); - supported_types_gpu_.insert(std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); - supported_types_gpu_.insert(std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); - supported_types_gpu_.insert(std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_gpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); + supported_types_gpu_.insert( + std::make_pair(V_2021_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)); + supported_types_gpu_.insert( + std::make_pair(V_2021_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8)); + supported_types_gpu_.insert( + std::make_pair(V_2022_1, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); } void DataOps::populate_op_mode_supported() { @@ -349,10 +439,10 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Gather", V_2020_4, {"All"}}); - no_dimension_supported_.push_back({"Greater", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Greater", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Less", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); - no_dimension_supported_.push_back({"Max", V_2023_0, {"VPUX"}}); + no_dimension_supported_.push_back({"Max", V_2023_0, {"NPU"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); @@ -382,11 +472,14 @@ void DataOps::populate_op_mode_supported() { { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { - // Abs is not supproted with INT8 or INT32 as input data type on GPU - if (device_id_.find("GPU") != std::string::npos) { + // Abs is not supproted with INT8 or INT32 as input data type on GPU and NPU + if ((device_id_.find("GPU") != std::string::npos) || + (device_id_.find("NPU") != std::string::npos)) { for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || - node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 || + node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -399,11 +492,14 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // tensor type does not support select last index auto& attributes = node->GetAttributes(); - auto last_index_arg = attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() : 0; + auto last_index_arg = + attributes.count("select_last_index") > 0 ? attributes.at("select_last_index").i() + : 0; if (last_index_arg != 0) return true; // tensor type supports float as input for argmax and argmin - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) + if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) return true; return false; }}; @@ -415,7 +511,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { // int64 data type is not supported on GPU - const bool data_is_int64 = node->InputDefs()[0]->Type()->find("int64") != std::string::npos; + const bool data_is_int64 = + node->InputDefs()[0]->Type()->find("int64") != std::string::npos; return data_is_int64; } return false; @@ -506,9 +603,12 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto x_data_type = node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); auto y_data_type = node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - // currently both inputs with int32 are not supported and also both input datatypes should be same - const bool A_is_int32 = node->InputDefs()[0]->Type()->find("int32") != std::string::npos; - const bool B_is_int32 = node->InputDefs()[1]->Type()->find("int32") != std::string::npos; + // currently both inputs with int32 are not supported + // and also both input datatypes should be same + const bool A_is_int32 = + node->InputDefs()[0]->Type()->find("int32") != std::string::npos; + const bool B_is_int32 = + node->InputDefs()[1]->Type()->find("int32") != std::string::npos; if ((A_is_int32 && B_is_int32) || (x_data_type != y_data_type)) return true; } @@ -589,11 +689,13 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { auto slope = node->InputDefs()[1]; // PRelu slope has to be an initializer or needs to come from a constant node - if (initializers.count(slope->Name())) + if (initializers.count(slope->Name())) { return false; - else { - for (auto input_node = node->InputNodesBegin(); input_node != node->InputNodesEnd(); ++input_node) { - if (GetInputCount(this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) + } else { + for (auto input_node = node->InputNodesBegin(); + input_node != node->InputNodesEnd(); ++input_node) { + if (GetInputCount( + this->graph_viewer_.GetNode((*input_node).Index()), initializers) == 0) return false; } } @@ -603,12 +705,12 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"PRelu", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_arg = node->InputDefs()[1]; auto shape = input_arg->Shape(); // Reshape op with empty dim is Rejected for Myriad - //[TODO] Is this condition required anymore with Myriad removed? + // [TODO] Is this condition required anymore with Myriad removed? if (shape != nullptr) { for (const auto& dim : input_arg->Shape()->dim()) { if (utils::HasDimValue(dim) && dim.dim_value() == 0) @@ -638,7 +740,8 @@ void DataOps::populate_op_mode_supported() { if (device_id_.find("GPU") != std::string::npos) { // INT32 dataype is not supported as input for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) return true; } } @@ -650,9 +753,11 @@ void DataOps::populate_op_mode_supported() { UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3}, [this](const Node* node, const InitializedTensorSet&) { if (device_id_.find("GPU") != std::string::npos) { - auto output_data_type = node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + auto output_data_type = + node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); // If the output of ScatterND op is BOOL, it is rejected for GPU. - if (output_data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) + if (output_data_type == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL) return true; } return false; @@ -666,7 +771,8 @@ void DataOps::populate_op_mode_supported() { [this](const Node* node, const InitializedTensorSet&) { // If the Input of Shrink op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { - if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) + if (node->InputDefs()[i]->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) return true; } return false; @@ -714,10 +820,11 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Squeeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze - // If axes is an input, then we cannot produce a static graph. Conversion fails in convert_function_to_cnn_network. + // If axes is an input, then we cannot produce a static graph. + // Conversion fails in convert_function_to_cnn_network. for (size_t i = 0; i < node->InputDefs().size(); i++) { if (node->InputDefs()[i]->Name() == "axes") { return true; @@ -728,14 +835,15 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0}, + UnsupportedOpMode obj = {{V_2022_1, V_2022_2, V_2022_3, V_2023_0, V_2023_1}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); if (upsample_attr.count("scales") > 0) { auto& upsample_arg = upsample_attr.at("scales"); auto float_size = upsample_arg.floats_size(); - if (float_size > 2 && (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { + if (float_size > 2 && + (upsample_arg.floats(0) != 1.f || upsample_arg.floats(1) != 1.f)) { return true; } } @@ -750,9 +858,12 @@ void DataOps::populate_op_mode_supported() { } } // x_arg supports only float, int8 and float16 type - if ((x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || - (x_arg->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { + if ((x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) || + (x_arg->TypeAsProto()->tensor_type().elem_type() == + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16)) { return false; } else { return true; @@ -849,9 +960,9 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { } else { auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("VPUX") != std::string::npos || device_id_.find("HETERO") != std::string::npos || + if (device_id_.find("NPU") != std::string::npos || device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { - for (auto const& var : supported_types_vpu_) { + for (auto const& var : supported_types_npu_) { if ((var.first <= version_id_) && (var.second == dtype)) { return true; @@ -1079,7 +1190,9 @@ bool DataOps::node_is_supported(const std::mapsecond.find(optype) == opset->second.end() && op_fun == ops_supported_as_function.end()) { #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { - std::cout << "The operator is not available in OpenVINO ngraph operators list nor the operator is a special ONNX function" << std::endl; + std::cout << "The operator is not available in OpenVINO ngraph operators list" + << "nor the operator is a special ONNX function" + << std::endl; } #endif return false; @@ -1095,10 +1208,12 @@ std::vector DataOps::GetUnsupportedNodeIndices(std::unordered_setForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, bool is_input) { - if(is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { + graph_viewer_.GetNode(node_idx)->ForEachDef([&ng_required_initializers, this](const NodeArg& node_arg, + bool is_input) { + if (is_input && this->graph_viewer_.GetAllInitializedTensors().count(node_arg.Name())) { ng_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -1110,7 +1225,8 @@ bool DataOps::IsOpSupportedOnlyInModel(std::string name) { return ops_supported_only_in_model.find(name) != ops_supported_only_in_model.end(); } -bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node) { +bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, + const Node* node) { if (node->OpType() == "Reshape") { const auto& shape_arg = node->InputDefs()[1]; if (ng_required_initializers.find(shape_arg->Name()) == ng_required_initializers.end()) { @@ -1119,15 +1235,20 @@ bool DataOps::SpecialConditionForClusterSizeOne(std::unordered_set& } else if (node->OpType() == "Expand") { // nGraph only supports constant shape input values const auto& output = node->OutputDefs()[0]; - if (output->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) + if (output->TypeAsProto()->tensor_type().elem_type() != + ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) return true; } else if (node->OpType() == "RoiAlign") { using onnx_dtype = ONNX_NAMESPACE::TensorProto_DataType; - onnx_dtype input_0_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_1_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype input_2_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); - onnx_dtype output_data_type = (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_0_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_1_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[1]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype input_2_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->InputDefs()[2]->TypeAsProto()->tensor_type().elem_type(); + onnx_dtype output_data_type = + (ONNX_NAMESPACE::TensorProto_DataType)node->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if ((input_0_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || (input_1_data_type != onnx_dtype::TensorProto_DataType_FLOAT16) || diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cc968d02ea64..a5aa3f825602 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -3,6 +3,11 @@ #pragma once #include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -47,7 +52,7 @@ class DataOps { std::multimap op_list_; std::vector subgraph_supported_; std::vector no_dimension_supported_; - std::set supported_types_vpu_; + std::set supported_types_npu_; std::set supported_types_cpu_; std::set supported_types_gpu_; std::set supported_types_initializer_; @@ -64,14 +69,16 @@ class DataOps { const NodeIndex node_idx); public: - DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { + DataOps(const GraphViewer& graph_viewer_param, VersionNum ver, std::string dev_id) + : graph_viewer_(graph_viewer_param), version_id_(ver), device_id_(dev_id) { populate_op_mode_supported(); populate_types_supported(); } virtual std::vector GetUnsupportedNodeIndices(std::unordered_set& ng_required_initializers); virtual bool IsOpSupportedOnlyInModel(std::string name); - virtual bool SpecialConditionForClusterSizeOne(std::unordered_set& ng_required_initializers, const Node* node); + virtual bool SpecialConditionForClusterSizeOne( + std::unordered_set& ng_required_initializers, const Node* node); virtual bool DoNotOmitSubGraph(const std::string& name); virtual bool InsertNode(const std::string& name); VersionNum GetVersion() const { return version_id_; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index be509b674362..74369d39b9a2 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License #include "core/providers/shared_library/provider_api.h" +#include "utils.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245 5208) @@ -113,7 +114,8 @@ std::map> GetNgSupportedOps(const int onnx_op * supported_cluster + (UNsupported_node + rest_of_the_graph). This functions returns vector of all supported_clusters by nGraph */ std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedClusters(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> ng_clusters; auto prev = topological_order.begin(); @@ -140,7 +142,10 @@ GetPartitionedClusters(const std::vector& topological_order, const st return ng_clusters; } -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster) { +void IdentifyConnectedNodes(const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster) { if (std::find(cluster.begin(), cluster.end(), curr_node_index) == cluster.end()) return; @@ -205,7 +210,8 @@ void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const auto& ext_node = graph_viewer.GetNode((*it).Index()); if (std::find(cluster.begin(), cluster.end(), ext_node->Index()) == cluster.end()) { - // Node is external to this_cluster. Search through its inputs to find the output that is generated by this_cluster. + // Node is external to this_cluster. Search through its inputs to + // find the output that is generated by this_cluster. std::set ext_node_inputs; ext_node->ForEachDef( [&ext_node_inputs](const NodeArg& arg, bool is_input) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 70f6954ea991..c256cde97956 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -1,5 +1,15 @@ // Copyright (C) 2019-2022 Intel Corporation // Licensed under the MIT License +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include namespace onnxruntime { namespace openvino_ep { @@ -18,9 +28,14 @@ int GetOnnxOpSet(const GraphViewer& graph_viewer); std::map> GetNgSupportedOps(const int onnx_opset); std::vector> -GetPartitionedClusters(const std::vector& topological_order, const std::vector& unsupported_nodes); - -void IdentifyConnectedNodes(const GraphViewer& graph_viewer, NodeIndex curr_node_index, std::vector& cluster, std::vector& sub_cluster); +GetPartitionedClusters( + const std::vector& topological_order, const std::vector& unsupported_nodes); + +void IdentifyConnectedNodes( + const GraphViewer& graph_viewer, + NodeIndex curr_node_index, + std::vector& cluster, + std::vector& sub_cluster); std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9e5988347822..df4dd5541775 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1432,7 +1432,7 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_vpu_fast_compile"] = legacy_ov_options->enable_vpu_fast_compile; + ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 7faca3b4681b..2027b592326d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -813,10 +813,10 @@ std::unique_ptr CreateExecutionProviderInstance( if (option.first == "device_type") { OV_provider_options_map[option.first] = option.second; continue; - } else if (option.first == "enable_vpu_fast_compile") { + } else if (option.first == "enable_npu_fast_compile") { if (!(option.second == "True" || option.second == "true" || option.second == "False" || option.second == "false")) { - ORT_THROW("Invalid value passed for enable_vpu_fast_compile: ", option.second); + ORT_THROW("Invalid value passed for enable_npu_fast_compile: ", option.second); } OV_provider_options_map[option.first] = option.second; } else if (option.first == "enable_opencl_throttling") { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 5bb6bcc38b6f..a5bcbce89bac 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -60,11 +60,11 @@ struct OrtStatus { #elif OPENVINO_CONFIG_GPU_FP16 #define BACKEND_OPENVINO "-OPENVINO_GPU_FP16" -#elif OPENVINO_CONFIG_VPUX_FP16 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_FP16" +#elif OPENVINO_CONFIG_NPU_FP16 +#define BACKEND_OPENVINO "-OPENVINO_NPU_FP16" -#elif OPENVINO_CONFIG_VPUX_U8 -#define BACKEND_OPENVINO "-OPENVINO_VPUX_U8" +#elif OPENVINO_CONFIG_NPU_U8 +#define BACKEND_OPENVINO "-OPENVINO_NPU_U8" #elif OPENVINO_CONFIG_MULTI #define BACKEND_OPENVINO "-OPENVINO_MULTI" diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index b1a04a00e89b..6d075fec997b 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -60,7 +60,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" - "\t [OpenVINO only] [enable_vpu_fast_compile]: Optionally enabled to speeds up the model's compilation on VPU device targets.\n" + "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" @@ -72,7 +72,7 @@ namespace perftest { "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_vpu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 41a1eafebbb5..b7a111783fc9 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -240,8 +240,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (key == "device_type") { std::set ov_supported_device_types = {"CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16", - "GPU.0_FP16", "GPU.1_FP16", - "VPUX_FP16", "VPUX_U8"}; + "GPU.0_FP16", "GPU.1_FP16"}; if (ov_supported_device_types.find(value) != ov_supported_device_types.end()) { ov_options[key] = value; } else if (value.find("HETERO:") == 0) { @@ -254,17 +253,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW( "[ERROR] [OpenVINO] You have selcted wrong configuration value for the key 'device_type'. " "Select from 'CPU_FP32', 'CPU_FP16', 'GPU_FP32', 'GPU.0_FP32', 'GPU.1_FP32', 'GPU_FP16', " - "'GPU.0_FP16', 'GPU.1_FP16', 'VPUX_FP16', 'VPUX_U8' or from" + "'GPU.0_FP16', 'GPU.1_FP16' or from" " HETERO/MULTI/AUTO options available. \n"); } } else if (key == "device_id") { ov_options[key] = value; - } else if (key == "enable_vpu_fast_compile") { + } else if (key == "enable_npu_fast_compile") { if (value == "true" || value == "True" || value == "false" || value == "False") { ov_options[key] = value; } else { - ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_vpu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); + ORT_THROW("[ERROR] [OpenVINO] The value for the key 'enable_npu_fast_compile' should be a boolean i.e. true or false. Default value is false.\n"); } } else if (key == "enable_opencl_throttling") { if (value == "true" || value == "True" || @@ -299,7 +298,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ov_options[key] = value; } } else { - ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_vpu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); + ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling|true'] \n"); } } session_options.AppendExecutionProvider("OpenVINO", ov_options); diff --git a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc index e37206d6aebf..b7cead66bd7f 100644 --- a/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/lp_norm_op_test.cc @@ -143,7 +143,7 @@ void L1NormalizationWithZeroNorm() { vector expected_output = {0.5f, 0.5f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L1NormalizationWithZeroNorm) { @@ -163,7 +163,7 @@ void L2NormalizationWithZeroNorm() { vector expected_output = {1.f, 0.f, 0.f, 0.f}; test.AddOutput("Y", input_dims, expected_output); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(LpNormalizationTest, L2NormalizationWithZeroNorm) { diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index d1a523b1eecf..b9875b9553a5 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -762,7 +762,7 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -860,7 +860,7 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc index c95ac1603a31..c3d91100605e 100644 --- a/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/compress_op.test.cc @@ -66,7 +66,7 @@ TEST(CompressTest, Compress_3dims_has_extra_condition) { // has condition length = 3 > input_dim[axis] = 2 test.AddInput("condition", {3}, {0, 1, 1}); test.AddOutput("output", {2, 1, 3}, {4.0f, 5.0f, 6.0f, 10.0f, 11.0f, 12.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } TEST(CompressTest, Compress_3dims_has_extra_input) { diff --git a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc index 2120da604f94..d2aa5dd428fe 100644 --- a/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/unsqueeze_op_test.cc @@ -99,7 +99,7 @@ TEST(TensorOpTest, Unsqueeze_scalar_2) { test.AddInput("input", {}, std::vector{1.0f}); test.AddInput("axes", {2}, std::vector{0, -1}, axes_is_initializer); test.AddOutput("output", {1, 1}, std::vector{1.0f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; run_test(false); run_test(true); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index ecf4b001eec6..c48b07422d45 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -140,6 +140,9 @@ def create_backend_test(test_name=None): if backend.supports_device("OPENVINO_CPU_FP16"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_CPU_FP16") + if backend.supports_device("OPENVINO_NPU_FP16"): + current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_NPU_FP16") + if backend.supports_device("OPENVINO"): current_failing_tests += apply_filters(filters, "current_failing_tests_OPENVINO_opset18") diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 44db7c0078cf..c552ec3aea72 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -521,6 +521,10 @@ "test_scan_sum_cpu", // Disabled due to output mismatch with tolerance. "test_scan9_sum_cpu" // Disabled due to output mismatch with tolerance. ], + "current_failing_tests_OPENVINO_NPU_FP16": [ + "^test_prelu_broadcast", + "test_loop11_cpu" + ], "current_failing_tests_OPENVINO_opset18": [ // pending opset 18 support, RUNTIME_EXCEPTION : Encountered unknown exception in Initialize() "^test_center_crop_pad_crop_axes_chw", diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 806e536cb4dd..a992da8ff993 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -66,15 +66,13 @@ def _str_to_bool(s): def _openvino_verify_device_type(device_read): - choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16", "VPUX_FP16", "VPUX_U8"] + choices = ["CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16"] choices1 = [ "CPU_FP32_NO_PARTITION", "CPU_FP16_NO_PARTITION", "GPU_FP32_NO_PARTITION", "GPU_FP16_NO_PARTITION", - "VPUX_FP16_NO_PARTITION", - "VPUX_U8_NO_PARTITION", ] status_hetero = True res = False @@ -89,7 +87,7 @@ def _openvino_verify_device_type(device_read): if len(comma_separated_devices) < 2: print("At least two devices required in Hetero/Multi/Auto Mode") status_hetero = False - dev_options = ["CPU", "GPU", "VPUX"] + dev_options = ["CPU", "GPU"] for dev in comma_separated_devices: if dev not in dev_options: status_hetero = False @@ -100,7 +98,7 @@ def invalid_hetero_build(): print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") - print("are ['CPU','GPU', 'VPUX'] \n") + print("are ['CPU','GPU'] \n") print("An example of how to specify the hetero build type. Ex: HETERO:GPU,CPU \n") print("An example of how to specify the MULTI build type. Ex: MULTI:GPU,CPU \n") print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU \n") @@ -1158,8 +1156,6 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16=" + ("ON" if args.use_openvino == "CPU_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16=" + ("ON" if args.use_openvino == "VPUX_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8=" + ("ON" if args.use_openvino == "VPUX_U8" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" @@ -1168,9 +1164,6 @@ def generate_build_tree( + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_CPU_FP16_NP=" + ("ON" if args.use_openvino == "CPU_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_FP16_NP=" - + ("ON" if args.use_openvino == "VPUX_FP16_NP_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VPUX_U8_NP=" + ("ON" if args.use_openvino == "VPUX_U8_NP_PARTITION" else "OFF"), "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"), "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"), diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index cc27cdc29364..f7b68551b9c5 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -552,6 +552,7 @@ def generate_files(line_list, args): files_list.append( "" ) + else: files_list.append( "' - ) + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") + for dll_element in os.listdir(dll_list_path): if dll_element.endswith("dll"): files_list.append( @@ -735,26 +720,7 @@ def generate_files(line_list, args): + args.target_architecture + '\\native" />' ) - # plugins.xml - files_list.append( - "' - ) - # usb-ma2x8x.mvcmd - # OpenVINO 2022.3 doesn't have usb-ma2x8x.mvcmd - if "2022.3" not in openvino_path: - files_list.append( - "' - ) + for tbb_element in os.listdir(tbb_list_path): if tbb_element.endswith("dll"): files_list.append( From 9e8ad398479d9c2dc0ca91a8df89e452d059f6ee Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 1 Nov 2023 08:49:33 -0700 Subject: [PATCH 10/12] Distributed Reduction (#18206) This PR implements distributed reduciton for llama 2. This version doesn't consider any cases requring re-sharding because we haven't seen any use cases. Intutive examples: - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and device_mesh=[0,1] Algorithm: When the reduced axes are not sharded, each device can call reduction directly. The output sharding spec will be identical to input sharding spec. We currently throw when input and output sharding specs are different. Review guideline: - Check 97b8d2f for new op's schema and how new op is registered. - Read tests in 2450f93 to get faimilar with the behavior of these ops. - Check the implementation details in 753d9af. --- cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../cuda/collective/distributed_reduce.cc | 175 +++++++++ .../cuda/collective/distributed_reduce.h | 59 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 18 + .../core/graph/contrib_ops/collective_defs.cc | 123 +++++++ .../providers/cuda/reduction/reduction_ops.cc | 24 ++ .../python/onnxruntime_test_distributed.py | 345 ++++++++++++------ 8 files changed, 638 insertions(+), 108 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 043789c36c32..ce0c12804b08 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -40,6 +40,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c7129..9bc2bdd208a9 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 000000000000..967f30a304ac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 000000000000..2939852c75c6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index d51915b85095..8e157da6cb43 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -175,6 +175,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); #endif template <> @@ -354,6 +363,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 070df487a264..8b5b561c1ad8 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -273,6 +273,129 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index d46ed9c245a8..bc78e577c505 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index e0fb3979a9f5..6f691972181b 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -7,7 +7,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT, INT64 +from onnxscript import FLOAT, FLOAT16, INT64 import onnxruntime as ort @@ -27,12 +27,23 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): return np.concatenate(selected_shards, axis=axis) -def translate_device_mesh_to_attrs(device_mesh: np.ndarray): +def translate_single_device_mesh(device_mesh: np.ndarray): device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" return device_mesh_shape, device_mesh_elements +def translate_all_device_meshes(device_meshes: np.ndarray): + assert all(len(mesh.shape) == 1 for mesh in device_meshes) + device_mesh_shapes = [] + device_mesh_elements = [] + for device_mesh in device_meshes: + device_mesh_shape, device_mesh_element = translate_single_device_mesh(device_mesh) + device_mesh_shapes.append(device_mesh_shape) + device_mesh_elements.append(device_mesh_element) + return device_mesh_shapes, device_mesh_elements + + def parse_sharding_spec(spec: str): axis_conditions = [] sharding_device_axes = [] @@ -90,29 +101,13 @@ def _check_distributed_reshape( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -134,11 +129,11 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = np.reshape(data_tensor, shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_reshape_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -176,9 +171,9 @@ def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): 3, ), target_shape=(6,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]",), ) @@ -191,9 +186,9 @@ def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): 4, ), target_shape=(8,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]",), ) @@ -210,9 +205,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): 2, 15, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -229,9 +224,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): 2, 20, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -248,9 +243,9 @@ def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): 2, 18, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) # Two axis fusion. @@ -268,9 +263,9 @@ def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): 2, 3, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -283,9 +278,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): 1, 16, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -298,9 +293,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -313,9 +308,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): 4, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -328,9 +323,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): 8, 2, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -343,9 +338,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): 16, 1, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -359,9 +354,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self) 1, 16, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -375,9 +370,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -390,9 +385,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): 4, 4, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -405,9 +400,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): 8, 2, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -420,9 +415,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self) 16, 1, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -444,9 +439,9 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -471,9 +466,9 @@ def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rr 64, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]R",), ) @@ -495,9 +490,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -519,9 +514,9 @@ def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self) 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRR",), ) @@ -546,9 +541,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_0101 7, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -573,9 +568,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_ 7, 7, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -600,9 +595,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101 7, 7, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -627,9 +622,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_6 7, 64, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -654,9 +649,9 @@ def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(s 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -678,9 +673,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -690,29 +685,16 @@ def _check_distributed_expand( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -734,11 +716,11 @@ def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = data_tensor * np.ones(shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_expand_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -780,9 +762,9 @@ def test_expand_sharded_on_expanded_axis(self): 8, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -799,9 +781,9 @@ def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): 8, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -818,9 +800,9 @@ def test_expand_replicated_on_expanded_axis(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -837,12 +819,12 @@ def test_expand_with_pass_through_sharding_spec(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=( "S[0]R", "R", ), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -863,13 +845,160 @@ def test_expand_in_tiny_llama(self): 256, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) +class TestDistributedReduce(unittest.TestCase): + def _check_distributed_reduce( + self, + keepdims: int, + dtype: np.dtype, + shape: Tuple[int, ...], + axes: Tuple[int, ...], + input_device_meshes: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshes: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) + + @onnxscript.script() + def distributed_reduce_sum_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceSum( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_max_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMax( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMean( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + + for onnx_func, np_func in zip( + [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], + [np.sum, np.maximum.reduce, np.mean], + ): + data = np.random.randint(4, size=shape).astype(dtype) + expected = np_func(data, axis=axes, keepdims=bool(keepdims)) + + assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs." + assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded." + assert len(output_shard_specs) == 1 and len(output_device_meshes) == 1, "Reduce has only one output." + + local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) + + if dtype == np.float32: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + elif dtype == np.int64: + onnx_model = onnx_func.to_model_proto( + input_types=[INT64[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[INT64[tuple(local_expected.shape)]], + ) + elif dtype == np.float16: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT16[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT16[tuple(local_expected.shape)]], + ) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data, + "axes_tensor": np.array(axes, dtype=np.int64), + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reduce(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(0,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reduce_sharded(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(1,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. From a2e9ba72d5a5f61e1324ffc2a80d748d01be9120 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 1 Nov 2023 15:34:51 -0700 Subject: [PATCH 11/12] [JS/Web]Added FusedConv. (#17766) ### Description Added FusedConv and FusedConvTranspose ### Motivation and Context Improve performance --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + .../webgpu/ops/3rd-party/activation_util.ts | 4 +- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 5 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 6 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 37 +++--- js/web/test/data/ops/fused-conv.jsonc | 112 ++++++++++++++++++ onnxruntime/contrib_ops/js/fused_conv.cc | 20 ++++ .../contrib_ops/js/js_contrib_kernels.cc | 5 +- .../core/optimizer/conv_activation_fusion.cc | 31 ++++- .../core/optimizer/conv_add_act_fusion.cc | 7 +- .../core/optimizer/graph_transformer_utils.cc | 13 +- .../selector_action_transformer.cc | 20 ++-- .../selector_action_transformer.h | 17 ++- .../core/providers/js/operators/conv.cc | 2 + .../core/providers/js/operators/conv.h | 78 ++++++++---- .../providers/js/operators/conv_transpose.cc | 2 + .../providers/js/operators/conv_transpose.h | 55 ++++++--- .../test/optimizer/graph_transform_test.cc | 13 +- 21 files changed, 339 insertions(+), 98 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv.jsonc create mode 100644 onnxruntime/contrib_ops/js/fused_conv.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 44003021293b..5b94a4a51093 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -40,6 +40,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 40309c1849bc..a4d51e68b6a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -67,6 +67,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 22b91d680a9b..6481a6b21d72 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -41,12 +41,12 @@ export const activationFnSnippet = if (!activation) { return ''; } - // TODO: add implementations return ''; }; export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} + // TODO uncomment the following line when activation is supported above. + // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 01ddca520dee..fbb936a045b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 840360223c75..a95d3830f34e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 103286941246..0a0f29db6a49 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -440,7 +440,7 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -473,8 +473,8 @@ export const createMatmulProgramInfo = const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} ${activationFunction} + ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7abf022928ad..8bfa722dd090 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const {activationFunction, applyActivation} = getActivationSnippet(attributes); const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0..956ef18eb5cf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,24 +10,25 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { + activationFunction: string; applyActivation: string; +} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : + 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 000000000000..812e9d7c2def --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 000000000000..76402f068197 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 4641b006a778..24d327576ecd 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -11,6 +11,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -23,7 +24,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9..d27603e4ab3a 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0..6f90eaf07ef4 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5a441b1d1701..86b126f2c7c3 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -270,11 +270,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2..546d52b6f168 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f..5caa949ebbe9 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index 2e07124dcd90..68336c996a86 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -16,6 +16,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, @@ -23,6 +24,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Conv, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index fdf3e5b6c6b6..3a01a4aa46be 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -3,23 +3,42 @@ #pragma once +#include +#include + #include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" namespace onnxruntime { namespace js { -template -class Conv : public JsKernel { +class ConvBase : public JsKernel { public: - Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { + ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info), + conv_attrs_(info), + w_is_const_(false) { + std::vector activation_params; TensorShapeVector kernel_shape; + const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size(); + std::vector local_pads(pads_vec_size, 0); + for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) { + local_pads[i] = gsl::narrow_cast(conv_attrs_.pads[i]); + } + if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - + if (is_fused_conv) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); + ORT_ENFORCE(info.GetAttrs("activation_params", activation_params).IsOK()); + } else { + conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); + activation_params = info.GetAttrsOrDefault("activation_params", activation_params); + } + const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); - + auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_attrs_.dilations.size() == 1 || (conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || @@ -30,44 +49,52 @@ class Conv : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : [ $5, $6 ], + "pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "activation" : UTF8ToString($10), + "activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(kernel_shape_0), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ - "format" : $13 ? "NHWC" : "NCHW", + "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : [ $7, $8, $9, $10 ], - "strides" : [ $11, $12 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + "pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [], + "strides" : [ $9, $10 ], + "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "activation" : UTF8ToString($13), + "activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), - static_cast(conv_attrs_.pads.size() > 2 ? conv_attrs_.pads[2] : 0), - static_cast(conv_attrs_.pads.size() > 3 ? conv_attrs_.pads[3] : 0), + static_cast(kernel_shape_0), + static_cast(kernel_shape_1), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } } @@ -94,5 +121,12 @@ class Conv : public JsKernel { // Tensor w_transposed_; }; +template +class Conv : public ConvBase { + public: + explicit Conv(const OpKernelInfo& info) : ConvBase(info, is_channels_last, is_fused_conv) { + } +}; + } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 2228343e1e6e..f7f0ab22b700 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -15,6 +15,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_KERNEL_EX( ConvTranspose, kOnnxDomain, @@ -22,6 +23,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( ConvTranspose, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 18ef73268005..5d30dc851e00 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -4,26 +4,45 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { TensorShapeVector kernel_shape; + if (is_fused_convtranspose) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_transpose_attrs_.activation)); + } else { + conv_transpose_attrs_.activation = info.GetAttrOrDefault("activation", ""); + } + if (conv_transpose_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), + conv_transpose_attrs_.output_shape.end()); + std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), + conv_transpose_attrs_.output_padding.end()); + const auto* local_output_padding_ptr = + local_output_padding.size() > 0 ? local_output_padding.data() : nullptr; + const auto* local_output_shape_ptr = + local_output_shape.size() > 0 ? local_output_shape.data() : nullptr; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_transpose_attrs_.dilations.size() == 1 || (conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || conv_transpose_attrs_.strides.size() == 1) { + auto dilations = conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0; + auto kernel_shape_0 = conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto pads_0 = conv_transpose_attrs_.pads.size() > 0 ? conv_transpose_attrs_.pads[0] : 0; + auto pads_1 = conv_transpose_attrs_.pads.size() > 1 ? conv_transpose_attrs_.pads[1] : 0; + auto strides = conv_transpose_attrs_.strides.size() > 0 ? conv_transpose_attrs_.strides[0] : 0; JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $8 ? "NHWC" : "NCHW", "autoPad" : $1, @@ -34,21 +53,23 @@ class ConvTranspose : public JsKernel { "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), "outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [] + "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [], + "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), - static_cast(conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0), + static_cast(dilations), static_cast(conv_transpose_attrs_.group), - static_cast(conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0) ? kernel_shape[0] : 0, - static_cast(conv_transpose_attrs_.pads.size()), - static_cast(conv_transpose_attrs_.pads.size() > 1) ? conv_transpose_attrs_.pads[1] : 0, - static_cast(conv_transpose_attrs_.strides.size() > 0) ? conv_transpose_attrs_.strides[0] : 0, + static_cast(kernel_shape_0), + static_cast(pads_0), + static_cast(pads_1), + static_cast(strides), static_cast(channels_last), reinterpret_cast(&w_is_const_), - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_padding.size() > 0 ? conv_transpose_attrs_.output_padding.data() : nullptr) >> 2, - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_shape.size() > 0 ? conv_transpose_attrs_.output_shape.data() : nullptr) >> 2); + gsl::narrow_cast(local_output_padding.size()), + reinterpret_cast(local_output_padding_ptr) >> 2, + gsl::narrow_cast(local_output_shape.size()), + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; constexpr size_t strides_vec_size = 2; @@ -59,8 +80,6 @@ class ConvTranspose : public JsKernel { std::vector local_strides(strides_vec_size, 0); std::vector local_dilations(dialations_vec_size, 0); std::vector local_kernel_shape; - std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), conv_transpose_attrs_.output_shape.end()); - std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); if (conv_transpose_attrs_.kernel_shape_specified) { for (size_t i = 0; i < kernel_shape.size() && i < kernel_shape_vec_size; ++i) { local_kernel_shape.push_back(gsl::narrow_cast(kernel_shape[i])); @@ -91,7 +110,8 @@ class ConvTranspose : public JsKernel { "strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [] + "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [], + "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), reinterpret_cast(local_dilations.data()) >> 2, @@ -102,9 +122,10 @@ class ConvTranspose : public JsKernel { static_cast(channels_last), reinterpret_cast(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding.size() > 0 ? local_output_padding.data() : nullptr) >> 2, + reinterpret_cast(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape.size() > 0 ? local_output_shape.data() : nullptr) >> 2); + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 46b95a127b75..a6aa4b946f39 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1438,7 +1438,7 @@ TEST_F(GraphTransformationTests, NotWhereFusion) { ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where } -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) +#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be transformed to FusedConv TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; @@ -1618,6 +1618,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCudaExecutionProvider); } +#elif defined(USE_JSEP) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kJsExecutionProvider); + } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count_before_fusion[model.second] >= 1); @@ -1632,6 +1636,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::set cuda_rocm_supported = {"Relu"}; if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_EQ(op_to_count_after_fusion[model.second], 0); + } +#elif defined(USE_JSEP) + std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; + if (js_supported.find(model.second) == js_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); } From 178f7caaebf67bcddee8fae8836656613290b8bc Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Wed, 1 Nov 2023 20:04:22 -0700 Subject: [PATCH 12/12] GQA Memory Efficient Kernel (#17920) Implement Cutlass Memory Efficient Attention Kernel into Group Query Attention Operator. ### Motivation and Context Before this change, Group Query Attention Operator was supported only by Flash-Attention. While this is the most efficient kernel for the operation, it only supports sm >= 80. Cutlass Memory Efficient Attention Kernel supports sm >= 53, allowing us to support a broader range of GPU hardware. --- cmake/onnxruntime_rocm_hipify.cmake | 5 + docs/ContribOperators.md | 6 +- .../contrib_ops/cuda/bert/attention_impl.cu | 2 + .../bert/cutlass_fmha/fmha_launch_template.h | 54 +- .../cutlass_fmha/memory_efficient_attention.h | 2 + .../cuda/bert/group_query_attention.cc | 73 ++- .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_helper.h | 88 ++- .../cuda/bert/group_query_attention_impl.cu | 596 ++++++++++++++---- .../cuda/bert/group_query_attention_impl.h | 9 + .../cuda/bert/packed_attention_impl.cu | 2 + .../bert/packed_multihead_attention_impl.cu | 2 + .../core/graph/contrib_ops/bert_defs.cc | 6 +- .../python/transformers/test_flash_attn.py | 309 ++++++--- 14 files changed, 843 insertions(+), 312 deletions(-) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 9bc2bdd208a9..4140eeee0d11 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -94,6 +94,11 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "bert/group_query_attention_helper.h" + "bert/group_query_attention.h" + "bert/group_query_attention.cc" + "bert/group_query_attention_impl.h" + "bert/group_query_attention_impl.cu" ) if (NOT onnxruntime_ENABLE_ATEN) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ed1049b0bd73..8e86862a62e7 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2422,14 +2422,14 @@ This version of the operator has been available since version 1 of the 'com.micr
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
-#### Outputs (1 - 3) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key (optional) : T
+
present_key : T
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value (optional) : T
+
present_value : T
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index eb9e6d5c6246..16ce3a899fb5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -374,6 +374,7 @@ Status EfficientAttention( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.total_sequence_length; + p.max_sequence_length = parameters.total_sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; @@ -395,6 +396,7 @@ Status EfficientAttention( p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) ? data.scratch : nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index ed330b0fca33..51c3d3d3a458 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -51,25 +51,45 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.num_keys = params.kv_sequence_length; if (params.causal) { - p.custom_mask_type = Attention::CausalFromTopLeft; + p.custom_mask_type = Attention::CausalFromBottomRight; } - // Input format is BxSxNxH, output is BxSxNxH - p.q_strideH = params.qk_head_size; - p.k_strideH = params.qk_head_size; - p.v_strideH = params.v_head_size; - p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; - - p.q_strideM = params.num_heads * params.qk_head_size; - p.k_strideM = params.num_heads * params.qk_head_size; - p.v_strideM = params.num_heads * params.v_head_size; - p.o_strideM = params.num_heads * params.v_head_size; - p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; - - p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; - p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; - p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; - p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + // We use max_sequence_length to calculate KV stride + if (params.is_kv_bsnh) { + // Input Q, K, V format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.max_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } else { + // Input K, V format is BxNxSxH, Input Q is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.max_sequence_length * params.qk_head_size; + p.v_strideH = params.max_sequence_length * params.v_head_size; + p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.qk_head_size; + p.v_strideM = params.v_head_size; + p.o_strideM = params.num_heads * params.v_head_size; + p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys; + + p.q_strideB = params.num_heads * params.qk_head_size * params.sequence_length; + p.k_strideB = params.num_heads * params.qk_head_size * params.max_sequence_length; + p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; + p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; + } } constexpr auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f725be8d7cf8..f16567bb6f2b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -14,10 +14,12 @@ namespace cuda { struct MemoryEfficientAttentionParams { int32_t sm; bool is_half; + bool is_kv_bsnh = true; int32_t batch_size; int32_t num_heads; int32_t sequence_length; int32_t kv_sequence_length; + int32_t max_sequence_length; int32_t qk_head_size; int32_t v_head_size; bool causal; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 67d750aeac11..8694dc998c7a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -6,9 +6,8 @@ #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" -// #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" -// #include "contrib_ops/cpu/utils/console_dumper.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -55,6 +54,13 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_flash_attention_ = true; #endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -92,18 +98,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { output_shape[2] = static_cast(parameters.hidden_size); Tensor* output = context->Output(0, output_shape); - std::vector present_dims; - if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { - present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; - } else { // BNSH - present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; - } - TensorShape present_shape(present_dims); - Tensor* present_key = context->Output(1, present_shape); - Tensor* present_value = context->Output(2, present_shape); - #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, @@ -143,8 +137,47 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif - // only kernel implemented for gqa right now - ORT_ENFORCE(use_flash_attention); +#if USE_MEMORY_EFFICIENT_ATTENTION + int sm = (device_prop.major * 10) + device_prop.minor; + bool use_memory_efficient_attention = + !use_flash_attention && + !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + // allocate buffers + size_t kv_buffer_bytes = 0; + // need a buffer if we must ungroup kv + const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); + if (use_memory_efficient_attention && needs_buff) { + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + } + size_t fmha_buffer_bytes = 0; + if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { + fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); + } + auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); +#else + constexpr bool use_memory_efficient_attention = false; + auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); +#endif + + std::vector present_dims; + if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + present_dims = { + parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + } else { // BNSH + present_dims = { + parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + } + TensorShape present_shape(present_dims); + Tensor* present_key = context->Output(1, present_shape); + Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); data.key = reinterpret_cast(key->Data()); @@ -155,6 +188,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.use_flash_attention = use_flash_attention; + data.use_memory_efficient_attention = use_memory_efficient_attention; if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -167,6 +201,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (seqlens_k_buffer != nullptr) { data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } + if (fmha_buffer != nullptr) { + data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 72c9814fad67..a90418ec2243 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -27,6 +27,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; float scale_; bool disable_flash_attention_; + bool disable_memory_efficient_attention_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index be8f5ca0ae3e..8c21de9ced05 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -29,13 +29,13 @@ Status CheckInputs(const Tensor* query, // query (Q) : (B, S, D) // key (K) : (B, S+, D_kv) // value (V) : (B, S+, D_kv) + ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = Q_K_V_BSNH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); - const auto& value_dims = value->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -47,10 +47,8 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = sequence_length; - int kv_hidden_size = (key_dims.size() == 3) - ? static_cast(key_dims[2]) - : (kv_num_heads * static_cast(key_dims[3])); + int kv_sequence_length = static_cast(key_dims[1]); + int kv_hidden_size = static_cast(key_dims[2]); int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { @@ -134,63 +132,49 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent"); } - if (key != nullptr) { - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - if (key_dims[2] != value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have same dim 2 (kv_hidden_size)"); - } - - qkv_format = Q_K_V_BSNH; - kv_sequence_length = static_cast(key_dims[1]); - } else { + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing key tensor."); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - if (value != nullptr) { - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } - if (static_cast(kv_sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); - } + if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch_size)"); + } - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Missing value tensor."); + "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + } + + if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. int32_t past_sequence_length = 0; - int present_sequence_length = 0; + int present_sequence_length = kv_sequence_length; if (past_seq_len != nullptr) { + if (past_key == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past KV must be present as share-buffer when using past_seq_len pointer."); + } if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "past_sequence_length tensor must be of one element when using past kv."); @@ -200,6 +184,10 @@ Status CheckInputs(const Tensor* query, } else { past_sequence_length = static_cast(*((*past_seq_len).template Data())); } + if (past_sequence_length + kv_sequence_length > max_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); + } present_sequence_length = max_sequence_length; } else if (past_key != nullptr) { past_sequence_length = max_sequence_length; // this is the length of past_key tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index ab3029ca3488..0455825c364a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -37,6 +37,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -47,6 +48,8 @@ namespace onnxruntime { namespace contrib { namespace cuda { +////////// Auxiliary Kernels for KV prep + // Kernel for seqlens_k __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size) { int id = blockDim.x * blockIdx.x + threadIdx.x; @@ -75,7 +78,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int present_head_stride = is_bsnh ? H : present_seqlen * H; // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH + // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L const int past_seqlen = present_seqlen - new_seqlen; @@ -95,33 +98,32 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, } } +// Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int H, + const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, const bool is_bsnh) { - // Use when (H*)*num_heads > 1024 - int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int present_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // past_kv: BPNH or BNPH + // new_kv: BLNH + // present_kv: BTNH or BNTH, where T = P + L + const int past_seqlen = present_seqlen - new_seqlen; - const int present_seqlen = gridDim.x; - const int num_heads = blockDim.y; - const int thread_stride = blockDim.x; - - const int present_batch_stride = present_seqlen * num_heads * H; - const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; - - // past_kv: BPNH or BNPH - // new_kv: BLNH or BNLH - // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; - - while (h < H) { int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { const int past_batch_stride = past_seqlen * num_heads * H; @@ -135,133 +137,477 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h; present_kv[out_offset] = new_kv[in_offset]; } - h += thread_stride; } } +// Concat new to past in present. Supports past BSNH or past BNSH template -Status QkvToContext( +Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; // divide by 4 so kernel can operate on 4 float16 elements at a time. + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(present_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKV<<>>(kv_sequence_length, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = (H * kv_num_heads + 255) / 256; + const dim3 grid(steps, present_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_key), + reinterpret_cast(data.key), + reinterpret_cast(data.present_key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatNewToPastKVLarge<<>>(kv_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.past_value), + reinterpret_cast(data.value), + reinterpret_cast(data.present_value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to append new kv to kv buffer in place +template +__global__ void ConcatKVInPlace(const int past_seqlen, + const int present_seqlen, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int new_seqlen = gridDim.x; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; +} + +template +__global__ void ConcatKVInPlaceLarge(const int past_seqlen, + const int present_seqlen, + const int H, + const int num_heads, + T* kv_buff, + const T* new_kv, + const bool is_bsnh) { // refers to kv buff; otherwise bnsh + int i = threadIdx.x + (blockDim.x * blockIdx.x); + if (i < H * num_heads) { + const int h = i % H; + const int n = i / H; + const int s = blockIdx.y; + const int b = blockIdx.z; + const int new_seqlen = gridDim.y; + + const int present_batch_stride = present_seqlen * num_heads * H; + const int present_row_stride = is_bsnh ? num_heads * H : H; + const int present_head_stride = is_bsnh ? H : present_seqlen * H; + + // kv_buff: BTNH or BNTH with buffered memory for new + // new_kv: BLNH + + int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + // Note: new KV always BSNH + const int new_batch_stride = new_seqlen * num_heads * H; + const int new_row_stride = num_heads * H; + const int new_head_stride = H; + const int in_offset = b * new_batch_stride + s * new_row_stride + n * new_head_stride + h; + kv_buff[out_offset] = new_kv[in_offset]; + } +} + +// Concat new to kv buffer in place +template +Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + const int H = head_size / 4; + if (H * kv_num_heads <= max_threads_per_block) { + const dim3 grid(kv_sequence_length, batch_size, 1); + const dim3 block(H, kv_num_heads, 1); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlace<<>>(past_sequence_length, + present_sequence_length, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } else { + int steps = int(ceil(float(H * kv_num_heads) / 256.0)); + const dim3 grid(steps, kv_sequence_length, batch_size); + const dim3 block(256, 1, 1); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + ConcatKVInPlaceLarge<<>>(past_sequence_length, + present_sequence_length, + H, + kv_num_heads, + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel for use with memory efficient kernel... kv_in is grouped and of bnsh or bsnh... kv_out is ungrouped and bsnh +template +__global__ void Ungroup(const T* kv_in, + T* kv_out, + const int in_seqlen, + const int kv_num_heads, + const bool is_bsnh) { + const int h = threadIdx.x; + const int out_n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + + const int out_seqlen = gridDim.x; + const int q_num_heads = blockDim.y; + const int H = blockDim.x; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + const int in_n = out_n / q_kv_head_ratio; + + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; +} + +template +__global__ void UngroupLarge(const T* kv_in, + T* kv_out, + const int H, + const int in_seqlen, + const int q_num_heads, + const int kv_num_heads, + const bool is_bsnh) { + int i = threadIdx.x + (blockDim.x * blockIdx.x); // index along H * q_num_heads elements + if (i < H * q_num_heads) { + const int out_seqlen = gridDim.y; + const int s = blockIdx.y; + const int b = blockIdx.z; + + const int q_kv_head_ratio = q_num_heads / kv_num_heads; + const int out_batch_stride = out_seqlen * q_num_heads * H; + const int out_row_stride = is_bsnh ? q_num_heads * H : H; + const int out_head_stride = is_bsnh ? H : out_seqlen * H; + + const int in_batch_stride = in_seqlen * kv_num_heads * H; + const int in_row_stride = is_bsnh ? kv_num_heads * H : H; + const int in_head_stride = is_bsnh ? H : in_seqlen * H; + + const int h = i % H; + const int out_n = i / H; + const int in_n = out_n / q_kv_head_ratio; + const int out_offset = out_batch_stride * b + out_row_stride * s + out_head_stride * out_n + h; + const int in_offset = in_batch_stride * b + in_row_stride * s + in_head_stride * in_n + h; + kv_out[out_offset] = kv_in[in_offset]; + } +} + +// Ungroup kv or present kv for use in Memory Efficient kernel. If present kv is not null and is BNSH, transposes it. +Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, + float2* k_buff, float2* v_buff, + const float2* k_og, const float2* v_og, + const int buff_seqlen, const int og_seqlen, + const bool is_bsnh, + cudaStream_t stream, + const int max_threads_per_block) { + const int batch_size = parameters.batch_size; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + const int H = head_size / 4; + if (H * num_heads <= max_threads_per_block) { + const dim3 grid(buff_seqlen, batch_size, 1); + const dim3 block(H, num_heads, 1); + Ungroup<<>>(k_og, + k_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + Ungroup<<>>(v_og, + v_buff, + og_seqlen, + kv_num_heads, + is_bsnh); + } else { + int steps = int(ceil(float(H * num_heads) / 256.0)); + const dim3 grid(steps, buff_seqlen, batch_size); + const dim3 block(256, 1, 1); + UngroupLarge<<>>(k_og, + k_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + UngroupLarge<<>>(v_og, + v_buff, + H, + og_seqlen, + num_heads, + kv_num_heads, + is_bsnh); + } + return CUDA_CALL(cudaGetLastError()); +} + +////////// Launch Kernels + +#if USE_FLASH_ATTENTION +template +Status FlashAttention( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, + cudaStream_t stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { - assert(data.use_flash_attention); + GroupQueryAttentionData& data, + float scale) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int present_sequence_length = parameters.present_sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if USE_FLASH_ATTENTION - auto stream = static_cast(ort_stream->GetHandle()); + void* query = reinterpret_cast(const_cast(data.query)); + void* key = reinterpret_cast(const_cast(data.key)); + void* value = reinterpret_cast(const_cast(data.value)); + + bool is_causal = parameters.is_unidirectional; + + if (data.past_key != nullptr && data.past_key == data.present_key) { + // Share buffer case + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + // Launch kernel to copy seqlen + int thr_per_blk = 256; + int blk_in_grid = ceil(float(batch_size) / thr_per_blk); + repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), + reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, kv_sequence_length, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); + + } else { + // Not share buffer or no past (prompt generation) + // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, + sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + } + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.kv_sequence_length; + const int past_sequence_length = parameters.past_sequence_length; const int present_sequence_length = parameters.present_sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(head_size)) : parameters.scale; - if (data.use_flash_attention) { - assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); - assert(parameters.num_heads % parameters.kv_num_heads == 0); - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - - bool is_causal = parameters.is_unidirectional; - - if (data.past_key == nullptr && data.present_key == nullptr) { - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(data.softmax_lse), - parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, head_size, - parameters.sequence_length, parameters.kv_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - - } else if (data.past_key == data.present_key) { - // Assume past and present kv share buffer. - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - assert(parameters.past_sequence_length >= 0); - assert(data.past_value != nullptr); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k, 1, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum))); - - } else if (data.present_key != nullptr && (data.past_key != nullptr || kv_sequence_length == present_sequence_length)) { - assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (head_size % 4 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "requires head_size be divisible by 4"); - } - const int H = head_size / 4; - if (H * kv_num_heads <= max_threads_per_block) { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(H, kv_num_heads, 1); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKV<<>>(kv_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + const void* query = reinterpret_cast(data.query); + const void* key = reinterpret_cast(data.key); + const void* value = reinterpret_cast(data.value); + if (data.past_key != nullptr) { + // Past key case + // concatenate new kv to past kv + if (data.past_key == data.present_key) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); } + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + } else if (num_heads == kv_num_heads) { + // no past or present and no need to ungroup... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + key = reinterpret_cast(data.present_key); + value = reinterpret_cast(data.present_value); + } else { + // intermediate buffer so q and kv have same num heads... still copy kv into present buffer + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + float2* k_buff = reinterpret_cast(data.k); + float2* v_buff = reinterpret_cast(data.v); + const float2* k_og = reinterpret_cast(data.present_key); + const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, + kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, + max_threads_per_block)); + key = reinterpret_cast(data.k); + value = reinterpret_cast(data.v); + } + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = batch_size; + p.num_heads = num_heads; + p.sequence_length = sequence_length; + p.kv_sequence_length = past_sequence_length + kv_sequence_length; + p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.qk_head_size = head_size; + p.v_head_size = head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr; + p.is_attn_bias_batched = false; + p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) + ? data.fmha_buffer + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); + + return Status::OK(); +} +#endif + +////////// API Functions + +template +Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif - return Status::OK(); +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } #endif + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 0bad9eeb6123..8412631078e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -14,19 +14,28 @@ namespace cuda { template struct GroupQueryAttentionData { + // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; int* seqlens_k = nullptr; + // Memory Efficient buffers + T* fmha_buffer = nullptr; + T* k = nullptr; + T* v = nullptr; + // Output Tensors T* output = nullptr; T* present_key = nullptr; T* present_value = nullptr; + // Kernel Flags bool use_flash_attention = false; + bool use_memory_efficient_attention = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index aba0efdbd7d5..d7aeef1501cd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -507,10 +507,12 @@ Status FusedScaledDotProductAttentionCutlass( MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; + p.is_kv_bsnh = true; p.batch_size = parameters.batch_size; p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e09fd9e6b36e..3fe9dbf8ed34 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -688,6 +688,7 @@ Status FusedAttentionCutlass( p.num_heads = parameters.num_heads; p.sequence_length = parameters.sequence_length; p.kv_sequence_length = parameters.sequence_length; + p.max_sequence_length = parameters.sequence_length; p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; @@ -702,6 +703,7 @@ Status FusedAttentionCutlass( p.attn_bias = data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; + p.is_kv_bsnh = true; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 76c3f8716ff0..5bc18a4e69b4 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1051,15 +1051,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .Output(2, "present_value", "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 04351cd6e678..319fed87dc9e 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -10,7 +10,10 @@ # license information. # ------------------------------------------------------------------------- import math +import os +import platform import random +import unittest import numpy import torch @@ -22,6 +25,8 @@ torch.manual_seed(0) +pipeline_mode = True # Reduces number of tests so pipeline doesn't time out + class Formats: BSNH = 0 @@ -159,7 +164,7 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_no_past(config, causal=False): +def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH): nodes = [ helper.make_node( "GroupQueryAttention", @@ -168,11 +173,12 @@ def create_group_query_attention_graph_no_past(config, causal=False): "key", "value", ], - ["output"], + ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, unidirectional=1 if causal else 0, + is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0, domain="com.microsoft", ), ] @@ -213,6 +219,26 @@ def create_group_query_attention_graph_no_past(config, causal=False): TensorProto.FLOAT16, [config.batch_size, config.sequence_length, config.num_heads * config.head_size], ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + ], + ), ] graph = helper.make_graph( @@ -514,7 +540,6 @@ def generate_token_offset(cu_seqlens, max_seqlen): return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) -# TODO(aciddelgado): rename def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): onnx_model_str = create_packed_multihead_attention_graph(config) qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) @@ -548,8 +573,8 @@ def mha_func(q, k, v, config): return output -def gqa_no_past_func(q, k, v, config, causal=True): - onnx_model_str = create_group_query_attention_graph_no_past(config, causal) +def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH): + onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) @@ -560,7 +585,7 @@ def gqa_no_past_func(q, k, v, config, causal=True): } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) - ort_output = ort_session.run(None, ort_inputs) + ort_output, _, _ = ort_session.run(None, ort_inputs) ort_output = numpy.array(ort_output) output = torch.tensor(ort_output) return output @@ -689,17 +714,12 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - # causal_mask = torch.triu( - # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 - # ) causal_mask = construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) if causal: # Some rows are completely masked out so we fill them with zero instead of NaN attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) - # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling - # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: @@ -1072,12 +1092,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(present_k[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref[0, 0, config.past_sequence_length, :10]) - # print(k_cache_ref.shape) - - # print(present_k - k_cache_ref.detach().cpu().numpy()) - # Make sure past-present buffer updating correctly if past_format == Formats.BSNH: assert numpy.allclose( @@ -1141,84 +1155,185 @@ def parity_check_gqa_past_no_buff( ) +class TestMHA(unittest.TestCase): + def test_packed_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST PACKED MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + parity_check_mha(config, True) + + def test_mha(self): + if not torch.cuda.is_available() or platform.system() != "Linux": + return + major, _ = torch.cuda.get_device_capability() + if major < 8: + return + print("-------- TEST MHA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + parity_check_mha(config, False) + + +class TestGQA(unittest.TestCase): + def test_gqa_no_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + torch.manual_seed(69) + print("-------- TEST GQA ---------") + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1024, 1024), + (1023, 1024), + (2048, 2048), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + if major < 5 or (major == 5 and minor < 3): + return + print("------- MEMORY EFFICIENT ATTENTION ---------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION --------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True, False]: + config = Config(b, s, s2, 0, n, n2, h) + parity_check_gqa_no_past(config, causal=causal) + + def test_gqa_past(self): + if not torch.cuda.is_available(): + return + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- TEST GQA PAST ---------") + print("-------- MEMORY EFFICEINT --------") + batches = [2] if pipeline_mode else [1, 2] + seqs = ( + [(1, 128), (3, 1024), (64, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + (1, 128 * 512), + (16, 128 * 512), + (128, 128), + ] + ) + num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + if major < 8 or platform.system() != "Linux": + return + print("------- FLASH ATTENTION -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for causal in [True]: + for past_kv_format in [Formats.BNSH, Formats.BSNH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + causal=causal, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + + if __name__ == "__main__": - print("-------- TEST PACKED MHA ---------") - for b in [5]: - for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s, 0, n, n, h) - parity_check_mha(config, True) - print("-------- TEST MHA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n in [6]: - for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: - config = Config(b, s, s2, 0, n, n, h) - parity_check_mha(config, False) - print("-------- TEST GQA ---------") - for b in [5]: - for s, s2 in [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) - print("-------- TEST GQA PAST ---------") - random.seed(69) - for b in [2]: - for s, s2 in [ - (1, 128), - (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), - ]: - for n, n2 in [(6, 6), (6, 3), (9, 9), (9, 3)]: - for h in [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + unittest.main()