Skip to content

Commit

Permalink
Fix SkipLayerNormFusion incorrectly setting modified every time it ru…
Browse files Browse the repository at this point in the history
…ns (#21502)

### Description
<!-- Describe your changes. -->
Current behavior forces all L2 optimizers to loop until they hit the max
number of iterations.

Only update modified if the graph was modified.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fix unnecessary loops of L2 optimizers during model loading.
  • Loading branch information
skottmckay authored Jul 26, 2024
1 parent c464ab3 commit e5302b2
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ Note: This fusion doesn't consider the following case:
LayerNormalization
*/

Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down Expand Up @@ -299,12 +300,15 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
// Assign provider to this new node. Provider should be same as the provider for old node.
skip_layer_norm_node.SetExecutionProviderType(ln_node.GetExecutionProviderType());
}

for (const auto& node : nodes_to_remove) {
graph_utils::RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.get().Index());
}

modified = true;
if (!nodes_to_remove.empty()) {
modified = true;
}

return Status::OK();
}
Expand Down

0 comments on commit e5302b2

Please sign in to comment.