From 702b2e28e0c2a1604914d2e6065903aaf122ce7f Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 9 Aug 2024 06:52:59 -0700 Subject: [PATCH] Fuse Pad even if Cast is present in-between (#21640) ### Description This change enhances the existing Pad Fusion to fuse Pad even if a Cast operator is present between Pad and Conv/MaxPool/AveragePool. It keeps the Cast as it is.
/*
 * Before Fusion:
 *     Pad
 *      |
 *    Cast (Optional)
 *      |
 *   Conv/MaxPool/AveragePool
 *
 * After Fusion:
 *    Cast (Optional)
 *      |
 *   Conv/MaxPool/AveragePool
 */
### Motivation and Context --- onnxruntime/core/optimizer/pad_fusion.cc | 93 ++++++++++++++++-------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index e266946b0d9e..3391e20cf0bb 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -8,25 +8,7 @@ namespace onnxruntime { -/* - * It matches following pattern: - * Pad - * | - * Conv/MaxPool/AveragePool - */ -bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { - // if Pad has input axis, don't fuse it. - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) || - node.GetOutputEdgesCount() != 1 || - node.InputDefs().size() > 3) { - return false; - } - - if (graph.NodeProducesGraphOutput(node)) { - return false; - } - - const Node& child_node = *node.OutputNodesBegin(); +bool VerifyNotCastChild(const Node& child_node) { if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) && !graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) { @@ -54,6 +36,45 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log return false; } + return true; +} + +void UpdatePaddingAttribute(Node& child_node, const std::vector& pads_values, const uint32_t pads_size) { + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); + uint32_t child_pads_size = static_cast(child_pads->size()); + + for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) { + child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]); + uint32_t mirrored_child_index = child_index + (child_pads_size / 2); + uint32_t mirrored_pad_index = pads_index + (pads_size / 2); + child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]); + } +} +/* + * Before: + * Pad + * | + * Cast (Optional) + * | + * Conv/MaxPool/AveragePool + * + * After: + * Cast (Optional) + * | + * Conv/MaxPool/AveragePool + */ +bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + // if Pad has input axis, don't fuse it. + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) || + node.GetOutputEdgesCount() != 1 || + node.InputDefs().size() > 3) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + const NodeAttributes& pad_attributes = node.GetAttributes(); if (pad_attributes.find("mode") != pad_attributes.end() && pad_attributes.at("mode").s() != "constant") { @@ -83,7 +104,19 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log } } - return true; + const Node& child_node = *node.OutputNodesBegin(); + if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13})) { + if (child_node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(child_node)) { + return false; + } + return VerifyNotCastChild(*child_node.OutputNodesBegin()); + } else { + return VerifyNotCastChild(child_node); + } } /* @@ -100,8 +133,6 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end()); } - assert(static_cast(pads_values.size()) == (2 * static_cast(pad_node.InputDefs()[0]->Shape()->dim_size()))); - uint32_t pads_size = static_cast(pads_values.size()); // check if padding is applied only on feature dims if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 || @@ -115,18 +146,18 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef } Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index()); - auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); - uint32_t child_pads_size = static_cast(child_pads->size()); - - for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) { - child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]); - uint32_t mirrored_child_index = child_index + (child_pads_size / 2); - uint32_t mirrored_pad_index = pads_index + (pads_size / 2); - child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]); - } + // We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value + // to be zero. See PadFusion::SatisfyCondition for details. + Node& target_padding_node = (child_node.OpType() == "Cast") ? *graph.GetNode(child_node.OutputNodesBegin()->Index()) : child_node; + UpdatePaddingAttribute(target_padding_node, pads_values, pads_size); graph_utils::RemoveNodeOutputEdges(graph, pad_node); graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]); + // Un-pad the output shape of Cast node + if (child_node.OpType() == "Cast") { + auto* cast_output_node_arg = child_node.MutableOutputDefs()[0]; + cast_output_node_arg->SetShape(*pad_node.MutableInputDefs()[0]->Shape()); + } graph.RemoveNode(pad_node.Index()); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK();