From 69c9f58480599f39092de04305536b76ed766925 Mon Sep 17 00:00:00 2001 From: Xyzhao Date: Thu, 17 Aug 2023 17:18:33 +0800 Subject: [PATCH 1/2] fix: Preserve empty initializer inputs to avoid misjudgment of parameters by some nodes --- onnxsim/onnxsim.cpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/onnxsim/onnxsim.cpp b/onnxsim/onnxsim.cpp index f2024a5..7a2f288 100644 --- a/onnxsim/onnxsim.cpp +++ b/onnxsim/onnxsim.cpp @@ -244,6 +244,13 @@ std::vector RunOp(onnx::ModelProto& model, std::vector input_names; std::vector input_tps; + onnx::ModelProto op_model; + op_model.set_ir_version(model.ir_version()); + for (const auto& x : model.opset_import()) { + *op_model.add_opset_import() = x; + } + *op_model.mutable_graph()->add_node() = op; + for (const auto& input : op.input()) { if (std::find(input_names.begin(), input_names.end(), input) != input_names.end()) { @@ -253,16 +260,16 @@ std::vector RunOp(onnx::ModelProto& model, if (input.empty()) { continue; } - input_names.push_back(input); + auto in_tp = FindInitializerByName(model, input); + if (in_tp.dims().size() == 1 && in_tp.dims()[0] == 0) { + *op_model.mutable_graph()->add_initializer() = in_tp; + continue; + } + input_names.push_back(input); input_tps.push_back(in_tp); } - onnx::ModelProto op_model; - op_model.set_ir_version(model.ir_version()); - for (const auto& x : model.opset_import()) { - *op_model.add_opset_import() = x; - } - *op_model.mutable_graph()->add_node() = op; + for (const auto& x : input_names) { // skip "" which represents the unset optional input if (x.empty()) { From 685d7c78b101620f0a1c1f921fff560c1335e9dd Mon Sep 17 00:00:00 2001 From: Xyzhao Date: Fri, 18 Aug 2023 10:18:30 +0800 Subject: [PATCH 2/2] fix: Avoid repeatedly adding initializer --- onnxsim/onnxsim.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxsim/onnxsim.cpp b/onnxsim/onnxsim.cpp index 7a2f288..064a80c 100644 --- a/onnxsim/onnxsim.cpp +++ b/onnxsim/onnxsim.cpp @@ -243,6 +243,7 @@ std::vector RunOp(onnx::ModelProto& model, const onnx::NodeProto& op) { std::vector input_names; std::vector input_tps; + std::set initializer_names; onnx::ModelProto op_model; op_model.set_ir_version(model.ir_version()); @@ -260,9 +261,12 @@ std::vector RunOp(onnx::ModelProto& model, if (input.empty()) { continue; } - + if (initializer_names.find(input) != initializer_names.end()) { + continue; + } auto in_tp = FindInitializerByName(model, input); if (in_tp.dims().size() == 1 && in_tp.dims()[0] == 0) { + initializer_names.insert(input); *op_model.mutable_graph()->add_initializer() = in_tp; continue; }