diff --git a/paddle2onnx/mapper/exporter.cc b/paddle2onnx/mapper/exporter.cc index 2b43fed9a..474b693f9 100644 --- a/paddle2onnx/mapper/exporter.cc +++ b/paddle2onnx/mapper/exporter.cc @@ -299,6 +299,24 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportConditionalBlock( temp_inputs, temp_outputs)); } +ONNX_NAMESPACE::GraphProto ModelExporter::ExportFillConstant( + const PaddleParser &parser, OnnxHelper *temp_helper, int32_t block_id, + int32_t op_id, const std::string &output_names) { + ONNX_NAMESPACE::GraphProto graph; + graph.set_name("PaddlePaddle fill_constant Graph " + std::to_string(op_id)); + auto op = parser.GetOpDesc(block_id, op_id); // fill_constant + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + + *(graph.add_output()) = (*MakeValueInfo(out_info[0])); + for (auto &item : temp_helper->nodes) { + if (item->output(0) == output_names) { + *(graph.add_node()) = (*item.get()); + break; + } + } + + return std::move(graph); +} ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, @@ -328,23 +346,45 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( Assert(input_info.size() == 2, "Only support when number of select_input's input_node is 2."); + // Build else sub graph auto else_node_name = input_info[0].name; auto conditional_block_cood_it = sub_block_map_.find(else_node_name); Assert(conditional_block_cood_it != sub_block_map_.end(), - "Don't find select_input else_input node."); + "Can't find select_input else_input node."); auto conditional_block_cood = conditional_block_cood_it->second; - auto else_graph = - ExportConditionalBlock(parser, conditional_block_cood.first, - conditional_block_cood.second, else_node_name); + ONNX_NAMESPACE::GraphProto else_graph, then_graph; + auto else_node = parser.GetOpDesc(conditional_block_cood.first, + conditional_block_cood.second); + + if (else_node.type().find("conditional_block") != std::string::npos) { + else_graph = ExportConditionalBlock( + parser, conditional_block_cood.first, conditional_block_cood.second, + else_node_name); + } else { + else_graph = ExportFillConstant( + parser, &temp_helper, conditional_block_cood.first, + conditional_block_cood.second, else_node_name); + } + // Build then sub graph auto then_node_name = input_info[1].name; conditional_block_cood_it = sub_block_map_.find(then_node_name); Assert(conditional_block_cood_it != sub_block_map_.end(), - "Don't find select_input then_input node."); + "Can't find select_input then_input node."); conditional_block_cood = conditional_block_cood_it->second; - auto then_graph = - ExportConditionalBlock(parser, conditional_block_cood.first, - conditional_block_cood.second, then_node_name); + auto then_node = parser.GetOpDesc(conditional_block_cood.first, + conditional_block_cood.second); + + // use node.type() to make sure correctness + if (then_node.type().find("conditional_block") != std::string::npos) { + then_graph = ExportConditionalBlock( + parser, conditional_block_cood.first, conditional_block_cood.second, + then_node_name); + } else { + then_graph = ExportFillConstant( + parser, &temp_helper, conditional_block_cood.first, + conditional_block_cood.second, then_node_name); + } auto cond_info = parser.GetOpInput(block_id, op_id, "Mask"); auto output_info = parser.GetOpOutput(block_id, op_id, "Out"); @@ -355,6 +395,9 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportBlock( AddAttribute(node, "then_branch", then_graph); AddAttribute(node, "else_branch", else_graph); continue; + } else if (op.type() == "fill_constant") { + auto out_info = parser.GetOpOutput(block_id, op_id, "Out"); + sub_block_map_[out_info[0].name] = {block_id, op_id}; } ExportOp(parser, &temp_helper, opset_version_, block_id, op_id, verbose_); } @@ -784,4 +827,4 @@ ONNX_NAMESPACE::ModelProto ModelExporter::Optimize( return ONNX_NAMESPACE::optimization::Optimize(model, passes); } -} // namespace paddle2onnx +} // namespace paddle2onnx diff --git a/paddle2onnx/mapper/exporter.h b/paddle2onnx/mapper/exporter.h index a64f6d039..f3d8577d9 100644 --- a/paddle2onnx/mapper/exporter.h +++ b/paddle2onnx/mapper/exporter.h @@ -106,6 +106,10 @@ class ModelExporter { ONNX_NAMESPACE::GraphProto ExportConditionalBlock( const PaddleParser &parser, int32_t block_id, int32_t op_id, const std::string &output_names); + ONNX_NAMESPACE::GraphProto ExportFillConstant( + const PaddleParser &parser, OnnxHelper *temp_helper, + int32_t block_id, int32_t op_id, + const std::string &output_names); ONNX_NAMESPACE::GraphProto ExportBlock( const PaddleParser &parser, int32_t block_id, std::vector> ¶meters, diff --git a/tests/onnxbase.py b/tests/onnxbase.py index ca5016ab6..32a1e2a78 100755 --- a/tests/onnxbase.py +++ b/tests/onnxbase.py @@ -64,7 +64,11 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): # Convert Paddle Tensor to Numpy array if type(expect) == list: expect = expect[0] - expect = expect.numpy() + + if isinstance(expect, paddle.Tensor): + expect = expect.numpy() + else: + expect = np.array(expect) # For result_shape is (1) and expect_shape shape is () expect = expect.squeeze() diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index b04079843..b90fadb37 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -46,7 +46,7 @@ def __init__(self): def forward(self, cond, inputs): if cond == 1: - return inputs * 1, inputs * 2 + return inputs * 1, inputs * 2 else: return inputs * 3, inputs * 4 @@ -64,8 +64,86 @@ def test_ifelse_2_false(): obj.set_input_data("input_data", paddle.to_tensor(2), paddle.to_tensor(1)) obj.run() +class BaseNet3(paddle.nn.Layer): + def __init__(self): + super(BaseNet3, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return 1 + else: + return 2 + +def test_ifelse_3_true(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_3_false(): + op = BaseNet3() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + +class BaseNet4(paddle.nn.Layer): + def __init__(self): + super(BaseNet4, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return inputs + 1 + else: + return 2 + +def test_ifelse_4_true(): + op = BaseNet4() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_4_false(): + op = BaseNet4() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + +class BaseNet5(paddle.nn.Layer): + def __init__(self): + super(BaseNet5, self).__init__() + + def forward(self, inputs): + if inputs == 1: + return 1, 2 + else: + return 2, 3 + +def test_ifelse_5_true(): + op = BaseNet5() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(1)) + obj.run() + +def test_ifelse_5_false(): + op = BaseNet5() + op.eval() + obj = APIOnnx(op, 'ifelse', [11]) + obj.set_input_data("input_data", paddle.to_tensor(2)) + obj.run() + if __name__ == "__main__": test_ifelse_1_true() test_ifelse_1_false() test_ifelse_2_true() - test_ifelse_2_false() \ No newline at end of file + test_ifelse_2_false() + test_ifelse_3_true() + test_ifelse_3_false() + test_ifelse_4_true() + test_ifelse_4_false() + test_ifelse_5_true() + test_ifelse_5_false()