From ccf9f4e3efef996627e87d32c2b0aff28feb5a26 Mon Sep 17 00:00:00 2001 From: Luke Boyer Date: Sat, 17 Aug 2024 13:41:28 -0700 Subject: [PATCH] Add direct legalization for min and max. PiperOrigin-RevId: 664210716 --- .../stablehlo/tests/tfl_legalize_hlo.mlir | 28 ++++++++++++++++ .../transforms/tflite_legalize_hlo.cc | 22 +++++++------ .../tflite_legalize_hlo_patterns.td | 32 ++++++++++++++++--- 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 85593fe8916cc9..62a792a0878946 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -2940,6 +2940,32 @@ func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1 // ----- +//===----------------------------------------------------------------------===// +// mhlo binary element-wise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: maximum +func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.maximum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: minimum +func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.minimum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + // CHECK-LABEL: clamp func.func @clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor @@ -2949,3 +2975,5 @@ func.func @clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> // CHECK-NEXT: %0 = "tfl.minimum"(%arg1, %arg2) // CHECK-NEXT: %1 = "tfl.maximum"(%0, %arg0) // CHECK-NEXT: return %1 : tensor + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 2538569bcb42fe..6b8cdfa89ff0d5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -294,17 +294,19 @@ void LegalizeHloToTfLitePass::runOnOperation() { target.addDynamicallyLegalOp(IsCompareLegal); target.addIllegalOp< + // go/keep-sorted start // clang-format off - // go/keep-sorted start - mhlo::ClampOp, - mhlo::DotGeneralOp, - mhlo::DotOp, - mhlo::DynamicReshapeOp, - mhlo::RemOp, - mhlo::ReshapeOp, - mhlo::ShiftRightArithmeticOp, - mhlo::ShiftRightLogicalOp, - mhlo::TransposeOp + mhlo::ClampOp, + mhlo::DotGeneralOp, + mhlo::DotOp, + mhlo::DynamicReshapeOp, + mhlo::MaxOp, + mhlo::MinOp, + mhlo::RemOp, + mhlo::ReshapeOp, + mhlo::ShiftRightArithmeticOp, + mhlo::ShiftRightLogicalOp, + mhlo::TransposeOp // clang-format on // go/keep-sorted end >(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index 9f76dd9c223ed2..fb9647c174ebcf 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -33,11 +33,6 @@ def LegalizeTranspose : Pat<(MHLO_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, (CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>; - -def : Pat<(MHLO_ShiftRightArithmeticOp $l, $r), (TFL_RightShiftOp $l, $r)>; -def : Pat<(MHLO_ShiftRightLogicalOp $l, $r), (TFL_RightShiftOp $l, $r)>; -def : Pat<(MHLO_RemOp $l, $r), (TFL_FloorModOp $l, $r)>; - def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input), (TFL_ReshapeOp $input, (CreateTFLCastToInt32Op (ShapeToConst $output)))>; @@ -107,6 +102,33 @@ def LegalizeXor : Pat< TFL_IntTensor:$r), (TFL_BitwiseXorOp $l, $r)>; +//===----------------------------------------------------------------------===// +// binary element-wise ops +//===----------------------------------------------------------------------===// + +def : Pat< + (MHLO_ShiftRightArithmeticOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_ShiftRightLogicalOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_RemOp $l, $r), + (TFL_FloorModOp $l, $r)>; + +// Binary ops with no attrs. +foreach pair = [ + [MHLO_MaxOp, TFL_MaximumOp], + [MHLO_MinOp, TFL_MinimumOp], +] in { + def : Pat< + (pair[0] $l, $r), + (pair[1] $l, $r)>; +} + + //===----------------------------------------------------------------------===// // comparison ops //===----------------------------------------------------------------------===//