From 6adaad288cac7bfebc22cc23659b5db2f87e3d24 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Wed, 18 Sep 2024 17:06:51 +0800 Subject: [PATCH 1/6] wip --- paddle2onnx/mapper/exporter.cc | 65 ++++++++++++++++++++++++++++++++-- paddle2onnx/mapper/exporter.h | 5 +++ tests/test_ifelse.py | 34 +++++++++++++++--- 3 files changed, 98 insertions(+), 6 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 3917d0888..dd4de145e 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -19,6 +19,7 @@ #include +#include "onnx_helper.h" #include "onnxoptimizer/optimize.h" #include "paddle2onnx/optimizer/convert_fp32_to_fp16.h" #include "paddle2onnx/optimizer/eliminate_non_transpose.h" @@ -315,9 +316,42 @@ namespace paddle2onnx } temp_outputs.push_back(std::move(MakeValueInfo(out_info[index]))); } + std::cout << "Enter ExportConditionalBlock" << std::endl; return std::move(ExportBlock(parser, sub_block_idx, temp_parameters, temp_inputs, temp_outputs)); } + ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant(const PaddleParser &parser, + int32_t block_id, + int32_t op_id, + const std::string &output_names, + const std::string &out_name) + { + ONNX_NAMESPACE::GraphProto graph; + graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); + auto op = parser.GetOpDesc(block_id, op_id); // fill_constant + OnnxHelper temp_helper; + + std::vector> temp_inputs; + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + std::cout << "target output name: " << output_names << std::endl; + temp_inputs.push_back(std::move(MakeValueInfo(out_info[0]))); + + + auto node = temp_helper.MakeNode("Identity", {output_names}, {out_name}); + + *(graph.add_input()) = (*MakeValueInfo(out_info[0])); + *(graph.add_output()) = (*MakeValueInfo(out_info[0])); + // ONNX_NAMESPACE::ValueInfoProto value_info = *graph.add_output(); + // value_info.set_name(out_name); + *(graph.add_node()) = (*node); + + for (auto &item : temp_helper.value_infos) + { + *(graph.add_value_info()) = (*item.get()); + } + return std::move(graph); + } + ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock(const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, @@ -333,6 +367,7 @@ namespace paddle2onnx for (auto op_id = 0; op_id < num_ops; ++op_id) { auto op = parser.GetOpDesc(block_id, op_id); + std::cout <<"op name: "<< op.type() << std::endl; if (op.type() == "feed") { continue; @@ -363,22 +398,48 @@ namespace paddle2onnx auto conditional_block_cood_it = sub_block_map_.find(else_node_name); Assert(conditional_block_cood_it != sub_block_map_.end(), "Don't find select_input else_input node."); auto conditional_block_cood = conditional_block_cood_it->second; - auto else_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name); + ONNX_NAMESPACE::GraphProto else_graph, then_graph; + auto else_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); + if (else_node.type().find("conditional_block") != std::string::npos) { + else_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name); + std::cout << "detect conditional_block" << std::endl; + } else { + std::string output_name = MapperHelper::Get()->GenName("fill_constant.identity"); + else_graph = ExportFillConstant(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name, output_name); + std::cout << "detect fill_constant" << std::endl; + // *(op -> mutable_input) = output_name; + } // 构建 then 分支图 auto then_node_name = input_info[1].name; conditional_block_cood_it = sub_block_map_.find(then_node_name); Assert(conditional_block_cood_it != sub_block_map_.end(), "Don't find select_input then_input node."); conditional_block_cood = conditional_block_cood_it->second; - auto then_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name); + auto then_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); + if (then_node.type().find("conditional_block") != std::string::npos) { + then_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name); + std::cout << "detect conditional_block" << std::endl; + } else { + std::string output_name = MapperHelper::Get()->GenName("fill_constant.identity"); + then_graph = ExportFillConstant(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name, output_name); + std::cout << "detect fill_constant" << std::endl; + // *(op -> mutable_input + 1) = output_name; + } + std::cout << "else_node_name: " << else_node_name << std::endl; + std::cout << "then_node_name: " << then_node_name << std::endl; auto cond_info = parser.GetOpInput(block_id, op_id, "Mask"); auto output_info = parser.GetOpOutput(block_id, op_id, "Out"); auto cond_name = temp_helper.AutoCast(cond_info[0].name, cond_info[0].dtype, P2ODataType::BOOL); + std::cout << "cond_name: " << cond_name << std::endl; auto node = temp_helper.MakeNode("If", {cond_name}, {output_info[0].name}); AddAttribute(node, "then_branch", then_graph); AddAttribute(node, "else_branch", else_graph); continue; + } else if (op.type() == "fill_constant") + { + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + sub_block_map_[out_info[0].name] = {block_id, op_id}; } ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_); } diff --git a/paddle2onnx/mapper/exporter.h b/paddle2onnx/mapper/exporter.h index 9467a9cde..64ad60bce 100644 --- a/paddle2onnx/mapper/exporter.h +++ b/paddle2onnx/mapper/exporter.h @@ -110,6 +110,11 @@ namespace paddle2onnx int32_t block_id, int32_t op_id, const std::string &output_names); + ONNX_NAMESPACE::GraphProto ExportFillConstant(const PaddleParser &parser, + int32_t block_id, + int32_t op_id, + const std::string &output_names, + const std::string &out_name); ONNX_NAMESPACE::GraphProto ExportBlock(const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index b04079843..47e143ca5 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -64,8 +64,34 @@ def test_ifelse_2_false(): obj.set_input_data("input_data", paddle.to_tensor(2), paddle.to_tensor(1)) obj.run() +class BaseNet3(paddle.nn.Layer): + def __init__(self): + super(BaseNet3, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return 1 + else: + return 2 + +def test_ifelse_3_true(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_3_false(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + if __name__ == "__main__": - test_ifelse_1_true() - test_ifelse_1_false() - test_ifelse_2_true() - test_ifelse_2_false() \ No newline at end of file + # test_ifelse_1_true() + # test_ifelse_1_false() + # test_ifelse_2_true() + # test_ifelse_2_false() + test_ifelse_3_true() + test_ifelse_3_false() From 495d8883eaba47652709366485f5ae0f61f270d2 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Wed, 18 Sep 2024 23:54:41 +0800 Subject: [PATCH 2/6] fix --- paddle2onnx/mapper/exporter.cc | 45 +++++++++------------------------- paddle2onnx/mapper/exporter.h | 4 +-- tests/onnxbase.py | 6 ++++- tests/test_ifelse.py | 34 ++++++++++++++++++++++--- 4 files changed, 49 insertions(+), 40 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 5027f76de..de16780ce 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -321,39 +321,30 @@ namespace paddle2onnx } temp_outputs.push_back(std::move(MakeValueInfo(out_info[index]))); } - std::cout << "Enter ExportConditionalBlock" << std::endl; return std::move(ExportBlock(parser, sub_block_idx, temp_parameters, temp_inputs, temp_outputs)); } ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant(const PaddleParser &parser, + OnnxHelper *temp_helper, int32_t block_id, int32_t op_id, - const std::string &output_names, - const std::string &out_name) + const std::string &output_names) { ONNX_NAMESPACE::GraphProto graph; graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); auto op = parser.GetOpDesc(block_id, op_id); // fill_constant - OnnxHelper temp_helper; - std::vector> temp_inputs; auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); - std::cout << "target output name: " << output_names << std::endl; - temp_inputs.push_back(std::move(MakeValueInfo(out_info[0]))); - - auto node = temp_helper.MakeNode("Identity", {output_names}, {out_name}); - - *(graph.add_input()) = (*MakeValueInfo(out_info[0])); *(graph.add_output()) = (*MakeValueInfo(out_info[0])); - // ONNX_NAMESPACE::ValueInfoProto value_info = *graph.add_output(); - // value_info.set_name(out_name); - *(graph.add_node()) = (*node); - for (auto &item : temp_helper.value_infos) - { - *(graph.add_value_info()) = (*item.get()); + for (auto &item: temp_helper->nodes) { + if (item -> output(0) == output_names) { + *(graph.add_node()) = (*item.get()); + break; + } } + return std::move(graph); } @@ -372,7 +363,6 @@ namespace paddle2onnx for (auto op_id = 0; op_id < num_ops; ++op_id) { auto op = parser.GetOpDesc(block_id, op_id); - std::cout <<"op name: "<< op.type() << std::endl; if (op.type() == "feed") { continue; @@ -401,42 +391,31 @@ namespace paddle2onnx // 构建 else 分支图 auto else_node_name = input_info[0].name; auto conditional_block_cood_it = sub_block_map_.find(else_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), "Don't find select_input else_input node."); + Assert(conditional_block_cood_it != sub_block_map_.end(), "Con't find select_input else_input node."); auto conditional_block_cood = conditional_block_cood_it->second; ONNX_NAMESPACE::GraphProto else_graph, then_graph; auto else_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); if (else_node.type().find("conditional_block") != std::string::npos) { else_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name); - std::cout << "detect conditional_block" << std::endl; } else { - std::string output_name = MapperHelper::Get()->GenName("fill_constant.identity"); - else_graph = ExportFillConstant(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name, output_name); - std::cout << "detect fill_constant" << std::endl; - // *(op -> mutable_input) = output_name; + else_graph = ExportFillConstant(parser, &temp_helper, conditional_block_cood.first, conditional_block_cood.second, else_node_name); } // 构建 then 分支图 auto then_node_name = input_info[1].name; conditional_block_cood_it = sub_block_map_.find(then_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), "Don't find select_input then_input node."); + Assert(conditional_block_cood_it != sub_block_map_.end(), "Con't find select_input then_input node."); conditional_block_cood = conditional_block_cood_it->second; auto then_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); if (then_node.type().find("conditional_block") != std::string::npos) { then_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name); - std::cout << "detect conditional_block" << std::endl; } else { - std::string output_name = MapperHelper::Get()->GenName("fill_constant.identity"); - then_graph = ExportFillConstant(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name, output_name); - std::cout << "detect fill_constant" << std::endl; - // *(op -> mutable_input + 1) = output_name; + then_graph = ExportFillConstant(parser, &temp_helper, conditional_block_cood.first, conditional_block_cood.second, then_node_name); } - std::cout << "else_node_name: " << else_node_name << std::endl; - std::cout << "then_node_name: " << then_node_name << std::endl; auto cond_info = parser.GetOpInput(block_id, op_id, "Mask"); auto output_info = parser.GetOpOutput(block_id, op_id, "Out"); auto cond_name = temp_helper.AutoCast(cond_info[0].name, cond_info[0].dtype, P2ODataType::BOOL); - std::cout << "cond_name: " << cond_name << std::endl; auto node = temp_helper.MakeNode("If", {cond_name}, {output_info[0].name}); AddAttribute(node, "then_branch", then_graph); AddAttribute(node, "else_branch", else_graph); diff --git a/paddle2onnx/mapper/exporter.h b/paddle2onnx/mapper/exporter.h index 64ad60bce..c111254c0 100644 --- a/paddle2onnx/mapper/exporter.h +++ b/paddle2onnx/mapper/exporter.h @@ -111,10 +111,10 @@ namespace paddle2onnx int32_t op_id, const std::string &output_names); ONNX_NAMESPACE::GraphProto ExportFillConstant(const PaddleParser &parser, + OnnxHelper *temp_helper, int32_t block_id, int32_t op_id, - const std::string &output_names, - const std::string &out_name); + const std::string &output_names); ONNX_NAMESPACE::GraphProto ExportBlock(const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, diff --git a/tests/onnxbase.py b/tests/onnxbase.py index ca5016ab6..32a1e2a78 100755 --- a/tests/onnxbase.py +++ b/tests/onnxbase.py @@ -64,7 +64,11 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): # Convert Paddle Tensor to Numpy array if type(expect) == list: expect = expect[0] - expect = expect.numpy() + + if isinstance(expect, paddle.Tensor): + expect = expect.numpy() + else: + expect = np.array(expect) # For result_shape is (1) and expect_shape shape is () expect = expect.squeeze() diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index 47e143ca5..6efa6e3ae 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -88,10 +88,36 @@ def test_ifelse_3_false(): obj.set_input_data("input_data", paddle.to_tensor(2)) obj.run() +class BaseNet4(paddle.nn.Layer): + def __init__(self): + super(BaseNet4, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return inputs + 1 + else: + return 2 + +def test_ifelse_4_true(): + op = BaseNet4() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_4_false(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + if __name__ == "__main__": - # test_ifelse_1_true() - # test_ifelse_1_false() - # test_ifelse_2_true() - # test_ifelse_2_false() + test_ifelse_1_true() + test_ifelse_1_false() + test_ifelse_2_true() + test_ifelse_2_false() test_ifelse_3_true() test_ifelse_3_false() + test_ifelse_4_true() + test_ifelse_4_false() From 7f0375bf1708f82784633ced51425d66437a86e7 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Thu, 19 Sep 2024 11:10:10 +0800 Subject: [PATCH 3/6] update due to comment --- paddle2onnx/mapper/exporter.cc | 26 ++++++++++++-------------- tests/test_ifelse.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index de16780ce..07ddd1a20 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -19,7 +19,6 @@ #include -#include "onnx_helper.h" #include "onnxoptimizer/optimize.h" #include "paddle2onnx/optimizer/convert_fp32_to_fp16.h" #include "paddle2onnx/optimizer/eliminate_non_transpose.h" @@ -86,7 +85,7 @@ namespace paddle2onnx { return true; } - + auto logger = P2OLogger(); logger << "Oops, there are some operators not supported yet, including "; for (auto &item : unsupported_ops) @@ -325,19 +324,17 @@ namespace paddle2onnx } ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant(const PaddleParser &parser, - OnnxHelper *temp_helper, - int32_t block_id, - int32_t op_id, - const std::string &output_names) + OnnxHelper *temp_helper, + int32_t block_id, + int32_t op_id, + const std::string &output_names) { ONNX_NAMESPACE::GraphProto graph; graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); auto op = parser.GetOpDesc(block_id, op_id); // fill_constant - auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); *(graph.add_output()) = (*MakeValueInfo(out_info[0])); - for (auto &item: temp_helper->nodes) { if (item -> output(0) == output_names) { *(graph.add_node()) = (*item.get()); @@ -382,31 +379,32 @@ namespace paddle2onnx } else if (op.type() == "select_input") { - // 如果找到,则输出对应的值;否则输出错误信息 - // 遍历输入Tensor auto input_info = parser.GetOpInput(block_id, op_id, "X"); Assert(input_info.size() == 2, "Only support when number of select_input's input_node is 2."); - // 构建 else 分支图 + // Build else sub graph auto else_node_name = input_info[0].name; auto conditional_block_cood_it = sub_block_map_.find(else_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), "Con't find select_input else_input node."); + Assert(conditional_block_cood_it != sub_block_map_.end(), "Can't find select_input else_input node."); auto conditional_block_cood = conditional_block_cood_it->second; ONNX_NAMESPACE::GraphProto else_graph, then_graph; auto else_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); + if (else_node.type().find("conditional_block") != std::string::npos) { else_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, else_node_name); } else { else_graph = ExportFillConstant(parser, &temp_helper, conditional_block_cood.first, conditional_block_cood.second, else_node_name); } - // 构建 then 分支图 + // Build then sub graph auto then_node_name = input_info[1].name; conditional_block_cood_it = sub_block_map_.find(then_node_name); - Assert(conditional_block_cood_it != sub_block_map_.end(), "Con't find select_input then_input node."); + Assert(conditional_block_cood_it != sub_block_map_.end(), "Can't find select_input then_input node."); conditional_block_cood = conditional_block_cood_it->second; auto then_node = parser.GetOpDesc(conditional_block_cood.first, conditional_block_cood.second); + + // use node.type() to make sure correctness if (then_node.type().find("conditional_block") != std::string::npos) { then_graph = ExportConditionalBlock(parser, conditional_block_cood.first, conditional_block_cood.second, then_node_name); } else { diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index 6efa6e3ae..b90fadb37 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -46,7 +46,7 @@ def __init__(self): def forward(self, cond, inputs): if cond == 1: - return inputs * 1, inputs * 2 + return inputs * 1, inputs * 2 else: return inputs * 3, inputs * 4 @@ -106,7 +106,31 @@ def test_ifelse_4_true(): obj.run() def test_ifelse_4_false(): - op = BaseNet3() + op = BaseNet4() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + +class BaseNet5(paddle.nn.Layer): + def __init__(self): + super(BaseNet5, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return 1, 2 + else: + return 2, 3 + +def test_ifelse_5_true(): + op = BaseNet5() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_5_false(): + op = BaseNet5() op.eval() obj = APIOnnx(op, 'ifelse', [11]) obj.set_input_data("input_data", paddle.to_tensor(2)) @@ -121,3 +145,5 @@ def test_ifelse_4_false(): test_ifelse_3_false() test_ifelse_4_true() test_ifelse_4_false() + test_ifelse_5_true() + test_ifelse_5_false() From 75428fd25475faa2333a0a512e9b49c57c4f9b92 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Fri, 20 Sep 2024 11:10:53 +0800 Subject: [PATCH 4/6] Add missing implementation --- paddle2onnx/mapper/exporter.cc | 105 +++++++++++++++++++-------------- 1 file changed, 62 insertions(+), 43 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 2941d729a..b86e93f10 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -191,41 +191,41 @@ void ModelExporter::SetOpsetVersion(const PaddleParser &parser, inline ONNX_NAMESPACE::Version ModelExporter::GetIRVersion() const { int ir_version = 0; switch (opset_version_) { - case 7: - case 8: - ir_version = 3; - break; - case 9: - ir_version = 4; - break; - case 10: - ir_version = 5; - break; - case 11: - ir_version = 6; - break; - case 12: - case 13: - case 14: - ir_version = 7; - break; - case 15: - case 16: - case 17: - case 18: - ir_version = 8; - break; - case 19: - case 20: - ir_version = 9; - break; - case 21: - ir_version = 10; - break; - default: - P2OLogger(verbose_) << "The Opset Version must be between 7 and 21." - << std::endl; - Assert(false, "Due to opset version, the model exporting is aborted."); + case 7: + case 8: + ir_version = 3; + break; + case 9: + ir_version = 4; + break; + case 10: + ir_version = 5; + break; + case 11: + ir_version = 6; + break; + case 12: + case 13: + case 14: + ir_version = 7; + break; + case 15: + case 16: + case 17: + case 18: + ir_version = 8; + break; + case 19: + case 20: + ir_version = 9; + break; + case 21: + ir_version = 10; + break; + default: + P2OLogger(verbose_) << "The Opset Version must be between 7 and 21." + << std::endl; + Assert(false, "Due to opset version, the model exporting is aborted."); } return static_cast(ir_version); } @@ -262,9 +262,10 @@ void ModelExporter::ExportParameters( } } -ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( - const PaddleParser &parser, int32_t block_id, int32_t op_id, - const std::string &output_names) { +ONNX_NAMESPACE::GraphProto +ModelExporter::ExportConditionalBlock(const PaddleParser &parser, + int32_t block_id, int32_t op_id, + const std::string &output_names) { auto op = parser.GetOpDesc(block_id, op_id); // Get sub_block_idx @@ -299,6 +300,24 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( temp_inputs, temp_outputs)); } +ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant( + const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id, const std::string &output_names) { + ONNX_NAMESPACE::GraphProto graph; + graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); + auto op = parser.GetOpDesc(block_id, op_id); // fill_constant + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + + *(graph.add_output()) = (*MakeValueInfo(out_info[0])); + for (auto &item : temp_helper->nodes) { + if (item->output(0) == output_names) { + *(graph.add_node()) = (*item.get()); + break; + } + } + + return std::move(graph); +} ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, @@ -378,8 +397,8 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( AddAttribute(node, "else_branch", else_graph); continue; } else if (op.type() == "fill_constant") { - auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); - sub_block_map_[out_info[0].name] = {block_id, op_id}; + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + sub_block_map_[out_info[0].name] = {block_id, op_id}; } ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_); } @@ -780,8 +799,8 @@ std::string ModelExporter::Run( return out; } -ONNX_NAMESPACE::ModelProto ModelExporter::Optimize( - const ONNX_NAMESPACE::ModelProto &model) { +ONNX_NAMESPACE::ModelProto +ModelExporter::Optimize(const ONNX_NAMESPACE::ModelProto &model) { ONNX_NAMESPACE::optimization::Optimizer::passes .registerPass(); ONNX_NAMESPACE::optimization::Optimizer::passes @@ -809,4 +828,4 @@ ONNX_NAMESPACE::ModelProto ModelExporter::Optimize( return ONNX_NAMESPACE::optimization::Optimize(model, passes); } -} // namespace paddle2onnx +} // namespace paddle2onnx From e0fb72d01fb2822a3b55f75698397a9df717a3c2 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Fri, 20 Sep 2024 11:20:27 +0800 Subject: [PATCH 5/6] Restore code format --- paddle2onnx/mapper/exporter.cc | 77 +++++++++++++++++----------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index b86e93f10..74c1d7686 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -191,41 +191,41 @@ void ModelExporter::SetOpsetVersion(const PaddleParser &parser, inline ONNX_NAMESPACE::Version ModelExporter::GetIRVersion() const { int ir_version = 0; switch (opset_version_) { - case 7: - case 8: - ir_version = 3; - break; - case 9: - ir_version = 4; - break; - case 10: - ir_version = 5; - break; - case 11: - ir_version = 6; - break; - case 12: - case 13: - case 14: - ir_version = 7; - break; - case 15: - case 16: - case 17: - case 18: - ir_version = 8; - break; - case 19: - case 20: - ir_version = 9; - break; - case 21: - ir_version = 10; - break; - default: - P2OLogger(verbose_) << "The Opset Version must be between 7 and 21." - << std::endl; - Assert(false, "Due to opset version, the model exporting is aborted."); + case 7: + case 8: + ir_version = 3; + break; + case 9: + ir_version = 4; + break; + case 10: + ir_version = 5; + break; + case 11: + ir_version = 6; + break; + case 12: + case 13: + case 14: + ir_version = 7; + break; + case 15: + case 16: + case 17: + case 18: + ir_version = 8; + break; + case 19: + case 20: + ir_version = 9; + break; + case 21: + ir_version = 10; + break; + default: + P2OLogger(verbose_) << "The Opset Version must be between 7 and 21." + << std::endl; + Assert(false, "Due to opset version, the model exporting is aborted."); } return static_cast(ir_version); } @@ -262,10 +262,9 @@ void ModelExporter::ExportParameters( } } -ONNX_NAMESPACE::GraphProto -ModelExporter::ExportConditionalBlock(const PaddleParser &parser, - int32_t block_id, int32_t op_id, - const std::string &output_names) { +ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( + const PaddleParser &parser, int32_t block_id, int32_t op_id, + const std::string &output_names) { auto op = parser.GetOpDesc(block_id, op_id); // Get sub_block_idx From 0431b045fdaad10003d0d3eb31e065795a9c8d19 Mon Sep 17 00:00:00 2001 From: starfall <1186454801@qq.com> Date: Fri, 20 Sep 2024 11:22:40 +0800 Subject: [PATCH 6/6] Restore code format --- paddle2onnx/mapper/exporter.cc | 4 ++-- paddle2onnx/mapper/exporter.h | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 74c1d7686..474b693f9 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -798,8 +798,8 @@ std::string ModelExporter::Run( return out; } -ONNX_NAMESPACE::ModelProto -ModelExporter::Optimize(const ONNX_NAMESPACE::ModelProto &model) { +ONNX_NAMESPACE::ModelProto ModelExporter::Optimize( + const ONNX_NAMESPACE::ModelProto &model) { ONNX_NAMESPACE::optimization::Optimizer::passes .registerPass(); ONNX_NAMESPACE::optimization::Optimizer::passes diff --git a/paddle2onnx/mapper/exporter.h b/paddle2onnx/mapper/exporter.h index b0ed6ac69..f3d8577d9 100644 --- a/paddle2onnx/mapper/exporter.h +++ b/paddle2onnx/mapper/exporter.h @@ -106,10 +106,10 @@ class ModelExporter { ONNX_NAMESPACE::GraphProto ExportConditionalBlock( const PaddleParser &parser, int32_t block_id, int32_t op_id, const std::string &output_names); - ONNX_NAMESPACE::GraphProto - ExportFillConstant(const PaddleParser &parser, OnnxHelper *temp_helper, - int32_t block_id, int32_t op_id, - const std::string &output_names); + ONNX_NAMESPACE::GraphProto ExportFillConstant( + const PaddleParser &parser, OnnxHelper *temp_helper, + int32_t block_id, int32_t op_id, + const std::string &output_names); ONNX_NAMESPACE::GraphProto ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters,