Skip to content

Commit

Permalink
Add StableHLO to linalg conversions to python bindings (#2660)
Browse files Browse the repository at this point in the history
Testing function:
```
from mlir.dialects import stablehlo
from mlir.ir import Context, Location, Module
import mlir.dialects.arith
from mlir.passmanager import PassManager

mlir_text = """
func.func @dot_general(%arg0: tensor<?x?x?xf32>,
                  %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
  %0 = "stablehlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #stablehlo.dot<
      lhs_batching_dimensions = [1],
      lhs_contracting_dimensions = [2],
      rhs_batching_dimensions = [2],
      rhs_contracting_dimensions = [1]
    >,
    precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>],
    someattr
  } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
  func.return %0 : tensor<?x?x?xf32>
}
"""

with Context() as ctx:
    stablehlo.register_dialect(ctx)
    stablehlo.register_stablehlo_passes()


    module = Module.parse(mlir_text)

    pm = PassManager.parse(
        "builtin.module(func.func("
        "shape-legalize-to-stablehlo,"
        "stablehlo-aggressive-folder,"
        "stablehlo-aggressive-simplification,"
        "stablehlo-legalize-to-linalg"
        "))"
    )

    pm.run(module.operation)

    print(f"{module}")
```

Before this change we get an error:
```
    pm = PassManager.parse(
         ^^^^^^^^^^^^^^^^^^
ValueError: MLIR Textual PassPipeline Parser:1:103: error: 'stablehlo-legalize-to-linalg' does not refer to a registered pass or pass pipeline
func.func(shape-legalize-to-stablehlo,stablehlo-aggressive-folder,stablehlo-aggressive-simplification,stablehlo-legalize-to-linalg)
```

Now we get an expected result:
```
#map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
module {
  func.func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
    %c0 = arith.constant 0 : index
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
    %c0_1 = arith.constant 0 : index
    %dim_2 = tensor.dim %arg1, %c0_1 : tensor<?x?x?xf32>
    %from_elements = tensor.from_elements %dim, %dim_0, %dim_2 : tensor<3xindex>
    %c0_3 = arith.constant 0 : index
    %extracted = tensor.extract %from_elements[%c0_3] : tensor<3xindex>
    %c1_4 = arith.constant 1 : index
    %extracted_5 = tensor.extract %from_elements[%c1_4] : tensor<3xindex>
    %c2 = arith.constant 2 : index
    %extracted_6 = tensor.extract %from_elements[%c2] : tensor<3xindex>
    %0 = tensor.empty(%extracted, %extracted_5, %extracted_6) : tensor<?x?x?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%1 : tensor<?x?x?xf32>) attrs =  {someattr} {
    ^bb0(%in: f32, %in_7: f32, %out: f32):
      %3 = arith.mulf %in, %in_7 : f32
      %4 = arith.addf %out, %3 : f32
      linalg.yield %4 : f32
    } -> tensor<?x?x?xf32>
    return %2 : tensor<?x?x?xf32>
  }
}
```
  • Loading branch information
mamanain authored Dec 9, 2024
1 parent 36bbe27 commit ef176a1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ cc_library(
hdrs = STABLEHLO_CAPI_HEADERS,
strip_include_prefix = ".",
deps = [
":linalg_passes",
":reference_api",
":reference_configuration",
":stablehlo_ops",
Expand Down Expand Up @@ -949,6 +950,7 @@ cc_library(
hdrs = STABLEHLO_CAPI_HEADERS,
strip_include_prefix = ".",
deps = [
":linalg_passes",
":reference_api",
":reference_configuration",
":stablehlo_ops",
Expand Down
6 changes: 5 additions & 1 deletion stablehlo/integrations/c/StablehloPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ limitations under the License.

#include "stablehlo/integrations/c/StablehloPasses.h"

#include "stablehlo/conversions/linalg/transforms/Passes.h"
#include "stablehlo/transforms/Passes.h"

void mlirRegisterAllStablehloPasses() { mlir::stablehlo::registerPasses(); }
void mlirRegisterAllStablehloPasses() {
mlir::stablehlo::registerPasses();
mlir::stablehlo::registerStablehloLinalgTransformsPasses();
}

0 comments on commit ef176a1

Please sign in to comment.