Skip to content

Commit

Permalink
Merge GatherToSplitFusion and microsoft#19218 to a General Fusion (mi…
Browse files Browse the repository at this point in the history
…crosoft#19600)

microsoft#19218 tried to fuse Gather/Slice to Split, but the logic has problem.
Scalar value or 1-dim value of indices in Gather node will produce
different result, scalar value will produce a result tensor by removing
the axis dim, will 1-dim indices value will keep that dim, even when the
dim value is 1. For example,

Node
    |-> Gather(indices=[0], axis=axis)
    |-> Gather(indices=[1], axis=axis)
    |-> Slice(index=2, axis=axis)
is same as
Node
   |-> Split(axis=axis)

But
Node
    |-> Gather(indices=0, axis=axis)
    |-> Gather(indices=1, axis=axis)
    |-> Slice(index=2, axis=axis)
is same as
Node
    |-> Split(axis=axis)
        ||-> Squeeze(axis=axis)
        ||-> Squeeze(axis=axis)
        ||->

Previous PR doesn't take such case related to Squeeze/Unsqueeze into
account.

This PR merges microsoft#19218 and GatherToSplitFusion to a general fusion, which
relaxes the limit the number of Gather and Slice node number, check all
Gather and Slice consumers, if the indices of Gather and start/end of
Slice can cover the specific dim of the input tensor, then we can fuse
them to a Split, and adding Squeeze if necessary according to the dim
count of the indices tensor in Gather.

@rui-ren, please check if the fix can still be applied to your model.
  • Loading branch information
centwang authored Feb 29, 2024
1 parent 7455dd1 commit d2e6dd2
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 916 deletions.
318 changes: 193 additions & 125 deletions onnxruntime/core/optimizer/gather_fusion.cc

Large diffs are not rendered by default.

16 changes: 10 additions & 6 deletions onnxruntime/core/optimizer/gather_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,23 @@
namespace onnxruntime {

/**
@Class GatherToSplitFusion
@Class GatherSliceToSplitFusion
Fuse multiple Gather nodes that comsuming one output to one Split node.
Fuse multiple Gather/Slice nodes that comsuming one output to one Split node.
*/
class GatherToSplitFusion : public GraphTransformer {
class GatherSliceToSplitFusion : public GraphTransformer {
public:
GatherToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {}
GatherSliceToSplitFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

private:
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const;
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, bool& need_squeeze) const;

bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size,
InlinedVector<bool>& consumed, int64_t& start, int64_t& end) const;
};

/**
Expand Down
344 changes: 0 additions & 344 deletions onnxruntime/core/optimizer/gather_slice_fusion.cc

This file was deleted.

Loading

0 comments on commit d2e6dd2

Please sign in to comment.