Skip to content

Commit

Permalink
Fuse Pad even if Cast is present in-between (microsoft#21640)
Browse files Browse the repository at this point in the history
### Description
This change enhances the existing Pad Fusion to fuse Pad even if a Cast
operator is present between Pad and Conv/MaxPool/AveragePool. It keeps
the Cast as it is.
<pre>
/*
 * Before Fusion:
 *     Pad
 *      |
 *    Cast (Optional)
 *      |
 *   Conv/MaxPool/AveragePool
 * 
 * After Fusion:
 *    Cast (Optional)
 *      |
 *   Conv/MaxPool/AveragePool
 */
</pre>


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
sumitsays authored Aug 9, 2024
1 parent e6e4047 commit 702b2e2
Showing 1 changed file with 62 additions and 31 deletions.
93 changes: 62 additions & 31 deletions onnxruntime/core/optimizer/pad_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,7 @@

namespace onnxruntime {

/*
* It matches following pattern:
* Pad
* |
* Conv/MaxPool/AveragePool
*/
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// if Pad has input axis, don't fuse it.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
node.GetOutputEdgesCount() != 1 ||
node.InputDefs().size() > 3) {
return false;
}

if (graph.NodeProducesGraphOutput(node)) {
return false;
}

const Node& child_node = *node.OutputNodesBegin();
bool VerifyNotCastChild(const Node& child_node) {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Conv", {1, 11}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "AveragePool", {1, 7, 10, 11, 19}) &&
!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "MaxPool", {1, 8, 10, 11, 12})) {
Expand Down Expand Up @@ -54,6 +36,45 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
return false;
}

return true;
}

void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}
}
/*
* Before:
* Pad
* |
* Cast (Optional)
* |
* Conv/MaxPool/AveragePool
*
* After:
* Cast (Optional)
* |
* Conv/MaxPool/AveragePool
*/
bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
// if Pad has input axis, don't fuse it.
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Pad", {1, 2, 11, 13, 18, 19}) ||
node.GetOutputEdgesCount() != 1 ||
node.InputDefs().size() > 3) {
return false;
}

if (graph.NodeProducesGraphOutput(node)) {
return false;
}

const NodeAttributes& pad_attributes = node.GetAttributes();
if (pad_attributes.find("mode") != pad_attributes.end() &&
pad_attributes.at("mode").s() != "constant") {
Expand Down Expand Up @@ -83,7 +104,19 @@ bool PadFusion::SatisfyCondition(const Graph& graph, const Node& node, const log
}
}

return true;
const Node& child_node = *node.OutputNodesBegin();
if (graph_utils::IsSupportedOptypeVersionAndDomain(child_node, "Cast", {1, 6, 9, 13})) {
if (child_node.GetOutputEdgesCount() != 1) {
return false;
}

if (graph.NodeProducesGraphOutput(child_node)) {
return false;
}
return VerifyNotCastChild(*child_node.OutputNodesBegin());
} else {
return VerifyNotCastChild(child_node);
}
}

/*
Expand All @@ -100,8 +133,6 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
pads_values.assign(pad_node.GetAttributes().at("pads").ints().begin(), pad_node.GetAttributes().at("pads").ints().end());
}

assert(static_cast<uint32_t>(pads_values.size()) == (2 * static_cast<uint32_t>(pad_node.InputDefs()[0]->Shape()->dim_size())));

uint32_t pads_size = static_cast<uint32_t>(pads_values.size());
// check if padding is applied only on feature dims
if (pads_values[0] != 0 || pads_values[1] != 0 || pads_values[pads_size / 2] != 0 ||
Expand All @@ -115,18 +146,18 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
}

Node& child_node = *graph.GetNode(pad_node.OutputNodesBegin()->Index());
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());

for (uint32_t pads_index = 2, child_index = 0; pads_index < pads_size / 2; pads_index++, child_index++) {
child_pads->Set(child_index, child_pads->Get(child_index) + pads_values[pads_index]);
uint32_t mirrored_child_index = child_index + (child_pads_size / 2);
uint32_t mirrored_pad_index = pads_index + (pads_size / 2);
child_pads->Set(mirrored_child_index, child_pads->Get(mirrored_child_index) + pads_values[mirrored_pad_index]);
}
// We don't need to cast the pad_constant_value because this fusion requires that constant_pad_value
// to be zero. See PadFusion::SatisfyCondition for details.
Node& target_padding_node = (child_node.OpType() == "Cast") ? *graph.GetNode(child_node.OutputNodesBegin()->Index()) : child_node;
UpdatePaddingAttribute(target_padding_node, pads_values, pads_size);

graph_utils::RemoveNodeOutputEdges(graph, pad_node);
graph_utils::ReplaceNodeInput(child_node, 0, *pad_node.MutableInputDefs()[0]);
// Un-pad the output shape of Cast node
if (child_node.OpType() == "Cast") {
auto* cast_output_node_arg = child_node.MutableOutputDefs()[0];
cast_output_node_arg->SetShape(*pad_node.MutableInputDefs()[0]->Shape());
}
graph.RemoveNode(pad_node.Index());
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
return Status::OK();
Expand Down

0 comments on commit 702b2e2

Please sign in to comment.