Skip to content

Commit

Permalink
Revamp StableHLO simplification paterns.
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Oct 11, 2024
1 parent 15324c4 commit 7029b03
Show file tree
Hide file tree
Showing 11 changed files with 1,330 additions and 1,040 deletions.
16 changes: 16 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,21 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_aggressive_simplification_inc_gen",
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
)

gentbl_cc_library(
name = "stablehlo_legalize_deprecated_ops_inc_gen",
tbl_outs = [
Expand Down Expand Up @@ -1127,6 +1142,7 @@ cc_library(
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
Expand Down
156 changes: 156 additions & 0 deletions stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,143 @@
// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s

////////
// AddOp

// CHECK-LABEL: @add_fold_cst
func.func @add_fold_cst() -> (tensor<i32>, tensor<f32>) {
%cst = stablehlo.constant dense<1> : tensor<i32>
%cst_1 = stablehlo.constant dense<1.0> : tensor<f32>
// CHECK: stablehlo.constant dense<2> : tensor<i32>
// CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
%0 = stablehlo.add %cst, %cst : tensor<i32>
%1 = stablehlo.add %cst_1, %cst_1 : tensor<f32>
return %0, %1 : tensor<i32>, tensor<f32>
}

// -----

////////
// BroadcastInDimOp

// CHECK-LABEL: func.func @broadcast_in_dim_fold_splat
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>)
-> (tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>) {
%c0 = stablehlo.constant dense<5> : tensor<i32>
%c1 = stablehlo.constant dense<3.0> : tensor<f32>
%c2 = stablehlo.constant dense<1> : tensor<1x3xi32>

%0 = stablehlo.broadcast_in_dim %c0, dims = [] : (tensor<i32>) -> tensor<6xi32>
%1 = stablehlo.broadcast_in_dim %c1, dims = [] : (tensor<f32>) -> tensor<3xf32>
%2 = stablehlo.broadcast_in_dim %c2, dims = [1, 0] : (tensor<1x3xi32>) -> tensor<3x3xi32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<5> : tensor<6xi32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3xf32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<1> : tensor<3x3xi32>

// CHECK-NEXT: return [[R0]], [[R1]], [[R2]]
return %0, %1, %2 : tensor<6xi32>, tensor<3xf32>, tensor<3x3xi32>
}

// -----

////////
// CompareOp

// CHECK-LABEL: func.func @compare_folds
func.func @compare_folds()
-> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
%cn1 = stablehlo.constant dense<-1> : tensor<i32>
%c0 = stablehlo.constant dense<0> : tensor<i32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<5> : tensor<i32>

%0 = stablehlo.compare EQ, %cn1, %cn1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = stablehlo.compare GT, %c5, %c5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.compare GE, %c4, %cn1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare LE, %c4, %c5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%4 = stablehlo.compare EQ, %cn1, %cn1, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%5 = stablehlo.compare GT, %c5, %cn1, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare GE, %c5, %c4, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.compare LE, %cn1, %c5, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

// CHECK-DAG: [[FALSE:%.+]] = stablehlo.constant dense<false> : tensor<i1>
// CHECK-DAG: [[TRUE:%.+]] = stablehlo.constant dense<true> : tensor<i1>

// CHECK-NEXT: return [[TRUE]], [[FALSE]], [[TRUE]], [[TRUE]], [[TRUE]], [[FALSE]], [[TRUE]], [[FALSE]]
return %0, %1, %2, %3, %4, %5, %6, %7 :
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
}


// -----

////////
// ConcatenateOp

// CHECK-LABEL: func.func @concatenate_fold
func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>) {
%c0 = stablehlo.constant dense<[0, 1]> : tensor<2xi32>
%c1 = stablehlo.constant dense<[2, 3, 4]> : tensor<3xi32>
%c2 = stablehlo.constant dense<[5]> : tensor<1xi32>

%c3 = stablehlo.constant dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi32>
%c4 = stablehlo.constant dense<[[6, 7, 8]]> : tensor<1x3xi32>
%c5 = stablehlo.constant dense<[[11, 12], [13, 14]]> : tensor<2x2xi32>

%0 = stablehlo.concatenate %c0, %c1, %c2, dim = 0 : (tensor<2xi32>, tensor<3xi32>, tensor<1xi32>) -> tensor<6xi32>
%1 = stablehlo.concatenate %c0, %c2, dim = 0 : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32>

%2 = stablehlo.concatenate %c3, %c4, dim = 0 : (tensor<2x3xi32>, tensor<1x3xi32>) -> tensor<3x3xi32>
%3 = stablehlo.concatenate %c3, %c5, dim = 1 : (tensor<2x3xi32>, tensor<2x2xi32>) -> tensor<2x5xi32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32>
// CHECK-DAG: [[R1:%.+]] = stablehlo.constant dense<[0, 1, 5]> : tensor<3xi32>
// CHECK-DAG: [[R2:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2\], \[3, 4, 5\], \[6, 7, 8\]\]}}> : tensor<3x3xi32>
// CHECK-DAG: [[R3:%.+]] = stablehlo.constant dense<{{\[\[0, 1, 2, 11, 12\], \[3, 4, 5, 13, 14\]\]}}> : tensor<2x5xi32>
// CHECK-NEXT: return [[R0]], [[R1]], [[R2]], [[R3]]
return %0, %1, %2, %3 : tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, tensor<2x5xi32>
}

// -----

////////
// MulOp

// CHECK-LABEL: @mul_fold_cst
func.func @mul_fold_cst() -> (tensor<i32>, tensor<f32>) {
%cst = stablehlo.constant dense<2> : tensor<i32>
%cst_1 = stablehlo.constant dense<2.0> : tensor<f32>
// CHECK: stablehlo.constant dense<4> : tensor<i32>
// CHECK: stablehlo.constant dense<4.0{{.*}}> : tensor<f32>
%0 = stablehlo.multiply %cst, %cst : tensor<i32>
%1 = stablehlo.multiply %cst_1, %cst_1 : tensor<f32>
return %0, %1 : tensor<i32>, tensor<f32>
}

// -----

////////
// SubtractOp

// CHECK-LABEL: @subtract_fold_cst
func.func @subtract_fold_cst() -> (tensor<i32>, tensor<f32>) {
%cst = stablehlo.constant dense<1> : tensor<i32>
%cst_1 = stablehlo.constant dense<3> : tensor<i32>
%cst_2 = stablehlo.constant dense<1.0> : tensor<f32>
%cst_3 = stablehlo.constant dense<3.0> : tensor<f32>
// CHECK: stablehlo.constant dense<2> : tensor<i32>
// CHECK: stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
%0 = stablehlo.subtract %cst_1, %cst : tensor<i32>
%1 = stablehlo.subtract %cst_3, %cst_2 : tensor<f32>
return %0, %1 : tensor<i32>, tensor<f32>
}

// -----

////////
// IotaOp

// CHECK-LABEL: func @eval_iota
func.func @eval_iota() -> (tensor<3x4x5xi32>, tensor<3x4x5xi32>, tensor<3x4x5xi32>) {
Expand Down Expand Up @@ -41,6 +179,24 @@ func.func @eval_iota_zero_dimension() -> (tensor<0xi32>, tensor<5x0x2xi32>) {

// -----

////////
// ReshapeOp

// CHECK-LABEL: func @reshape
func.func @reshape_fold() -> (tensor<1xi32>, tensor<2x2xi32>) {
%c0 = stablehlo.constant dense<2> : tensor<i32>
%c1 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%0 = stablehlo.reshape %c0 : (tensor<i32>) -> tensor<1xi32>
%1 = stablehlo.reshape %c1 : (tensor<4xi32>) -> tensor<2x2xi32>

// CHECK-DAG: [[CST1:%.+]] = stablehlo.constant dense<2> : tensor<1xi32>
// CHECK-DAG: [[CST2:%.+]] = stablehlo.constant dense<{{\[\[1, 2\], \[3, 4\]\]}}> : tensor<2x2xi32>
// CHECK-NEXT: return [[CST1]], [[CST2]]
return %0, %1 : tensor<1xi32>, tensor<2x2xi32>
}

// -----

// CHECK-LABEL: func @eval_convert_f32_to_i64
func.func @eval_convert_f32_to_i64() -> tensor<2xi64> {
// CHECK-NOT: stablehlo.convert
Expand Down
Loading

0 comments on commit 7029b03

Please sign in to comment.