Skip to content

Commit

Permalink
Add support for fusing reshape ops around TFL_BatchMatMulOp when the …
Browse files Browse the repository at this point in the history
…input has more than one contracting dim.

PiperOrigin-RevId: 672172067
  • Loading branch information
vamsimanchala authored and tensorflower-gardener committed Sep 8, 2024
1 parent b580857 commit 859fc9f
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 13 deletions.
40 changes: 40 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/optimize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,46 @@ func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tenso
// CHECK: return %3 : tensor<2x768xf32>
}

// CHECK-LABEL: @FuseBMMOutputReshape_WithTwoLHSContractionDims
func.func @FuseBMMOutputReshape_WithTwoLHSContractionDims(%arg0: tensor<8x256x1792xf32>, %arg1: tensor<1x128x8x256xf32>) -> (tensor<1x128x1792xf32>){
%cst = arith.constant dense<[1, 128, 1792]> : tensor<3xi32>
%cst_0 = arith.constant dense<[2048, 1792]> : tensor<2xi32>
%cst_1 = arith.constant dense<[128, 2048]> : tensor<2xi32>
%0 = "tfl.reshape"(%arg1, %cst_1) : (tensor<1x128x8x256xf32>, tensor<2xi32>) -> tensor<128x2048xf32>
%1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<8x256x1792xf32>, tensor<2xi32>) -> tensor<2048x1792xf32>
%2 = "tfl.batch_matmul"(%0, %1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<128x2048xf32>, tensor<2048x1792xf32>) -> tensor<128x1792xf32>
%3 = "tfl.reshape"(%2, %cst) : (tensor<128x1792xf32>, tensor<3xi32>) -> tensor<1x128x1792xf32>
return %3 : tensor<1x128x1792xf32>
// CHECK: %cst = arith.constant dense<[1, 128, 2048]> : tensor<3xi32>
// CHECK: %cst_0 = arith.constant dense<[2048, 1792]> : tensor<2xi32>
// CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<8x256x1792xf32>, tensor<2xi32>) -> tensor<2048x1792xf32>
// CHECK: %1 = "tfl.reshape"(%arg1, %cst) : (tensor<1x128x8x256xf32>, tensor<3xi32>) -> tensor<1x128x2048xf32>
// CHECK: %2 = "tfl.batch_matmul"(%1, %0) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x128x2048xf32>, tensor<2048x1792xf32>) -> tensor<1x128x1792xf32>
// CHECK: return %2 : tensor<1x128x1792xf32>
}

// CHECK-LABEL: @FuseBMMOutputReshape_WithThreeLHSContractionDims
func.func @FuseBMMOutputReshape_WithThreeLHSContractionDims(%arg0: tensor<2x8x256x1792xf32>, %arg1: tensor<1x2x128x8x256xf32>) -> (tensor<1x128x1792xf32>){
%cst = arith.constant dense<[1, 128, 1792]> : tensor<3xi32>
%cst_0 = arith.constant dense<[4096, 1792]> : tensor<2xi32>
%cst_1 = arith.constant dense<[128, 4096]> : tensor<2xi32>
%cst_2 = arith.constant dense<[0, 2, 1, 3, 4]> : tensor<5xi32>
%0 = "tfl.transpose"(%arg1, %cst_2) : (tensor<1x2x128x8x256xf32>, tensor<5xi32>) -> tensor<1x128x2x8x256xf32>
%1 = "tfl.reshape"(%0, %cst_1) : (tensor<1x128x2x8x256xf32>, tensor<2xi32>) -> tensor<128x4096xf32>
%2 = "tfl.reshape"(%arg0, %cst_0) : (tensor<2x8x256x1792xf32>, tensor<2xi32>) -> tensor<4096x1792xf32>
%3 = "tfl.batch_matmul"(%1, %2) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<128x4096xf32>, tensor<4096x1792xf32>) -> tensor<128x1792xf32>
%4 = "tfl.reshape"(%3, %cst) : (tensor<128x1792xf32>, tensor<3xi32>) -> tensor<1x128x1792xf32>
return %4 : tensor<1x128x1792xf32>
// CHECK: %cst = arith.constant dense<[1, 128, 4096]> : tensor<3xi32>
// CHECK: %cst_0 = arith.constant dense<[4096, 1792]> : tensor<2xi32>
// CHECK: %cst_1 = arith.constant dense<[0, 2, 1, 3, 4]> : tensor<5xi32>
// CHECK: %0 = "tfl.transpose"(%arg1, %cst_1) : (tensor<1x2x128x8x256xf32>, tensor<5xi32>) -> tensor<1x128x2x8x256xf32>
// CHECK: %1 = "tfl.reshape"(%arg0, %cst_0) : (tensor<2x8x256x1792xf32>, tensor<2xi32>) -> tensor<4096x1792xf32>
// CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x128x2x8x256xf32>, tensor<3xi32>) -> tensor<1x128x4096xf32>
// CHECK: %3 = "tfl.batch_matmul"(%2, %1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x128x4096xf32>, tensor<4096x1792xf32>) -> tensor<1x128x1792xf32>
// CHECK: return %3 : tensor<1x128x1792xf32>
}

// CHECK-LABEL: @FuseReshapeAroundBMMRHS
func.func @FuseReshapeAroundBMMRHS(%arg0: tensor<1x3x6x5x1024xf32>) -> tensor<1x3x6x5x8192xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "inputs", outputs = "Identity_1"}} {
%cst = arith.constant dense_resource<__elided__> : tensor<1x1024x8192xf32>
Expand Down
47 changes: 40 additions & 7 deletions tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,54 @@ bool HasSameStridedShape(TFL::Conv3DOp op, ArrayRef<int64_t> pre_pad_shape) {

using ::llvm::cast;

// Return true if the product of dimension values of a subsection of the
// tensor is equal to the non-contracting dimension after a reshape
bool BroadcastDimsProductEqual(Value input, Value output,
size_t agg_start_idx) {
// Predicate to check if the product of last few dimensions in LHS is equal to
// the last dimension in RHS.
// agg_start_idx is the index in LHS from where the subsection will start.
bool ContractingDimsProductEqual(Value input, Value output,
size_t agg_start_idx) {
ArrayRef<int64_t> input_shape =
mlir::cast<ShapedType>(input.getType()).getShape();
ArrayRef<int64_t> output_shape =
mlir::cast<ShapedType>(output.getType()).getShape();

int64_t agg_value = 1;
for (size_t i = agg_start_idx; i < input_shape.size() - 1; ++i) {
int agg_value = 1;
for (size_t i = agg_start_idx; i < input_shape.size(); ++i) {
agg_value *= input_shape[i];
}

return (agg_value == output_shape[agg_start_idx]);
return (agg_value == output_shape[output_shape.size() - 1]);
}

// Return true if the product of dimension values of a subsection of the
// tensor is equal to the non-contracting dimension after a reshape
bool NonBroadcastingNonContractingDimsProductEqual(Value original,
Value updated, bool is_lhs,
size_t agg_start_idx,
size_t agg_end_idx = 0) {
ArrayRef<int64_t> original_shape =
mlir::cast<ShapedType>(original.getType()).getShape();
ArrayRef<int64_t> updated_shape =
mlir::cast<ShapedType>(updated.getType()).getShape();

int64_t agg_value = 1;
// If the end_index is not supplied, we'll assume that the contracting
// dimension count is one and skip the one contracting dimension.
if (agg_end_idx == 0) {
if (is_lhs) {
agg_end_idx = original_shape.size() - 2;
} else {
agg_end_idx = original_shape.size() - 1;
}
}
for (size_t i = agg_start_idx; i <= agg_end_idx; ++i) {
agg_value *= original_shape[i];
}

if (is_lhs) {
return (agg_value == updated_shape[updated_shape.size() - 2]);
} else {
return (agg_value == updated_shape[updated_shape.size() - 1]);
}
}

// Returns whether the given type `a` is broadcast-compatible with `b`.
Expand Down
52 changes: 46 additions & 6 deletions tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1538,8 +1538,18 @@ def FuseLeakyReluConst : Pat<

// Return true if the product of dimension values of a subsection of the tensor
// is equal to the non-contracting dimension after a reshape
class BroadcastDimsProductEqual<int agg_start_idx> : Constraint<CPred<
"TFL::BroadcastDimsProductEqual($0, $1, "# agg_start_idx #")">>;
class NonBroadcastingNonContractingLhsDimsProductEqual<int agg_start_idx, int agg_end_idx=0> : Constraint<CPred<
"TFL::NonBroadcastingNonContractingDimsProductEqual($0, $1, true, "# agg_start_idx #","# agg_end_idx #")">>;

// Predicate to check if the last dimensions of two values is equal.
def TrailingDimValuesEqual : Constraint<
CPred<"mlir::cast<ShapedType>($0.getType()).getShape().back()"
"== mlir::cast<ShapedType>($1.getType()).getShape().back()">>;

// Predicate to check if the product of last few dimensions in LHS is equal to
// the last dimension in RHS.
class ContractingDimsProductEqual<int agg_start_idx> : Constraint<CPred<
"TFL::ContractingDimsProductEqual($0, $1, "# agg_start_idx #")">>;

// Returns true if the dimensions of a subsection of two tensors is equal
class AreTensorSubSectionShapesEqual<int skip_first, int skip_last> : Constraint<CPred<
Expand All @@ -1566,10 +1576,39 @@ def FuseReshapesAroundBatchMatMulLHS: Pat<
(TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr),
[(HasRank<2> $rhs),
(HasRank<2> $initial_shape_change),
(BroadcastDimsProductEqual<0> $input, $initial_shape_change),
(BroadcastDimsProductEqual<0> $final_shape_change, $bmm_tmp_output),
(TrailingDimValuesEqual $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<0> $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<0> $final_shape_change, $bmm_tmp_output),
(AreTensorSubSectionShapesEqual<0, 1> $input, $final_shape_change)]>;

// Pattern to fuse/fold the reshape of TFL_BatchMatMulOp output to expand the
// dimensions of the output to add a unity broadcast dimension.
// This pattern assumes that the input has more than one contracting dimensions.
//
// This pattern is applied when-
// 1. The rank of rhs is 2
// 2. The original input reshape has a) reduction in leading broadcast dim and
// b) flattening of the contracting dims.
def FuseOutputReshape_BatchMatMulWithFlattenedContractingDims: Pat<
(TFL_ReshapeOp:$final_shape_change
(TFL_BatchMatMulOp:$bmm_tmp_output
(TFL_ReshapeOp:$initial_shape_change
$input, (Arith_ConstantOp I32ElementsAttr:$s0)),
$rhs, $adj_x, $adj_y, $bool_attr),
(Arith_ConstantOp $s1)),
(TFL_BatchMatMulOp
(TFL_ReshapeOp $input, (Arith_ConstantOp (GetExpandedShape<1> $s0))),
$rhs, $adj_x, $adj_y, $bool_attr),
[(HasRankAtLeast<4> $input),
(HasRank<2> $rhs),
(HasRank<2> $initial_shape_change),
(IsBroadcastDimEqualToOne $input),
(IsBroadcastDimEqualToOne $final_shape_change),
(ContractingDimsProductEqual<2> $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<0,1> $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<0,1> $final_shape_change, $bmm_tmp_output)]>;


// Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp
// This pattern is applied when the rank of rhs is 3
// and the broadcast dimension is [1]
Expand All @@ -1585,8 +1624,9 @@ def FuseReshapesAroundBatchMatMulLHS1: Pat<
(HasRank<3> $initial_shape_change),
(IsBroadcastDimEqualToOne $rhs),
(IsBroadcastDimEqualToOne $input),
(BroadcastDimsProductEqual<1> $input, $initial_shape_change),
(BroadcastDimsProductEqual<1> $final_shape_change, $bmm_tmp_output),
(TrailingDimValuesEqual $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<1> $input, $initial_shape_change),
(NonBroadcastingNonContractingLhsDimsProductEqual<1> $final_shape_change, $bmm_tmp_output),
(AreTensorSubSectionShapesEqual<1, 1> $input, $final_shape_change)]>;


Expand Down
20 changes: 20 additions & 0 deletions tensorflow/compiler/mlir/lite/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,26 @@ inline ShapedType GetTransposedType(Value input,
return transposed_type;
}

// Return the resultant shape if the shape of the supplied attribute/value is
// expanded by n leading 1s'.
inline DenseElementsAttr GetExpandedShape(DenseElementsAttr input_val_attr,
int n) {
SmallVector<int32_t> expanded_shape;
expanded_shape.reserve(input_val_attr.getNumElements() + n);
for (int i = 0; i < n; ++i) {
expanded_shape.push_back(1);
}
expanded_shape.insert(expanded_shape.end(),
input_val_attr.getValues<int32_t>().begin(),
input_val_attr.getValues<int32_t>().end());

return mlir::DenseElementsAttr::get(
RankedTensorType::get(
{static_cast<int>(expanded_shape.size())},
mlir::IntegerType::get(input_val_attr.getContext(), 32)),
llvm::ArrayRef(expanded_shape));
}

// Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor.
// Returns a truncated shape when `truncate` is set to true.
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/mlir/lite/utils/utils.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def CreateNoneValue : NativeCodeCall<
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;

// Return the resultant shape if the shape of the supplied attribute/value is
// expanded by n leading 1s'.
class GetExpandedShape<int n> : NativeCodeCall<
"GetExpandedShape($0.cast<DenseElementsAttr>(), " # n # ")">;

// Constraint that values in list attribute are all ones.
def IsAllOnesConstant : Constraint<CPred<"TFL::IsAllOnesConstant($0)">>;

Expand Down

0 comments on commit 859fc9f

Please sign in to comment.