diff --git a/onnxsim/onnxsim.cpp b/onnxsim/onnxsim.cpp index f2024a5..064a80c 100644 --- a/onnxsim/onnxsim.cpp +++ b/onnxsim/onnxsim.cpp @@ -243,6 +243,14 @@ std::vector RunOp(onnx::ModelProto& model, const onnx::NodeProto& op) { std::vector input_names; std::vector input_tps; + std::set initializer_names; + + onnx::ModelProto op_model; + op_model.set_ir_version(model.ir_version()); + for (const auto& x : model.opset_import()) { + *op_model.add_opset_import() = x; + } + *op_model.mutable_graph()->add_node() = op; for (const auto& input : op.input()) { if (std::find(input_names.begin(), input_names.end(), input) != @@ -253,16 +261,19 @@ std::vector RunOp(onnx::ModelProto& model, if (input.empty()) { continue; } - input_names.push_back(input); + if (initializer_names.find(input) != initializer_names.end()) { + continue; + } auto in_tp = FindInitializerByName(model, input); + if (in_tp.dims().size() == 1 && in_tp.dims()[0] == 0) { + initializer_names.insert(input); + *op_model.mutable_graph()->add_initializer() = in_tp; + continue; + } + input_names.push_back(input); input_tps.push_back(in_tp); } - onnx::ModelProto op_model; - op_model.set_ir_version(model.ir_version()); - for (const auto& x : model.opset_import()) { - *op_model.add_opset_import() = x; - } - *op_model.mutable_graph()->add_node() = op; + for (const auto& x : input_names) { // skip "" which represents the unset optional input if (x.empty()) {