Skip to content

Commit

Permalink
Only update modified if the graph was modified.
Browse files Browse the repository at this point in the history
Current behavior forces all L2 optimizers to loop until they hit the max number of iterations.
  • Loading branch information
skottmckay committed Jul 25, 2024
1 parent ae3ec2e commit 97bfdfc
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ Note: This fusion doesn't consider the following case:
LayerNormalization
*/

Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
InlinedVector<std::reference_wrapper<Node>> nodes_to_remove;
Expand Down Expand Up @@ -299,12 +300,15 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
// Assign provider to this new node. Provider should be same as the provider for old node.
skip_layer_norm_node.SetExecutionProviderType(ln_node.GetExecutionProviderType());
}

for (const auto& node : nodes_to_remove) {
graph_utils::RemoveNodeOutputEdges(graph, node);
graph.RemoveNode(node.get().Index());
}

modified = true;
if (!nodes_to_remove.empty()) {
modified = true;
}

return Status::OK();
}
Expand Down

0 comments on commit 97bfdfc

Please sign in to comment.