Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpreter crashes when sorting float data #2601

Open
GleasonK opened this issue Oct 24, 2024 · 0 comments
Open

Interpreter crashes when sorting float data #2601

GleasonK opened this issue Oct 24, 2024 · 0 comments

Comments

@GleasonK
Copy link
Member

GleasonK commented Oct 24, 2024

What happened?

See the repro, hit an assertion when trying to run some JAX tests through the reference interpreter.

Note the repro relies on the stablehlo-translate changes in:
#2600

Steps to reproduce your issue

The following is a dump of JAX lax_numpy_test.py > testSortStableDescending

// RUN: stablehlo-translate %s --interpret --args="[dense<[0.000000e+00, 1.000000e+00, 0x7FC00000, 0.000000e+00, 2.000000e+00, 0x7FC00000, 0xFF800000, 0x7F800000]> : tensor<8xf32>]"
module @jit_sort attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<8xf32> {mhlo.layout_mode = "default"}) -> (tensor<8xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<0x7FC00000> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = true}> ({
    ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
      %1 = stablehlo.compare  EQ, %arg1, %cst_0,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %2 = stablehlo.select %1, %cst_0, %arg1 : tensor<i1>, tensor<f32>
      %3 = stablehlo.compare  NE, %arg1, %arg1,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %4 = stablehlo.select %3, %cst, %2 : tensor<i1>, tensor<f32>
      %5 = stablehlo.compare  EQ, %arg2, %cst_0,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %6 = stablehlo.select %5, %cst_0, %arg2 : tensor<i1>, tensor<f32>
      %7 = stablehlo.compare  NE, %arg2, %arg2,  FLOAT : (tensor<f32>, tensor<f32>) -> tensor<i1>
      %8 = stablehlo.select %7, %cst, %6 : tensor<i1>, tensor<f32>
      %9 = stablehlo.compare  LT, %4, %8,  TOTALORDER : (tensor<f32>, tensor<f32>) -> tensor<i1>
      stablehlo.return %9 : tensor<i1>
    }) : (tensor<8xf32>) -> tensor<8xf32>
    return %0 : tensor<8xf32>
  }
}

Fails with:

third_party/crosstool/v18/stable/src/libcxx/include/__debug_utils/strict_weak_ordering_check.h:59: assertion __comp(*(__first + __a), *(__first + __b)) failed: Your comparator is not a valid strict-weak ordering

Program received signal SIGABRT, Aborted.
0x00007ffff7cf9981 in raise () from /usr/grte/v5/lib64/libc.so.6
(gdb) bt
#0  0x00007ffff7cf9981 in raise () from /usr/grte/v5/lib64/libc.so.6
#1  0x00007ffff7cfadf7 in abort () from /usr/grte/v5/lib64/libc.so.6
#2  0x000055555a41345d in std::__u::__libcpp_verbose_abort(char const*, ...) ()
#3  0x0000555556195326 in void std::__u::__check_strict_weak_ordering_sorted<long*, mlir::stablehlo::sortOp(llvm::ArrayRef<mlir::stablehlo::Tensor>, long, bool, mlir::Region&, mlir::stablehlo::Process*, mlir::stablehlo::Scope&)::$_0>(long*, long*, mlir::stablehlo::sortOp(llvm::ArrayRef<mlir::stablehlo::Tensor>, long, bool, mlir::Region&, mlir::stablehlo::Process*, mlir::stablehlo::Scope&)::$_0&) ()
#4  0x000055555618847b in mlir::stablehlo::sortOp(llvm::ArrayRef<mlir::stablehlo::Tensor>, long, bool, mlir::Region&, mlir::stablehlo::Process*, mlir::stablehlo::Scope&) ()
#5  0x0000555556165552 in mlir::stablehlo::eval(mlir::Region&, llvm::ArrayRef<mlir::stablehlo::InterpreterValue>, mlir::stablehlo::InterpreterFallback*, mlir::stablehlo::Process*, mlir::stablehlo::Scope*) ()
#6  0x000055555612966f in mlir::stablehlo::evalModule(mlir::ModuleOp, llvm::ArrayRef<mlir::stablehlo::InterpreterValue>, mlir::stablehlo::InterpreterConfiguration const&) ()

Version information

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant