Skip to content

Commit

Permalink
Add direct legalization for min and max.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 664210716
  • Loading branch information
LukeBoyer authored and tensorflower-gardener committed Aug 17, 2024
1 parent 43ab21f commit ccf9f4e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2940,6 +2940,32 @@ func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1

// -----

//===----------------------------------------------------------------------===//
// mhlo binary element-wise ops
//===----------------------------------------------------------------------===//

// CHECK-LABEL: maximum
func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
func.return %0 : tensor<4xf32>
}

// CHECK: "tfl.maximum"(%arg0, %arg1)
// CHECK-NOT: mhlo

// -----

// CHECK-LABEL: minimum
func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
func.return %0 : tensor<4xf32>
}

// CHECK: "tfl.minimum"(%arg0, %arg1)
// CHECK-NOT: mhlo

// -----

// CHECK-LABEL: clamp
func.func @clamp(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
Expand All @@ -2949,3 +2975,5 @@ func.func @clamp(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>) ->
// CHECK-NEXT: %0 = "tfl.minimum"(%arg1, %arg2)
// CHECK-NEXT: %1 = "tfl.maximum"(%0, %arg0)
// CHECK-NEXT: return %1 : tensor<f32>


Original file line number Diff line number Diff line change
Expand Up @@ -294,17 +294,19 @@ void LegalizeHloToTfLitePass::runOnOperation() {
target.addDynamicallyLegalOp<mhlo::CompareOp>(IsCompareLegal);

target.addIllegalOp<
// go/keep-sorted start
// clang-format off
// go/keep-sorted start
mhlo::ClampOp,
mhlo::DotGeneralOp,
mhlo::DotOp,
mhlo::DynamicReshapeOp,
mhlo::RemOp,
mhlo::ReshapeOp,
mhlo::ShiftRightArithmeticOp,
mhlo::ShiftRightLogicalOp,
mhlo::TransposeOp
mhlo::ClampOp,
mhlo::DotGeneralOp,
mhlo::DotOp,
mhlo::DynamicReshapeOp,
mhlo::MaxOp,
mhlo::MinOp,
mhlo::RemOp,
mhlo::ReshapeOp,
mhlo::ShiftRightArithmeticOp,
mhlo::ShiftRightLogicalOp,
mhlo::TransposeOp
// clang-format on
// go/keep-sorted end
>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ def LegalizeTranspose : Pat<(MHLO_TransposeOp $arg, $perm),
(TFL_TransposeOp $arg,
(CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>;


def : Pat<(MHLO_ShiftRightArithmeticOp $l, $r), (TFL_RightShiftOp $l, $r)>;
def : Pat<(MHLO_ShiftRightLogicalOp $l, $r), (TFL_RightShiftOp $l, $r)>;
def : Pat<(MHLO_RemOp $l, $r), (TFL_FloorModOp $l, $r)>;

def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input),
(TFL_ReshapeOp $input,
(CreateTFLCastToInt32Op (ShapeToConst $output)))>;
Expand Down Expand Up @@ -107,6 +102,33 @@ def LegalizeXor : Pat<
TFL_IntTensor:$r),
(TFL_BitwiseXorOp $l, $r)>;

//===----------------------------------------------------------------------===//
// binary element-wise ops
//===----------------------------------------------------------------------===//

def : Pat<
(MHLO_ShiftRightArithmeticOp $l, $r),
(TFL_RightShiftOp $l, $r)>;

def : Pat<
(MHLO_ShiftRightLogicalOp $l, $r),
(TFL_RightShiftOp $l, $r)>;

def : Pat<
(MHLO_RemOp $l, $r),
(TFL_FloorModOp $l, $r)>;

// Binary ops with no attrs.
foreach pair = [
[MHLO_MaxOp, TFL_MaximumOp],
[MHLO_MinOp, TFL_MinimumOp],
] in {
def : Pat<
(pair[0] $l, $r),
(pair[1] $l, $r)>;
}


//===----------------------------------------------------------------------===//
// comparison ops
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit ccf9f4e

Please sign in to comment.