Skip to content

Commit

Permalink
Support Graph Input and Initializer for GatherToSplit Fusion (microso…
Browse files Browse the repository at this point in the history
…ft#18412)

Support graph input and initializer for GatherToSplit fusion. Previously
the fusion requires Gather nodes consume some other node which cannot be
graph input or initializer.

This helps some model training with such case so that we will not have
GatherGrad in the final graph. GatherGrad is super inefficient in kernel
implementation.
  • Loading branch information
centwang authored Nov 15, 2023
1 parent d738ff1 commit b0699d9
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 45 deletions.
65 changes: 45 additions & 20 deletions onnxruntime/core/optimizer/gather_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

namespace onnxruntime {

bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const {
bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis,
int64_t& indices_n_dims) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
return false;
Expand Down Expand Up @@ -53,6 +54,22 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

InlinedVector<const NodeArg*> node_args;
for (auto node_arg : graph.GetInputs()) {
if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) {
node_args.push_back(node_arg);
}
}

for (auto entry : graph.GetAllInitializedTensors()) {
if (graph.GetConsumerNodes(entry.first).size() > 1) {
auto node_arg = graph.GetNodeArg(entry.first);
if (node_arg) {
node_args.push_back(node_arg);
}
}
}

for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion
Expand All @@ -73,19 +90,26 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
size_t output_count = node.GetOutputEdgesCount();
if (output_count <= 1) continue;

auto shape = node.MutableOutputDefs()[0]->Shape();
node_args.push_back(node.OutputDefs()[0]);
}

for (const NodeArg* node_arg : node_args) {
auto shape = node_arg->Shape();
if (!shape) continue;
int64_t rank = static_cast<int64_t>(shape->dim_size());

bool can_fuse = true;
bool first_edge = true;
int64_t split_axis = 0;
int64_t indices_n_dims = -1;
InlinedVector<NodeArg*> gather_outputs(output_count, nullptr);
auto consumers = graph.GetConsumerNodes(node_arg->Name());
size_t consumer_count = consumers.size();
InlinedVector<NodeArg*> gather_outputs(consumer_count, nullptr);
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
for (auto it = node.OutputNodesBegin(); it != node.OutputNodesEnd(); ++it) {
for (auto consumer : consumers) {
int64_t index, axis, dims;
if (!IsSupportedGather(graph, *it, index, axis, dims)) {
if (!consumer || consumer->InputDefs()[0] != node_arg ||
!IsSupportedGather(graph, *consumer, index, axis, dims)) {
can_fuse = false;
break;
}
Expand All @@ -99,7 +123,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
if (axis < 0) axis += rank;
if (first_edge) {
auto dim = shape->dim(static_cast<int>(axis));
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(output_count)) {
if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast<int64_t>(consumer_count)) {
can_fuse = false;
break;
}
Expand All @@ -109,21 +133,21 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
can_fuse = false;
break;
}
if (index < 0) index += static_cast<int64_t>(output_count);
if (index < 0 || index >= static_cast<int64_t>(output_count) || gather_outputs[static_cast<size_t>(index)]) {
if (index < 0) index += static_cast<int64_t>(consumer_count);
if (index < 0 || index >= static_cast<int64_t>(consumer_count) || gather_outputs[static_cast<size_t>(index)]) {
can_fuse = false;
break;
}
Node& gather_node = *graph.GetNode(it->Index());
Node& gather_node = *graph.GetNode(consumer->Index());
nodes_to_fuse.emplace_back(gather_node);
gather_outputs[static_cast<size_t>(index)] = gather_node.MutableOutputDefs()[0];
}

if (!can_fuse) continue;

ONNX_NAMESPACE::TypeProto split_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type());
const ONNX_NAMESPACE::TensorProto_DataType element_type =
static_cast<ONNX_NAMESPACE::TensorProto_DataType>(node_arg->TypeAsProto()->tensor_type().elem_type());
split_output_type.mutable_tensor_type()->set_elem_type(element_type);
for (int64_t i = 0; i < rank; ++i) {
if (i == split_axis) {
Expand All @@ -136,16 +160,17 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
InlinedVector<NodeArg*> split_outputs;
bool add_squeeze_node = indices_n_dims == 0;
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
split_outputs.emplace_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type));
}
}

Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{node.MutableOutputDefs()[0]}, add_squeeze_node ? split_outputs : gather_outputs);
Node& split_node =
graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes",
{graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs);
split_node.AddAttribute("axis", split_axis);
split_node.SetExecutionProviderType(node.GetExecutionProviderType());
split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());

// Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas.
int onnx_opset_version = -1;
Expand All @@ -155,16 +180,16 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le

if (onnx_opset_version < 13) {
if (add_squeeze_node) {
for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]});
squeeze_node.AddAttribute("axes", std::vector<int64_t>{split_axis});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
}
}
} else {
if (onnx_opset_version >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(output_count));
split_node.AddAttribute("num_outputs", static_cast<int64_t>(consumer_count));
}

if (add_squeeze_node) {
Expand All @@ -176,11 +201,11 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t));
NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto);

for (size_t i = 0; i < output_count; ++i) {
for (size_t i = 0; i < consumer_count; ++i) {
Node& squeeze_node =
graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze",
"Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]});
squeeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType());
}
}
}
Expand Down
Loading

0 comments on commit b0699d9

Please sign in to comment.