Skip to content

Commit

Permalink
Add i64 support for tfl.equal folding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671596710
  • Loading branch information
LukeBoyer authored and tensorflower-gardener committed Sep 6, 2024
1 parent 958555c commit 9ce36fb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3644,6 +3644,11 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
getType(), adaptor.getX(), adaptor.getY(),
[](int32_t lhs, int32_t rhs) { return lhs == rhs; });
}
if (getX().getType().getElementType().isInteger(64)) {
return ConstFoldBinaryOp<DenseIntElementsAttr, int64_t, bool>(
getType(), adaptor.getX(), adaptor.getY(),
[](int64_t lhs, int64_t rhs) { return lhs == rhs; });
}
if (getX().getType().getElementType().isF32()) {
return ConstFoldBinaryOp<DenseIntElementsAttr, float, bool>(
getType(), adaptor.getX(), adaptor.getY(),
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/const-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,18 @@ func.func @equal_int() -> tensor<4xi1> {

// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1>

// CHECK-LABEL: @equal_int64
func.func @equal_int64() -> tensor<4xi1> {
%0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi64>
%1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi64>

%2 = "tfl.equal"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1>

func.return %2 : tensor<4xi1>
}

// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1>

// CHECK-LABEL: @equal_float
func.func @equal_float() -> tensor<4xi1> {
%0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32>
Expand Down

0 comments on commit 9ce36fb

Please sign in to comment.