Skip to content

The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings.

License

Notifications You must be signed in to change notification settings

makslevental/mlir-python-extras

Repository files navigation

mlir-python-extras

The missing pieces (as far as boilerplate reduction goes) of the MLIR python bindings.

TL;DR

Full example at examples/mwe.py (i.e., go there if you want to copy-paste).

Turn this

K = 10
memref_i64 = T.memref(K, K, T.i64)

@func
@canonicalize(using=scf)
def memfoo(A: memref_i64, B: memref_i64, C: memref_i64):
    one = constant(1)
    two = constant(2)
    if one > two:
        three = constant(3)
    else:
        for i in range(0, K):
            for j in range(0, K):
                C[i, j] = A[i, j] * B[i, j]

into this

func.func @memfoo(%arg0: memref<10x10xi64>, %arg1: memref<10x10xi64>, %arg2: memref<10x10xi64>) {
  %c1_i32 = arith.constant 1 : i32
  %c2_i32 = arith.constant 2 : i32
  %0 = arith.cmpi ugt, %c1_i32, %c2_i32 : i32
  scf.if %0 {
    %c3_i32 = arith.constant 3 : i32
  } else {
    %c0 = arith.constant 0 : index
    %c10 = arith.constant 10 : index
    %c1 = arith.constant 1 : index
    scf.for %arg3 = %c0 to %c10 step %c1 {
      scf.for %arg4 = %c0 to %c10 step %c1 {
        %1 = memref.load %arg0[%arg3, %arg4] : memref<10x10xi64>
        %2 = memref.load %arg1[%arg3, %arg4] : memref<10x10xi64>
        %3 = arith.muli %1, %2 : i64
        memref.store %3, %arg2[%arg3, %arg4] : memref<10x10xi64>
      }
    }
  }
  return
}

then run it like this

module = backend.compile(
    ctx.module,
    kernel_name=memfoo.__name__,
    pipeline=Pipeline().bufferize().lower_to_llvm(),
)

A = np.random.randint(0, 10, (K, K))
B = np.random.randint(0, 10, (K, K))
C = np.zeros((K, K), dtype=int)

backend.load(module).memfoo(A, B, C)
assert np.array_equal(A * B, C)

5s Intro

This is not a Python compiler, but just a (hopefully) nice way to emit MLIR using python.

The few main features/affordances:

  1. region_ops (like @func above)
     
    1. These are decorators around ops (bindings for MLIR operations) that have regions (e.g., in_parallel). They turn decorated functions, by executing them "eagerly", into an instance of such an op, e.g.,
      @func
      def foo(x: T.i32):
         return
      becomes func.func @foo(%arg0: i32) { }; if the region carrying op produces a result, the identifier for the python function (foo) becomes the corresponding ir.Value of the result (if the op doesn't produce a result then the identifier becomes the corresponding ir.OpView).

      See mlir_extras.util.op_region_builder for details.
       
  2. @canonicalize (like @canonicalize(using=scf) above)
     
    1. These are decorators that rewrite the python AST. They transform a select few forms (basically only ifs) into a more "canonical" form, in order to more easily map to MLIR. If that scares you, fear not; they are not essential and all target MLIR can still be mapped to without using them (by using the slightly more verbose region_op).

      See mlir_extras.ast.canonicalize for details.
       
  3. mlir_extras.types (like T.memref(K, K, T.i64) above)
     
    1. These are just convenient wrappers around upstream type constructors. Note, because MLIR types are uniqued to a ir.Context, these are all actually functions that return the type (yes, even T.i64, which uses __getattr__ on the module).

      See mlir_extras.types for details.
       
  4. Pipeline()
     
    1. This is just a (generated) wrapper around available upstream passes; it can be used to build pass pipelines (by str(Pipeline())). It is mainly convenient with IDEs/editors that will tab-complete the available methods on the Pipeline class (which correspond to passes), Note, if your host bindings don't register some upstream passes, then this will generate "illegal" pass pipelines.

      See mlir_extras._configuration.generate_pass_pipeline.py for details on generation mlir_extras.runtime.passes.py for the passes themselves.
       

Note, also, there are no docs (because ain't no one got time for that) but that shouldn't be a problem because the package is designed such that you can use/reuse only the pieces/parts you want/understand. But, open an issue if something isn't clear.

Install

This package is meant to work in concert with host bindings. Practically speaking that means you need to have some package installed that includes mlir python bindings.

So

$ HOST_MLIR_PYTHON_PACKAGE_PREFIX=<YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX> pip install git+https://github.com/makslevental/mlir-python-extras

where YOUR_HOST_MLIR_PYTHON_PACKAGE_PREFIX is (as it says) the package prefix for your chosen host bindings. When in doubt about this prefix, it is everything up until ir when you import your bindings, e.g., in import torch_mlir.ir, torch_mlir is the HOST_MLIR_PYTHON_PACKAGE_PREFIX for the torch-mlir bindings.

If you don't have any such package, but you want to experiment anyway, you can install the "stock" upstream bindings first:

$ pip install mlir-python-bindings -f https://makslevental.github.io/wheels/

and then

$ pip install git+https://github.com/makslevental/mlir-python-extras

Examples/Demo

Check tests for a plethora of example code.