Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add StableHLO to linalg conversions to python bindings (#2660)
Testing function: ``` from mlir.dialects import stablehlo from mlir.ir import Context, Location, Module import mlir.dialects.arith from mlir.passmanager import PassManager mlir_text = """ func.func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %0 = "stablehlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = #stablehlo.dot< lhs_batching_dimensions = [1], lhs_contracting_dimensions = [2], rhs_batching_dimensions = [2], rhs_contracting_dimensions = [1] >, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>], someattr } : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> func.return %0 : tensor<?x?x?xf32> } """ with Context() as ctx: stablehlo.register_dialect(ctx) stablehlo.register_stablehlo_passes() module = Module.parse(mlir_text) pm = PassManager.parse( "builtin.module(func.func(" "shape-legalize-to-stablehlo," "stablehlo-aggressive-folder," "stablehlo-aggressive-simplification," "stablehlo-legalize-to-linalg" "))" ) pm.run(module.operation) print(f"{module}") ``` Before this change we get an error: ``` pm = PassManager.parse( ^^^^^^^^^^^^^^^^^^ ValueError: MLIR Textual PassPipeline Parser:1:103: error: 'stablehlo-legalize-to-linalg' does not refer to a registered pass or pass pipeline func.func(shape-legalize-to-stablehlo,stablehlo-aggressive-folder,stablehlo-aggressive-simplification,stablehlo-legalize-to-linalg) ``` Now we get an expected result: ``` #map = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3, d0)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> module { func.func @dot_general(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { %c1 = arith.constant 1 : index %dim = tensor.dim %arg0, %c1 : tensor<?x?x?xf32> %c0 = arith.constant 0 : index %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32> %c0_1 = arith.constant 0 : index %dim_2 = tensor.dim %arg1, %c0_1 : tensor<?x?x?xf32> %from_elements = tensor.from_elements %dim, %dim_0, %dim_2 : tensor<3xindex> %c0_3 = arith.constant 0 : index %extracted = tensor.extract %from_elements[%c0_3] : tensor<3xindex> %c1_4 = arith.constant 1 : index %extracted_5 = tensor.extract %from_elements[%c1_4] : tensor<3xindex> %c2 = arith.constant 2 : index %extracted_6 = tensor.extract %from_elements[%c2] : tensor<3xindex> %0 = tensor.empty(%extracted, %extracted_5, %extracted_6) : tensor<?x?x?xf32> %cst = arith.constant 0.000000e+00 : f32 %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%1 : tensor<?x?x?xf32>) attrs = {someattr} { ^bb0(%in: f32, %in_7: f32, %out: f32): %3 = arith.mulf %in, %in_7 : f32 %4 = arith.addf %out, %3 : f32 linalg.yield %4 : f32 } -> tensor<?x?x?xf32> return %2 : tensor<?x?x?xf32> } } ```
- Loading branch information