From aa990cb9bd35e682d18c1c697ebf646255b7c08e Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 24 Jul 2024 17:54:08 +1000 Subject: [PATCH 1/8] Changes to add all required ops for priority model - Add Concat (#21423) - Add DepthToSpace (#21426) - Add LeakyRelu (#21453) - Add test scripts (#21427) - Add ability to set coreml flags from python (#21434) Also updated partitioning utils to support dropping constant initializers from a ComputeCapability's inputs. We copy these to a CoreML model so don't need the originals. If they remain as inputs ORT can't free them as they appear to be in use. Misc changes - Fix SkipLayerNormFusion incorrectly setting `modified` - causes unnecessary loops of the L2 transformers --- .../core/optimizer/skip_layer_norm_fusion.cc | 8 +- .../builders/impl/activation_op_builder.cc | 13 +- .../coreml/builders/impl/builder_utils.cc | 22 ++++ .../coreml/builders/impl/builder_utils.h | 20 +++ .../coreml/builders/impl/concat_op_builder.cc | 86 ++++++++---- .../builders/impl/depthtospace_op_builder.cc | 124 +++++++++++++++--- .../coreml/builders/op_builder_factory.cc | 23 ++-- .../coreml/coreml_execution_provider.cc | 4 +- .../DebugMLProgram.md | 2 + .../mlprogram_test_scripts/concat_test.py | 33 +++++ .../convtranspose_test.py | 42 ++++++ .../depthtospace_test.py | 51 +++++++ .../coreml/mlprogram_test_scripts/div_test.py | 103 +++++++++++++++ .../dump_mlprogram_model.py | 0 .../mlprogram_test_scripts/gridsample_test.py | 114 ++++++++++++++++ .../mlprogram_test_scripts/resize_test.py | 51 +++++++ .../core/providers/partitioning_utils.cc | 39 +++--- .../core/providers/partitioning_utils.h | 25 ++-- .../providers/qnn/qnn_execution_provider.cc | 2 +- .../python/onnxruntime_pybind_state.cc | 70 +++++----- .../test/optimizer/qdq_transformer_test.cc | 3 +- .../cpu/tensor/space_depth_ops_test.cc | 31 +++++ .../apple/coreml_supported_mlprogram_ops.md | 3 + 23 files changed, 735 insertions(+), 134 deletions(-) rename onnxruntime/core/providers/coreml/{ => mlprogram_test_scripts}/DebugMLProgram.md (97%) create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py rename onnxruntime/core/providers/coreml/{ => mlprogram_test_scripts}/dump_mlprogram_model.py (100%) create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py create mode 100644 onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index cf70a7d821d7..d8f49124a2fa 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -168,7 +168,8 @@ Note: This fusion doesn't consider the following case: LayerNormalization */ -Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { +Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); InlinedVector> nodes_to_remove; @@ -299,13 +300,14 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le // Assign provider to this new node. Provider should be same as the provider for old node. skip_layer_norm_node.SetExecutionProviderType(ln_node.GetExecutionProviderType()); } + + modified = !nodes_to_remove.empty(); + for (const auto& node : nodes_to_remove) { graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.get().Index()); } - modified = true; - return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 0e2171551370..c8670cd54625 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -83,12 +83,16 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, using namespace CoreML::Specification::MILSpec; // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation std::string_view coreml_op_type; + bool add_alpha = false; if (op_type == "Sigmoid") { coreml_op_type = "sigmoid"; } else if (op_type == "Tanh") { coreml_op_type = "tanh"; } else if (op_type == "Relu") { coreml_op_type = "relu"; + } else if (op_type == "LeakyRelu") { + coreml_op_type = "leaky_relu"; + add_alpha = true; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -96,6 +100,13 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); + + if (add_alpha) { + NodeAttrHelper helper(node); + const auto alpha = helper.Get("alpha", 0.01f); + AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha)); + } + AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); @@ -198,7 +209,7 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { - if (op_type == "PRelu" || op_type == "LeakyRelu") { + if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable return false; } } else diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index ebb3f97895f0..f1cfbd305443 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -314,6 +314,28 @@ void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std: (*op.mutable_inputs())[input_name] = std::move(arg); } +void AddOperationInputs(MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names) { + MILSpec::Argument arg; + for (const auto& value : value_names) { + arg.mutable_arguments()->Add()->set_name(std::string(value)); + } + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, const std::string& output_name, + int32_t element_type, std::optional> shape) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output_name); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(element_type), shape, /*convert_scalar*/ true); +} + void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output, std::optional override_element_type) { auto& outputs = *op.mutable_outputs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index f012e6af0d71..25e30577cf1e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -129,6 +129,26 @@ COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& n void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, std::string_view value_name); +/// +/// Add a variadic input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The input value names. +void AddOperationInputs(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names); + +/// Add an output to a MILSpec::Operation for an intermediate operation when the implementation is composed of +/// multiple MLProgram operations. In this case we don't have a NodeArg for the output. +/// +/// Operation to update. +/// Name of the intermediate output. Create using ModelBuilder::GetUniqueName. +/// onnx::TensorProto_DataType element type of the output. +/// int32_t as that is what TensorShapeProto uses to store the value. +/// Shape of the output if known. +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, const std::string& output_name, + int32_t element_type, std::optional> shape); + /// /// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. /// diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index 34193318a026..551d8222cc06 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,6 +4,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,27 +19,52 @@ class ConcatOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - layer->mutable_concat()->set_sequenceconcat(false); - - for (const auto* input : node.InputDefs()) { - LOGS(logger, VERBOSE) << "input name " << input->Name(); - *layer->mutable_input()->Add() = input->Name(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + NodeAttrHelper helper(node); + const auto axis = helper.GetInt64("axis"); // required + const auto interleave = false; + + std::unique_ptr op = model_builder.CreateOperation(node, "concat"); + std::vector input_names; + for (const auto* input : node.InputDefs()) { + input_names.emplace_back(input->Name()); + } + AddOperationInputs(*op, "values", input_names); + AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", *axis)); + AddOperationInput(*op, "interleave", model_builder.AddScalarConstant(op->type(), "interleave", interleave)); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + layer->mutable_concat()->set_sequenceconcat(false); + + for (const auto* input : node.InputDefs()) { + LOGS(logger, VERBOSE) << "input name " << input->Name(); + *layer->mutable_input()->Add() = input->Name(); + } + + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); } - - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); - - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); if (input_defs.size() < 2) { @@ -50,23 +76,25 @@ bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa if (!GetShape(*input_defs[0], input_shape, logger)) return false; - auto rank = input_shape.size(); - if (rank != 4) { - // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis - // Instead of concat on axis 0, it will concat on axis 1 - // Disable Concat support for 3d tensor for now - // TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d - LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " - << rank << "d shape"; - return false; - } - - NodeAttrHelper helper(node); - auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); - if (rank != axis + 3) { - LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis - << ", actual rank: " << rank; - return false; + if (!input_params.create_mlprogram) { + auto rank = input_shape.size(); + if (rank != 4) { + // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis + // Instead of concat on axis 0, it will concat on axis 1 + // Disable Concat support for 3d tensor for now + // TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d + LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " + << rank << "d shape"; + return false; + } + + NodeAttrHelper helper(node); + auto axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); + if (rank != axis + 3) { + LOGS(logger, VERBOSE) << "Concat only support axis to be -3, actual axis: " << axis + << ", actual rank: " << rank; + return false; + } } return true; diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index 1eba312b2577..bec2461ffbc5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,6 +4,7 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -18,52 +19,133 @@ class DepthToSpaceOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - + [[maybe_unused]] const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); - const auto& output_name = output_defs[0]->Name(); - uint64_t blocksize = SafeInt(node.GetAttributes().at("blocksize").i()); + NodeAttrHelper helper(node); + int64_t blocksize = *helper.GetInt64("blocksize"); // required attribute + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; // NOLINT + + const auto mode = helper.Get("mode", "DCR"); + + if (mode == "DCR") { + // DCR is directly supported + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.depth_to_space + // Validated with depth_to_space.py. + auto op = model_builder.CreateOperation(node, "depth_to_space"); + AddOperationInput(*op, "x", input_name); + AddOperationInput(*op, "block_size", model_builder.AddScalarConstant(op->type(), "blocksize", blocksize)); + AddOperationOutput(*op, *output_defs[0]); + model_builder.AddOperation(std::move(op)); + } else { + // CRD is manual. there may be a perf cost from the Reshape's (typically that happens on CPU) but if the input + // is a fixed size hopefully CoreML is smart enough to handle that aspect during model compilation instead + // of execution. + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#depthtospace + // b, c, h, w = x.shape + // tmp = np.reshape(x, [b, c // (blocksize ** 2), blocksize, blocksize, h, w]) + // tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) + // y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) + // + // CoreML has a 5D limit, so we merge the batch dim into the channel dim as that doesn't change the data + // movement. + // First reshape is to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] + // Transpose is to [0, 3, 1, 4, 2] + + // we checked shape was static in IsOpSupportedImpl so this should never fail + std::vector input_shape; + ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Failed to get input shape"); + const int32_t elem_type = static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // reshape to [b * c // (blocksize ** 2), blocksize, blocksize, h, w] + auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre"); + std::vector shape1 = {input_shape[0] * input_shape[1] / (blocksize * blocksize), + blocksize, blocksize, input_shape[2], input_shape[3]}; + AddOperationInput(*reshape1, "x", input_name); + AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape", shape1)); + const auto& reshape1_output = model_builder.GetUniqueName(node, "reshape1"); + AddIntermediateOperationOutput(*reshape1, reshape1_output, elem_type, shape1); + + // transpose to [0, 3, 1, 4, 2] + auto transpose = model_builder.CreateOperation(node, "transpose"); + std::vector perm = {0, 3, 1, 4, 2}; + std::vector shape2 = {shape1[0], shape1[3], shape1[1], shape1[4], shape1[2]}; + AddOperationInput(*transpose, "x", reshape1_output); + AddOperationInput(*transpose, "perm", model_builder.AddConstant(transpose->type(), "perm", perm)); + const auto& transpose_output = model_builder.GetUniqueName(node, "transpose"); + AddIntermediateOperationOutput(*transpose, transpose_output, elem_type, shape2); + + // reshape to [b, c // (blocksize ** 2), h * blocksize, w * blocksize] + auto reshape2 = model_builder.CreateOperation(node, "reshape", "post"); + std::vector shape3 = {input_shape[0], + input_shape[1] / (blocksize * blocksize), + input_shape[2] * blocksize, + input_shape[3] * blocksize}; + AddOperationInput(*reshape2, "x", transpose_output); + AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape", shape3)); + + AddOperationOutput(*reshape2, *output_defs[0]); + + model_builder.AddOperation(std::move(reshape1)); + model_builder.AddOperation(std::move(transpose)); + model_builder.AddOperation(std::move(reshape2)); + } + } else // NOLINT +#endif // if defined(COREML_ENABLE_MLPROGRAM) + { + const auto& output_name = output_defs[0]->Name(); + std::unique_ptr layer = model_builder.CreateNNLayer(node); - auto* coreml_depthtospace = layer->mutable_reorganizedata(); - coreml_depthtospace->set_blocksize(blocksize); - coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType:: - ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE); + auto* coreml_depthtospace = layer->mutable_reorganizedata(); + coreml_depthtospace->set_blocksize(static_cast(blocksize)); + coreml_depthtospace->set_mode(CoreML::Specification::ReorganizeDataLayerParams_ReorganizationType:: + ReorganizeDataLayerParams_ReorganizationType_DEPTH_TO_SPACE); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, +bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) { + LOGS(logger, VERBOSE) << "DepthToSpace: no input shape"; return false; } - const auto input_rank = input_shape.size(); - if (input_rank < 4) { - LOGS(logger, VERBOSE) << "DepthToSpace does not support input shape of " << input_rank << "d shape."; - } + // ONNX and CoreML both require 4D input so no need to check the shape here. NodeAttrHelper helper(node); - if (node.SinceVersion() >= 11) { - // For now, only DCR mode DepthToSpace is supported - const auto mode = helper.Get("mode", "DCR"); + const auto mode = helper.Get("mode", "DCR"); + + if (input_params.create_mlprogram) { + if (mode == "CRD" && !IsStaticShape(input_shape)) { + // we need to manually implement the logic with a Reshape, so we need to know the shape to do that + LOGS(logger, VERBOSE) << "DepthToSpace: CRD mode requires static shape"; + return false; + } + } else { if (mode != "DCR") { - LOGS(logger, VERBOSE) << "The mode: " << mode << "of DepthToSpace is not supported in CoreML EP for now."; + LOGS(logger, VERBOSE) << "DepthToSpace: " << mode << " mode is not supported"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index 535712f09601..b0006b24e7d7 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -15,28 +15,28 @@ namespace coreml { static OpBuilderRegistrations CreateOpBuilderRegistrations() { OpBuilderRegistrations op_registrations; + // Activations + CreateActivationOpBuilder("Sigmoid", op_registrations); + CreateActivationOpBuilder("Tanh", op_registrations); + CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("PRelu", op_registrations); + CreateActivationOpBuilder("LeakyRelu", op_registrations); + // Unary ops - CreateUnaryOpBuilder("Sqrt", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); + CreateUnaryOpBuilder("Sqrt", op_registrations); // Binary elementwise ops CreateBinaryOpBuilder("Add", op_registrations); + CreateBinaryOpBuilder("Div", op_registrations); CreateBinaryOpBuilder("Mul", op_registrations); CreateBinaryOpBuilder("Pow", op_registrations); CreateBinaryOpBuilder("Sub", op_registrations); - CreateBinaryOpBuilder("Div", op_registrations); - - // Activations - CreateActivationOpBuilder("Sigmoid", op_registrations); - CreateActivationOpBuilder("Tanh", op_registrations); - CreateActivationOpBuilder("Relu", op_registrations); - CreateActivationOpBuilder("PRelu", op_registrations); - CreateActivationOpBuilder("LeakyRelu", op_registrations); // Pooling ops + CreatePoolOpBuilder("AveragePool", op_registrations); CreatePoolOpBuilder("GlobalAveragePool", op_registrations); CreatePoolOpBuilder("GlobalMaxPool", op_registrations); - CreatePoolOpBuilder("AveragePool", op_registrations); CreatePoolOpBuilder("MaxPool", op_registrations); // Reduction ops @@ -54,6 +54,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateFlattenOpBuilder("Flatten", op_registrations); CreateGatherOpBuilder("Gather", op_registrations); CreateGemmOpBuilder("Gemm", op_registrations); + CreateGridSampleOpBuilder("GridSample", op_registrations); CreateLRNOpBuilder("LRN", op_registrations); CreateGemmOpBuilder("MatMul", op_registrations); CreatePadOpBuilder("Pad", op_registrations); @@ -66,8 +67,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateSqueezeOpBuilder("Squeeze", op_registrations); CreateTransposeOpBuilder("Transpose", op_registrations); - CreateGridSampleOpBuilder("GridSample", op_registrations); - return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index a92fef81ac39..f2cd4d01174d 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -83,7 +83,9 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, - gen_metadef_name, COREML, kCoreMLExecutionProvider); + gen_metadef_name, COREML, kCoreMLExecutionProvider, + nullptr, + /*drop_constant_initializers*/ true); const auto num_of_partitions = result.size(); const auto num_of_supported_nodes = std::transform_reduce( diff --git a/onnxruntime/core/providers/coreml/DebugMLProgram.md b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md similarity index 97% rename from onnxruntime/core/providers/coreml/DebugMLProgram.md rename to onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md index e41a51559430..b7a54466ab8d 100644 --- a/onnxruntime/core/providers/coreml/DebugMLProgram.md +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/DebugMLProgram.md @@ -25,6 +25,8 @@ https://apple.github.io/coremltools/docs-guides/source/model-intermediate-langua Usage is reasonably intuitive. The below example defines a model with 2 inputs and a matmul operator. The model is printed, and run with randomly generated inputs. The output from doing so is printed. +There are additional test scripts in this directory for different operators. + ```python import numpy as np import coremltools as ct diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py new file mode 100644 index 000000000000..430a2b3fa3ed --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/concat_test.py @@ -0,0 +1,33 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +a_shape = (1, 1, 3, 3) + + +@mb.program( + input_specs=[mb.TensorSpec(shape=a_shape), mb.TensorSpec(shape=a_shape), mb.TensorSpec(shape=a_shape)], + opset_version=target, +) +def prog(x, y, z): + axis = mb.const(val=1) + interleave = mb.const(val=False) + z = mb.concat(values=(x, y, z), axis=axis, interleave=interleave) + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +x = np.random.rand(*a_shape) +y = np.random.rand(*a_shape) +z = np.random.rand(*a_shape) + +# spec = m.get_spec() +# print(spec) + +print(m.predict({"x": x, "y": y, "z": z})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py new file mode 100644 index 000000000000..2c8cbc4948a6 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/convtranspose_test.py @@ -0,0 +1,42 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (1, 3, 4, 4) +w_shape = (3, 3, 3, 3) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + weight = mb.const(name="weight", val=np.ones(w_shape, dtype=np.float32)) + output_shape = mb.const(name="output_shape", val=np.array([1, 3, 4, 4])) + # pad = mb.const(val=np.zeros((4), dtype=np.int32)) + strides = mb.const(name="strides", val=np.ones((2), dtype=np.int32)) + dilations = mb.const(name="dilations", val=np.ones((2), dtype=np.int32)) + z = mb.conv_transpose( + x=x, weight=weight, strides=strides, dilations=dilations, output_shape=output_shape + ) # , pad=pad + + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("ConvTranspose.mlpackage") +# construct MLModel with compute_units=ComputeUnit.CPU and run predict +m_cpu = ct.models.MLModel("ConvTranspose.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("ConvTranspose.mlpackage", compute_units=ct.ComputeUnit.ALL) + +x = np.ones(x_shape, dtype=np.float32) +print("CPU_ONLY") +print(m_cpu.predict({"x": x})) +print("ALL") +print(m_all.predict({"x": x})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py new file mode 100644 index 000000000000..593d9e8bbf66 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/depthtospace_test.py @@ -0,0 +1,51 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +# replicate example from https://github.com/onnx/onnx/blob/main/docs/Operators.md#depthtospace +# to prove CoreML mode is DCR +x_shape = (1, 8, 2, 3) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + block_size = mb.const(name="block_size", val=2) + z = mb.depth_to_space(x=x, block_size=block_size) + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("DepthToSpace.mlpackage") + +# also check for differences between CPU_ONLY and ALL +m_cpu = ct.models.MLModel("DepthToSpace.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("DepthToSpace.mlpackage", compute_units=ct.ComputeUnit.ALL) + +x = np.array( + [ + [ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[9.0, 10.0, 11.0], [12.0, 13.0, 14.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + [[27.0, 28.0, 29.0], [30.0, 31.0, 32.0]], + [[36.0, 37.0, 38.0], [39.0, 40.0, 41.0]], + [[45.0, 46.0, 47.0], [48.0, 49.0, 50.0]], + [[54.0, 55.0, 56.0], [57.0, 58.0, 59.0]], + [[63.0, 64.0, 65.0], [66.0, 67.0, 68.0]], + ] + ] +).astype(np.float32) + +print("CPU_ONLY") +print(m_cpu.predict({"x": x})) +print("ALL") +print(m_all.predict({"x": x})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py new file mode 100644 index 000000000000..a0423511598f --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/div_test.py @@ -0,0 +1,103 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb +from coremltools.models import datatypes +from coremltools.models.neural_network import NeuralNetworkBuilder +from coremltools.models.utils import save_spec + +input_dim = (1,) +output_dim = (1,) + + +def mlprogram(): + target = ct.target.iOS15 + + @mb.program(input_specs=[mb.TensorSpec(shape=input_dim), mb.TensorSpec(shape=input_dim)], opset_version=target) + def prog(x, y): + return mb.real_div(x=x, y=y) + + # print(prog) + + # Convert to ML program + m = ct.convert(prog, minimum_deployment_target=target) + + x = np.array([2], dtype=np.float32) + y = np.array([2047], dtype=np.float32) + + # spec = m.get_spec() + # print(spec) + + print(m.predict({"x": x, "y": y})) + + +# implement Div with coremltools approach of x * (1/y) +def nn(): + input_features = [("x", datatypes.Array(*input_dim)), ("y_inv", datatypes.Array(*input_dim))] + output_features = [("final", datatypes.Array(*output_dim))] + + # Build a simple neural network with 1 inner product layer + builder = NeuralNetworkBuilder(input_features, output_features) + builder.add_elementwise( + name="x_multiply_inverse_of_y", + input_names=["x", "y_inv"], + output_name="final", + mode="MULTIPLY", + ) + + save_spec(builder.spec, "network.mlmodel") + m = ct.models.MLModel("network.mlmodel") + + x = np.array([2], dtype=np.float32) + y = np.array([1 / 2047], dtype=np.float32) + print(m.predict({"x": x, "y_inv": y})) + + +def nn_scale(): + input_features = [ + ("x", datatypes.Array(*input_dim)), + ("y_inv", datatypes.Array(*input_dim)), + ("z", datatypes.Array(*input_dim)), + ] + output_features = [("final", datatypes.Array(*output_dim))] + + builder = NeuralNetworkBuilder(input_features, output_features) + + builder.add_elementwise( + name="div_implemented_as_x_multiply_inverse_of_y", + input_names=["x", "y_inv"], + output_name="div_result", + mode="MULTIPLY", + ) + + builder.add_elementwise( + name="apply_scaling_factor", + input_names=["div_result", "z"], + output_name="final", + mode="MULTIPLY", + ) + + from coremltools.models.utils import save_spec + + save_spec(builder.spec, "network.mlmodel") + m = ct.models.MLModel("network.mlmodel") + + a = 2 + b = 2047 + # scaling factor to test working around coremltools inaccuracy. + # weirdly even a scaling factor of 1 fixes the problem from https://github.com/microsoft/onnxruntime/issues/21170 + c = 1000 + + x = np.array([a], dtype=np.float32) + y = np.array([1 / b / c], dtype=np.float32) + z = np.array([c], dtype=np.float32) + print(m.predict({"x": x, "y_inv": y, "z": z})) + + +print("NN") +nn() + +print("\nNN with scaling") +nn_scale() + +print("\nML Program") +mlprogram() diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/dump_mlprogram_model.py similarity index 100% rename from onnxruntime/core/providers/coreml/dump_mlprogram_model.py rename to onnxruntime/core/providers/coreml/mlprogram_test_scripts/dump_mlprogram_model.py diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py new file mode 100644 index 000000000000..5ce79c204c00 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/gridsample_test.py @@ -0,0 +1,114 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (2, 2, 3, 2) +grid_shape = (2, 3, 2, 2) + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape), mb.TensorSpec(shape=grid_shape)], opset_version=target) +def prog(x, grid): + sampling = mb.const(name="sampling_mode", val="bilinear") + padding_mode = mb.const(name="pmode", val="reflection") + pad = mb.const(name="pval", val=np.float32(0)) + coord_mode = mb.const(name="coord_mode", val="normalized_minus_one_to_one") + align_corners = mb.const(name="align_corners", val=False) + z = mb.resample( + x=x, + coordinates=grid, + sampling_mode=sampling, + padding_mode=padding_mode, + padding_value=pad, + coordinates_mode=coord_mode, + align_corners=align_corners, + ) + + return z + + +# print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +# spec = m.get_spec() +# print(spec) + +m.save("GridSample.mlpackage") +# construct MLModel with compute_units=ComputeUnit.CPU and run predict +m_cpu = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.CPU_ONLY) +m_all = ct.models.MLModel("GridSample.mlpackage", compute_units=ct.ComputeUnit.ALL) + +# GridSampleTest.test_grid_sample_20_4D_bilinear_reflection_no_align_corners +# ORT produces different output for this test. ORT output is generated by pytorch +x = ( + np.array( + [ + -0.173652, + -1.513725, + -0.704586, + -1.952375, + -0.699404, + -0.806298, + 1.640852, + -0.138969, + -0.695411, + -1.352111, + 0.568797, + -0.564294, + -0.056468, + 0.641604, + -0.438370, + 0.450167, + -1.091401, + 1.669729, + -0.908544, + 0.244467, + 0.172109, + 1.156741, + -0.617128, + 1.155460, + ] + ) + .astype(np.float32) + .reshape(x_shape) +) + +grid = ( + np.array( + [ + 0.252250, + -0.151452, + 0.824706, + -0.588292, + -0.591147, + -0.155082, + -0.732938, + 0.457493, + -0.439559, + 0.492330, + 0.696447, + 0.700722, + -0.220298, + 0.654884, + -0.635434, + -1.195619, + -0.114204, + -0.870080, + -0.929674, + 0.305035, + 1.025429, + -0.472240, + -0.067881, + -0.869393, + ] + ) + .astype(np.float32) + .reshape(grid_shape) +) + + +print(m_cpu.predict({"x": x, "grid": grid})) +print(m_all.predict({"x": x, "grid": grid})) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py new file mode 100644 index 000000000000..3cfe2658f945 --- /dev/null +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py @@ -0,0 +1,51 @@ +import coremltools as ct +import numpy as np +from coremltools.converters.mil import Builder as mb + +target = ct.target.iOS15 + +x_shape = (1, 1, 3, 6) + +use_scale = False # set this to test upsample vs resize + + +@mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) +def prog(x): + global use_scale + + if use_scale: + align = mb.const(val=False) + scale_h = mb.const(val=float(1 / 3)) + scale_w = mb.const(val=float(1 / 3)) + z = mb.upsample_bilinear(x=x, scale_factor_height=scale_h, scale_factor_width=scale_w, align_corners=align) + else: + size_h = mb.const(val=1) + size_w = mb.const(val=2) + sampling_mode = mb.const(val="UNALIGN_CORNERS") + z = mb.resize_bilinear(x=x, target_size_height=size_h, target_size_width=size_w, sampling_mode=sampling_mode) + + return z + + +print(prog) + +# Convert to ML program +m = ct.convert(prog, minimum_deployment_target=target, compute_precision=ct.precision.FLOAT32) + +x = np.array( + [ + [ + [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + ] + ] + ], + dtype=np.float32, +) + +# spec = m.get_spec() +# print(spec) + +print(m.predict({"x": x})) diff --git a/onnxruntime/core/providers/partitioning_utils.cc b/onnxruntime/core/providers/partitioning_utils.cc index c45f5cd0848d..83c08f3dbd25 100644 --- a/onnxruntime/core/providers/partitioning_utils.cc +++ b/onnxruntime/core/providers/partitioning_utils.cc @@ -88,8 +88,6 @@ It is required to ensure we do not break up a QDQ node unit during partitioning. @param graph_viewer GraphViewer that IExecutionProvider::GetCapability is called with. @param is_node_supported_fn Callback to check whether a node is supported. @param on_group_closed_fn Callback to indicate a completed partition node group. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. @return The partition node groups. */ std::vector> CreateSupportedPartitionNodeGroups( @@ -97,12 +95,7 @@ std::vector> CreateSupportedPartitionNodeGroups( const IsNodeSupportedFn& is_node_supported_fn, const OnGroupClosedFn& on_group_closed_fn, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map, - bool debug_output) { -#ifdef NDEBUG - ORT_UNUSED_PARAMETER(debug_output); -#endif - + const std::unordered_map* node_unit_map) { ORT_ENFORCE(is_node_supported_fn, "Node support test is required."); /* @@ -146,12 +139,10 @@ std::vector> CreateSupportedPartitionNodeGroups( auto close_group = [&]() { if (!supported_group.empty()) { #ifndef NDEBUG - if (debug_output) { - LOGS_DEFAULT(VERBOSE) << "New partition node group.\n" - << "Unsupported nodes on group border: " - << NodeGroupDebugString(nodes_to_process_with_next_group, true) << "\n" - << "Nodes in group: " << NodeGroupDebugString(supported_group); - } + LOGS_DEFAULT(VERBOSE) << "New partition node group.\n" + << "Unsupported nodes on group border: " + << NodeGroupDebugString(nodes_to_process_with_next_group, true) << "\n" + << "Nodes in group: " << NodeGroupDebugString(supported_group); #endif // if no on_group_closed_fn callback was given, keep the partition @@ -163,7 +154,7 @@ std::vector> CreateSupportedPartitionNodeGroups( } #ifndef NDEBUG else { - LOGS_DEFAULT_IF(debug_output, VERBOSE) << "Discarded partition node group."; + LOGS_DEFAULT(VERBOSE) << "Discarded partition node group."; } #endif @@ -291,7 +282,8 @@ InlinedHashSet CreateExcludedNodeSet(const GraphViewer& graph_viewe std::unique_ptr MakeComputeCapability(const GraphViewer& graph_viewer, const std::vector& group, const GenerateMetadefNameFn& generate_metadef_name, - const std::string& execution_provider_name) { + const std::string& execution_provider_name, + bool drop_constant_initializers) { std::unordered_set node_set; node_set.reserve(group.size()); node_set.insert(group.cbegin(), group.cend()); @@ -354,6 +346,10 @@ std::unique_ptr MakeComputeCapability(const GraphViewer& grap meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; for (const auto& input : ordered_subgraph_inputs) { + if (drop_constant_initializers && graph_viewer.IsConstantInitializer(input->Name(), true)) { + continue; + } + meta_def->inputs.push_back(input->Name()); } @@ -374,13 +370,12 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::string& execution_provider_name, const std::string& execution_provider_type, const std::unordered_map* node_unit_map, - bool debug_output) { + bool drop_constant_initializers) { const auto groups = CreateSupportedPartitionNodeGroups(graph_viewer, is_node_supported_fn, on_partition_closed_fn, execution_provider_type, - node_unit_map, - debug_output); + node_unit_map); std::vector> partitions{}; partitions.reserve(groups.size()); @@ -390,7 +385,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, std::back_inserter(partitions), [&](const auto& supported_partition) { return MakeComputeCapability(graph_viewer, supported_partition, generate_metadef_name_fn, - execution_provider_name); + execution_provider_name, drop_constant_initializers); }); return partitions; @@ -404,7 +399,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::string& execution_provider_name, const std::string& execution_provider_type, const std::unordered_map* node_unit_map, - bool debug_output) { + bool drop_constant_initializers) { const auto excluded_nodes = CreateExcludedNodeSet(graph_viewer, stop_ops); const bool check_excluded_nodes = !excluded_nodes.empty(); @@ -419,7 +414,7 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, execution_provider_name, execution_provider_type, node_unit_map, - debug_output); + drop_constant_initializers); } } // namespace utils diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h index c3f6b104e3f6..235a88cfdb8a 100644 --- a/onnxruntime/core/providers/partitioning_utils.h +++ b/onnxruntime/core/providers/partitioning_utils.h @@ -62,9 +62,10 @@ Create the supported partitions for the execution provider. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. Should be created by EP calling GetAllNodeUnits. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. - +@param drop_constant_initializer Drop constant initializers from input to a ComputeCapability. + Set to true if constant initializers have been copied into a compiled model to allow + ORT to free the initializer. If the initializer remains as an input it will appear to + still be in-use. @returns ComputeCapability instances for all partitions assigned to the execution provider. */ std::vector> @@ -74,8 +75,8 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name_fn, const std::string& execution_provider_name, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map = nullptr, - bool debug_output = false); + const std::unordered_map* node_unit_map, + bool drop_constant_initializers = false); /** Create the supported partitions for the execution provider. @@ -88,9 +89,10 @@ Create the supported partitions for the execution provider. @param execution_provider_type ExecutionProviderType of the EP creating this ComputeCapability instance. @param node_unit_map Map of each Node in the graph_viewer to its NodeUnit. Provide if EP handles QDQ format models. Should be created by EP calling GetAllNodeUnits. -@param debug_output Print diagnostic output about the partitions and reasons for partition breaks. - No-op in a release build. - +@param drop_constant_initializer Drop constant initializers from input to a ComputeCapability. + Set to true if constant initializers have been copied into a compiled model to allow + ORT to free the initializer. If the initializer remains as an input it will appear to + still be in-use. @returns ComputeCapability instances for all partitions assigned to the execution provider. */ std::vector> @@ -100,8 +102,8 @@ CreateSupportedPartitions(const GraphViewer& graph_viewer, const GenerateMetadefNameFn& generate_metadef_name, const std::string& execution_provider_name, const std::string& execution_provider_type, - const std::unordered_map* node_unit_map = nullptr, - bool debug_output = false); + const std::unordered_map* node_unit_map, + bool drop_constant_initializers = false); /** Create a ComputeCapability instance from the group of nodes. @@ -120,7 +122,8 @@ Will automatically determine the inputs and outputs required. std::unique_ptr MakeComputeCapability(const GraphViewer& graph_viewer, const std::vector& group, const GenerateMetadefNameFn& generate_metadef_name, - const std::string& execution_provider_name); + const std::string& execution_provider_name, + bool drop_constant_initializers); /** Create the set of nodes to exclude based on a set of stop ops. diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 0ddaa9769421..e90671cf2e77 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -660,7 +660,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Create partitions from supported nodes. std::vector> partitions = utils::CreateSupportedPartitions( - graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map, true); + graph_viewer, supported_nodes, {}, gen_metadef_name, QNN, kQnnExecutionProvider, &node_unit_map); // Filter out partitions that consist of a single QuantizeLinear or DequantizeLinear node. // We also count the number of supported nodes in all valid partitions. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 6b5daf8cb882..0106e4489537 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -40,6 +40,10 @@ #include // for CUDNN_MAJOR #endif +#if defined(USE_COREML) +#include "core/providers/coreml/coreml_provider_factory.h" +#endif + #include // Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, @@ -1161,7 +1165,30 @@ std::unique_ptr CreateExecutionProviderInstance( #if !defined(__APPLE__) LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build."; #endif - return onnxruntime::CoreMLProviderFactoryCreator::Create(0)->CreateProvider(); + uint32_t coreml_flags = 0; + + const auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + const ProviderOptions& options = it->second; + auto flags = options.find("flags"); + if (flags != options.end()) { + const auto& flags_str = flags->second; + + if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY; + } + + if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES; + } + + if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) { + coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM; + } + } + } + + return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #endif } else if (type == kXnnpackExecutionProvider) { #if defined(USE_XNNPACK) @@ -1893,9 +1920,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") } res << ")"; - return std::string(res.str()); - }, - "converts the node into a readable string") + return std::string(res.str()); }, "converts the node into a readable string") .def_property_readonly( "shape", [](const onnxruntime::NodeArg& na) -> std::vector { auto shape = na.Shape(); @@ -1914,9 +1939,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") arr[i] = py::none(); } } - return arr; - }, - "node shape (assuming the node holds a tensor)"); + return arr; }, "node shape (assuming the node holds a tensor)"); py::class_ sessionObjectInitializer(m, "SessionObjectInitializer"); py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") @@ -2108,50 +2131,34 @@ including arg name, arg type (contains both type and shape).)pbdoc") return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs(); }) .def( - "get_providers", [](const PyInferenceSession* sess) -> const std::vector& { - return sess->GetSessionHandle()->GetRegisteredProviderTypes(); - }, - py::return_value_policy::reference_internal) + "get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) .def( - "get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { - return sess->GetSessionHandle()->GetAllProviderOptions(); - }, - py::return_value_policy::reference_internal) + "get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) .def_property_readonly( "session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { auto session_options = std::make_unique(); session_options->value = sess->GetSessionHandle()->GetSessionOptions(); - return session_options.release(); - }, - py::return_value_policy::take_ownership) + return session_options.release(); }, py::return_value_policy::take_ownership) .def_property_readonly( "inputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelInputs(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) + return *(res.second); }, py::return_value_policy::reference_internal) .def_property_readonly( "outputs_meta", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetModelOutputs(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) + return *(res.second); }, py::return_value_policy::reference_internal) .def_property_readonly( "overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector& { auto res = sess->GetSessionHandle()->GetOverridableInitializers(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) + return *(res.second); }, py::return_value_policy::reference_internal) .def_property_readonly( "model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& { auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); - return *(res.second); - }, - py::return_value_policy::reference_internal) + return *(res.second); }, py::return_value_policy::reference_internal) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; // release GIL to allow multiple python threads to invoke Run() in parallel. @@ -2161,8 +2168,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") else status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get()); if (!status.IsOK()) - throw std::runtime_error("Error in execution: " + status.ErrorMessage()); - }) + throw std::runtime_error("Error in execution: " + status.ErrorMessage()); }) .def("get_tuning_results", [](PyInferenceSession* sess) -> py::list { #if !defined(ORT_MINIMAL_BUILD) auto results = sess->GetSessionHandle()->GetTuningResults(); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1638851daf65..bc8b9d892160 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3525,7 +3525,8 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { const auto compute_capability = utils::MakeComputeCapability( whole_graph_viewer, nodes, []() { return "sub_graph"; }, - "Test Provider"); + "Test Provider", + /*drop_constant_initializers*/ false); const GraphViewer partial_graph_viewer(graph, *compute_capability->sub_graph); ASSERT_EQ(3, partial_graph_viewer.NumberOfNodes()); diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 5222380d9ca5..a0c1d675f506 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -373,5 +373,36 @@ TEST(TensorOpTest, DepthToSpaceTest_5) { test.Run(); } +TEST(TensorOpTest, DepthToSpaceTest_CRD_Batched) { + OpTester test("DepthToSpace", 11); // create an opset 11 model with attribute present = "CRD" mode + constexpr int64_t blocksize = 2; + test.AddAttribute("blocksize", blocksize); + test.AddAttribute("mode", "CRD"); + + constexpr int64_t N = 2, C = 4, H = 2, W = 3; + std::vector X = {0., 1., 2., + 3., 4., 5., + 9., 10., 11., + 12., 13., 14., + 18., 19., 20., + 21., 22., 23., + 27., 28., 29., + 30., 31., 32.}; + + // append same data but in reverse order so we can tell if the batch output is wrong + X.insert(X.end(), X.rbegin(), X.rend()); + + test.AddInput("input", {N, C, H, W}, X); + + std::vector result = {0., 9., 1., 10., 2., 11., + 18., 27., 19., 28., 20., 29., + 3., 12., 4., 13., 5., 14., + 21., 30., 22., 31., 23., 32.}; + result.insert(result.end(), result.rbegin(), result.rend()); + + test.AddOutput("output", {2, 1, 4, 6}, result); + test.Run(); +} + } // namespace test } // namespace onnxruntime diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 5609033fc3e3..5af9a2512f7c 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -6,13 +6,16 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Add|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Clip|| +|ai:onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| |ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| +|ai.onnx.DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| |ai.onnx:Div|| |ai.onnx:Gemm|Input B must be constant.| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:GridSample|4D input.
'mode' of 'linear' or 'zeros'.
(mode==linear && padding_mode==reflection && align_corners==0) is not supported.| +|ai.onnx.LeakyRelu|| |ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.| |ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Mul|| From 464c83604c968b35e37647c88bc0107f9062a297 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 24 Jul 2024 20:20:18 +1000 Subject: [PATCH 2/8] Fix QNN failure. Drop SkipLayerNormFusion change - need to investigate test failures. Update session state and allocation planner to handle ORT format models where an EP drops constant initializers. --- .../core/framework/allocation_planner.cc | 32 ++++++++++++++++--- .../core/framework/session_state_utils.cc | 15 ++++++++- .../core/optimizer/skip_layer_norm_fusion.cc | 8 ++--- .../providers/qnn/qnn_execution_provider.cc | 3 +- 4 files changed, 47 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5dca4cf6c165..737f3f606ba8 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -225,7 +225,8 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), + "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -642,9 +643,21 @@ class PlannerImpl { } // All initializers should be treated as input + // + // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. + // Those initializers become unused so don't end up in ort_value_name_idx_map_, but as we don't run + // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. + // We can ignore lookup failures in this case. + const bool unresolved_graph = graph_viewer_.GetGraph().GraphResolveNeeded(); for (const auto& pair : graph_viewer_.GetAllInitializedTensors()) { const auto& initializer_name = pair.first; - UseCount(initializer_name)++; + OrtValueIndex index = -1; + auto status = ort_value_name_idx_map_.GetIdx(initializer_name, index); + if (status.IsOK()) { + UseCount(initializer_name)++; + } else { + ORT_ENFORCE(unresolved_graph, status.ErrorMessage()); + } } for (auto& stream_execution_order : stream_nodes_) { @@ -709,10 +722,21 @@ class PlannerImpl { } // All initializers should be treated as input + // + // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. + // Those initializers become unused so don't end up in ort_value_name_idx_map_, but as we don't run + // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. + // We can ignore lookup failures in this case. + const bool unresolved_graph = graph_viewer_.GetGraph().GraphResolveNeeded(); for (const auto& pair : graph_viewer_.GetAllInitializedTensors()) { const auto& initializer_name = pair.first; - OrtValueIndex index = Index(initializer_name); - ProcessDef(index, graph_viewer_.GetNodeArg(pair.first)); + OrtValueIndex index = -1; + auto status = ort_value_name_idx_map_.GetIdx(initializer_name, index); + if (status.IsOK()) { + ProcessDef(index, graph_viewer_.GetNodeArg(initializer_name)); + } else { + ORT_ENFORCE(unresolved_graph, status.ErrorMessage()); + } } InlinedHashSet set_node_arg_has_explicit_consumer; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 059de8e3c8c4..f3890399ce4f 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -228,9 +228,22 @@ common::Status SaveInitializedTensors( id_to_initialized_tensor.reserve(initialized_tensor_set.size()); user_supplied_initializer_ids.reserve(initialized_tensor_set.size()); + // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. + // Those initializers become unused so don't end up in ort_value_name_idx_map, but as we don't run + // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. + // We can ignore lookup failures in this case. + const bool unresolved_graph = graph.GetGraph().GraphResolveNeeded(); for (const auto& entry : initialized_tensor_set) { int ort_value_index; - ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index)); + + if (auto status = ort_value_name_idx_map.GetIdx(entry.first, ort_value_index); !status.IsOK()) { + if (unresolved_graph) { + continue; + } + + return status; + } + if (use_user_supplied_initializer(entry.first)) { user_supplied_initializer_ids.insert(ort_value_index); } diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index d8f49124a2fa..cf70a7d821d7 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -168,8 +168,7 @@ Note: This fusion doesn't consider the following case: LayerNormalization */ -Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); InlinedVector> nodes_to_remove; @@ -300,14 +299,13 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le // Assign provider to this new node. Provider should be same as the provider for old node. skip_layer_norm_node.SetExecutionProviderType(ln_node.GetExecutionProviderType()); } - - modified = !nodes_to_remove.empty(); - for (const auto& node : nodes_to_remove) { graph_utils::RemoveNodeOutputEdges(graph, node); graph.RemoveNode(node.get().Index()); } + modified = true; + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index e90671cf2e77..51920dd5e21a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -565,7 +565,8 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, supported_groups.begin(), supported_groups.end(), std::back_inserter(result), [&](const auto& supported_partition) { - return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN); + return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN, + /*drop_constant_initializers*/false); // TODO: could this be set to true? }); const size_t num_of_partitions = result.size(); From 63b23faa5413d02c6a48a14807f78af816bacad9 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 24 Jul 2024 21:45:22 +1000 Subject: [PATCH 3/8] waste of time --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 51920dd5e21a..539b456cb657 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -566,7 +566,7 @@ static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, std::back_inserter(result), [&](const auto& supported_partition) { return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN, - /*drop_constant_initializers*/false); // TODO: could this be set to true? + /*drop_constant_initializers*/ false); // TODO: could this be set to true? }); const size_t num_of_partitions = result.size(); From a1ac35e872a9509410b8327267cd8863452f20da Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 25 Jul 2024 10:59:41 +1000 Subject: [PATCH 4/8] Address PR comments. Refine the handling of initailizers that become unused in an ORT format model --- .lintrunner.toml | 1 + include/onnxruntime/core/graph/graph.h | 24 ++++++++------ .../core/framework/allocation_planner.cc | 32 +++---------------- .../core/framework/session_state_utils.cc | 15 +-------- onnxruntime/core/graph/graph.cc | 12 +++++++ .../coreml/builders/impl/builder_utils.cc | 12 +++---- .../coreml/builders/impl/builder_utils.h | 6 ++-- .../coreml/builders/impl/concat_op_builder.cc | 9 +++--- .../builders/impl/gridsample_op_builder.cc | 4 +-- onnxruntime/core/session/inference_session.cc | 5 +++ .../apple/coreml_supported_mlprogram_ops.md | 4 +-- 11 files changed, 55 insertions(+), 69 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index e6d06b34726f..cdc5d5dc469c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -137,6 +137,7 @@ exclude_patterns = [ 'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks 'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang + 'onnxruntime/core/providers/coreml/mlprogram_test_scripts', # test scripts only ] command = [ 'python', diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 9289e14c17dd..c51f38553c3b 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -1408,6 +1408,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi RuntimeOptimizationRecordContainer& MutableRuntimeOptimizations() { return runtime_optimizations_; } + + // We don't run Graph::Resolve() on an ORT format model, but a compiling EP may copy initializers to its + // compiled model during partitioning, leaving them unused in the ORT Graph. To allow the memory to be freed + // we need to manually run the cleanup that would usually happen as part of Graph::Resolve. + Status RemovedUnusedInitializersOrtFormat(); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // This friendship relationship should only be used to call Graph::Graph and @@ -1541,12 +1546,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi common::Status PerformTypeAndShapeInferencing(const ResolveOptions& options); - // Recursively find all subgraphs including nested subgraphs - void FindAllSubgraphs(std::vector& subgraphs); - - // Iterate this Graph instance and all subgraphs, calling the provided function for each. - common::Status ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func); - common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op, const ResolveOptions& options); // perform type and shape inferencing on the subgraph and Resolve to validate @@ -1576,9 +1575,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - // Clear all unused initializers and NodeArgs - void CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve = nullptr); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, const ArgNameToTypeMap& name_to_type_map); @@ -1587,6 +1583,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + + // Recursively find all subgraphs including nested subgraphs + void FindAllSubgraphs(std::vector& subgraphs); + + // Iterate this Graph instance and all subgraphs, calling the provided function for each. + common::Status ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func); + + // Clear all unused initializers and NodeArgs + void CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve = nullptr); + Status PopulateNodeArgToProducerConsumerLookupsFromNodes(); template diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 737f3f606ba8..5dca4cf6c165 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -225,8 +225,7 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), - "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -643,21 +642,9 @@ class PlannerImpl { } // All initializers should be treated as input - // - // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. - // Those initializers become unused so don't end up in ort_value_name_idx_map_, but as we don't run - // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. - // We can ignore lookup failures in this case. - const bool unresolved_graph = graph_viewer_.GetGraph().GraphResolveNeeded(); for (const auto& pair : graph_viewer_.GetAllInitializedTensors()) { const auto& initializer_name = pair.first; - OrtValueIndex index = -1; - auto status = ort_value_name_idx_map_.GetIdx(initializer_name, index); - if (status.IsOK()) { - UseCount(initializer_name)++; - } else { - ORT_ENFORCE(unresolved_graph, status.ErrorMessage()); - } + UseCount(initializer_name)++; } for (auto& stream_execution_order : stream_nodes_) { @@ -722,21 +709,10 @@ class PlannerImpl { } // All initializers should be treated as input - // - // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. - // Those initializers become unused so don't end up in ort_value_name_idx_map_, but as we don't run - // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. - // We can ignore lookup failures in this case. - const bool unresolved_graph = graph_viewer_.GetGraph().GraphResolveNeeded(); for (const auto& pair : graph_viewer_.GetAllInitializedTensors()) { const auto& initializer_name = pair.first; - OrtValueIndex index = -1; - auto status = ort_value_name_idx_map_.GetIdx(initializer_name, index); - if (status.IsOK()) { - ProcessDef(index, graph_viewer_.GetNodeArg(initializer_name)); - } else { - ORT_ENFORCE(unresolved_graph, status.ErrorMessage()); - } + OrtValueIndex index = Index(initializer_name); + ProcessDef(index, graph_viewer_.GetNodeArg(pair.first)); } InlinedHashSet set_node_arg_has_explicit_consumer; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index f3890399ce4f..059de8e3c8c4 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -228,22 +228,9 @@ common::Status SaveInitializedTensors( id_to_initialized_tensor.reserve(initialized_tensor_set.size()); user_supplied_initializer_ids.reserve(initialized_tensor_set.size()); - // Special case: ORT format model where an EP takes nodes and copies initializers into the compiled model. - // Those initializers become unused so don't end up in ort_value_name_idx_map, but as we don't run - // Graph::Resolve with an ORT format model they will still exist in GetAllInitializedTensors. - // We can ignore lookup failures in this case. - const bool unresolved_graph = graph.GetGraph().GraphResolveNeeded(); for (const auto& entry : initialized_tensor_set) { int ort_value_index; - - if (auto status = ort_value_name_idx_map.GetIdx(entry.first, ort_value_index); !status.IsOK()) { - if (unresolved_graph) { - continue; - } - - return status; - } - + ORT_RETURN_IF_ERROR(ort_value_name_idx_map.GetIdx(entry.first, ort_value_index)); if (use_user_supplied_initializer(entry.first)) { user_supplied_initializer_ids.insert(ort_value_index); } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 442a0db933d6..5f539b3da595 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3387,6 +3387,18 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor.name(), &t)); } } + +Status Graph::RemovedUnusedInitializersOrtFormat() { + std::vector all_subgraphs; + FindAllSubgraphs(all_subgraphs); + auto cleanup_func = [](Graph& graph) { + graph.CleanUnusedInitializersAndNodeArgs(nullptr); + return Status::OK(); + }; + + auto result = ForThisAndAllSubgraphs(all_subgraphs, cleanup_func); + return result; +} #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const std::string& Graph::Name() const noexcept { diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index f1cfbd305443..e02186d3aee8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -309,26 +309,26 @@ COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& n void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { MILSpec::Argument arg; - arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + arg.mutable_arguments()->Add()->set_name(value_name.data(), value_name.size()); (*op.mutable_inputs())[input_name] = std::move(arg); } -void AddOperationInputs(MILSpec::Operation& op, std::string_view input_name, - const std::vector& value_names) { +void AddOperationVariadicInput(MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names) { MILSpec::Argument arg; for (const auto& value : value_names) { - arg.mutable_arguments()->Add()->set_name(std::string(value)); + arg.mutable_arguments()->Add()->set_name(value.data(), value.size()); } (*op.mutable_inputs())[input_name] = std::move(arg); } -void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, const std::string& output_name, +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::string_view output_name, int32_t element_type, std::optional> shape) { auto& outputs = *op.mutable_outputs(); auto& output_arg = *outputs.Add(); - output_arg.set_name(output_name); + output_arg.set_name(output_name.data(), output_name.size()); MILSpec::ValueType& value = *output_arg.mutable_type(); MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 25e30577cf1e..475ce79b0a81 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -135,8 +135,8 @@ void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, /// Operation to update. /// The input name defined by the spec for the operation. /// The input value names. -void AddOperationInputs(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, - const std::vector& value_names); +void AddOperationVariadicInput(COREML_SPEC::MILSpec::Operation& op, std::string_view input_name, + const std::vector& value_names); /// Add an output to a MILSpec::Operation for an intermediate operation when the implementation is composed of /// multiple MLProgram operations. In this case we don't have a NodeArg for the output. @@ -146,7 +146,7 @@ void AddOperationInputs(COREML_SPEC::MILSpec::Operation& op, std::string_view in /// onnx::TensorProto_DataType element type of the output. /// int32_t as that is what TensorShapeProto uses to store the value. /// Shape of the output if known. -void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, const std::string& output_name, +void AddIntermediateOperationOutput(COREML_SPEC::MILSpec::Operation& op, std::string_view output_name, int32_t element_type, std::optional> shape); /// diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index 551d8222cc06..9ea0030290ab 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -39,14 +39,13 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, for (const auto* input : node.InputDefs()) { input_names.emplace_back(input->Name()); } - AddOperationInputs(*op, "values", input_names); + AddOperationVariadicInput(*op, "values", input_names); AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", *axis)); AddOperationInput(*op, "interleave", model_builder.AddScalarConstant(op->type(), "interleave", interleave)); AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); - - } else -#endif // defined(COREML_ENABLE_MLPROGRAM) + } else // NOLINT +#endif // defined(COREML_ENABLE_MLPROGRAM) { std::unique_ptr layer = model_builder.CreateNNLayer(node); @@ -82,7 +81,7 @@ bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa // For some reason, the concat in CoreML running on 3d tensor will concat on wrong axis // Instead of concat on axis 0, it will concat on axis 1 // Disable Concat support for 3d tensor for now - // TODO, add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d + // TODO: add ExpandDims and Squeeze, 3d -ExpandDims-> 4d -> Concat -Squeeze-> 3d LOGS(logger, VERBOSE) << "Concat only support 4d shape for now, input is " << rank << "d shape"; return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc index bfc665e0ac71..9caec290ea5a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gridsample_op_builder.cc @@ -19,8 +19,8 @@ std::string_view GetMode(const NodeAttrHelper& helper) { // opset 20+ uses linear, nearest, cubic // bilinear is what CoreML uses, so prefer that // bicubic/cubic isn't supported - - const auto& mode = helper.Get("mode", "linear"); + static const std::string default_mode = "linear"; // static in case we ever return the default as a string_view + const auto& mode = helper.Get("mode", default_mode); if (mode == "linear") { return "bilinear"; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cc3a9943ca0a..5ad2f0846779 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1603,6 +1603,11 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, logger, GraphPartitioner::Mode::kOrtFormatLoad)); +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + // a compiling EP (e.g. CoreML) may copy initializers to its own memory. run the cleanup of unused initializers + // so that they can be freed. + ORT_RETURN_IF_ERROR(graph.RemovedUnusedInitializersOrtFormat()); +#endif return Status::OK(); } diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 5af9a2512f7c..d2a961f17bd6 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -6,7 +6,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Add|| |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Clip|| -|ai:onnx:Concat|| +|ai.onnx:Concat|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| |ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| |ai.onnx.DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.| @@ -27,4 +27,4 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:Sub|| |ai.onnx:Sigmoid|| |ai:onnx:Tanh|| -|ai:onnx:Transpose|| +|ai.onnx:Transpose|| From dcf0c17ca796268392dd34e98d7813e198d60608 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 25 Jul 2024 12:25:54 +1000 Subject: [PATCH 5/8] Fix exclude --- .lintrunner.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index cdc5d5dc469c..e1b24b2955b0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -137,7 +137,7 @@ exclude_patterns = [ 'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks 'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang - 'onnxruntime/core/providers/coreml/mlprogram_test_scripts', # test scripts only + 'onnxruntime/core/providers/coreml/mlprogram_test_scripts/**', # test scripts only ] command = [ 'python', From 84854c8c470ecb5fe6044e5b69236e1a57ea115f Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 25 Jul 2024 13:57:49 +1000 Subject: [PATCH 6/8] Fix inclusion of CleanUnusedInitializersAndNodeArgs --- onnxruntime/core/graph/graph.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 5f539b3da595..7e48e2f86a7b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -4134,6 +4134,9 @@ void Graph::ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const } } +#endif // !defined(ORT_MINIMAL_BUILD) + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set* initializer_names_to_preserve) { // Node Args being used std::unordered_set used_args; @@ -4265,8 +4268,7 @@ void Graph::CleanUnusedInitializersAndNodeArgs(const std::unordered_set Date: Thu, 25 Jul 2024 15:06:07 +1000 Subject: [PATCH 7/8] Move the other required funcs --- onnxruntime/core/graph/graph.cc | 42 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 7e48e2f86a7b..e950d68947b9 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3254,27 +3254,6 @@ Status Graph::PerformTypeAndShapeInferencing(const ResolveOptions& options) { return Status::OK(); } -void Graph::FindAllSubgraphs(std::vector& subgraphs) { - for (auto& node : Nodes()) { - for (auto& subgraph : node.MutableSubgraphs()) { - subgraphs.push_back(subgraph.get()); - subgraph->FindAllSubgraphs(subgraphs); - } - } -} - -Status Graph::ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func) { - auto status = func(*this); - ORT_RETURN_IF_ERROR(status); - - for (auto& subgraph : subgraphs) { - status = func(*subgraph); - ORT_RETURN_IF_ERROR(status); - } - - return status; -} - Status Graph::Resolve(const ResolveOptions& options) { if (parent_graph_) { // Resolve must start at the top level graph in-order to handle outer scope @@ -3388,6 +3367,27 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { } } +void Graph::FindAllSubgraphs(std::vector& subgraphs) { + for (auto& node : Nodes()) { + for (auto& subgraph : node.MutableSubgraphs()) { + subgraphs.push_back(subgraph.get()); + subgraph->FindAllSubgraphs(subgraphs); + } + } +} + +Status Graph::ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func) { + auto status = func(*this); + ORT_RETURN_IF_ERROR(status); + + for (auto& subgraph : subgraphs) { + status = func(*subgraph); + ORT_RETURN_IF_ERROR(status); + } + + return status; +} + Status Graph::RemovedUnusedInitializersOrtFormat() { std::vector all_subgraphs; FindAllSubgraphs(all_subgraphs); From 9193172e58c429d6b3f48d5d0947a0ad98572c68 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 25 Jul 2024 20:14:58 +1000 Subject: [PATCH 8/8] Fix python lint --- .../core/providers/coreml/mlprogram_test_scripts/resize_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py index 3cfe2658f945..f83dc6ddfe02 100644 --- a/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py +++ b/onnxruntime/core/providers/coreml/mlprogram_test_scripts/resize_test.py @@ -11,7 +11,7 @@ @mb.program(input_specs=[mb.TensorSpec(shape=x_shape)], opset_version=target) def prog(x): - global use_scale + global use_scale # noqa if use_scale: align = mb.const(val=False)