diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 64b7f8aa4a1bde..cdcf44f1903879 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -3644,6 +3644,11 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { getType(), adaptor.getX(), adaptor.getY(), [](int32_t lhs, int32_t rhs) { return lhs == rhs; }); } + if (getX().getType().getElementType().isInteger(64)) { + return ConstFoldBinaryOp( + getType(), adaptor.getX(), adaptor.getY(), + [](int64_t lhs, int64_t rhs) { return lhs == rhs; }); + } if (getX().getType().getElementType().isF32()) { return ConstFoldBinaryOp( getType(), adaptor.getX(), adaptor.getY(), diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index d50f7a9367fb81..c47149759240d4 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -1470,6 +1470,18 @@ func.func @equal_int() -> tensor<4xi1> { // CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> +// CHECK-LABEL: @equal_int64 +func.func @equal_int64() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi64> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi64> + + %2 = "tfl.equal"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> + // CHECK-LABEL: @equal_float func.func @equal_float() -> tensor<4xi1> { %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32>