Skip to content

Commit

Permalink
add launch op
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Aug 25, 2023
1 parent 59187bc commit f7b159c
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 18 deletions.
138 changes: 133 additions & 5 deletions mlir_utils/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from functools import partial
from typing import Optional, Any

Expand All @@ -8,7 +9,12 @@
GPUModuleOp,
GPUFuncOp,
LaunchFuncOp,
LaunchOp,
ReturnOp,
AllReduceOp,
YieldOp,
TerminatorOp,
WaitOp,
)
from mlir.ir import (
Type,
Expand All @@ -22,10 +28,16 @@
Value,
)

from mlir_utils import types as T
from mlir_utils.dialects.ext.arith import constant
from mlir_utils.dialects.ext.func import FuncBase
from mlir_utils.dialects.gpu import block_id, module_end
from mlir_utils.meta import ModuleMeta, make_maybe_no_args_decorator, maybe_cast
from mlir_utils.meta import (
ModuleMeta,
make_maybe_no_args_decorator,
maybe_cast,
region_op,
)
from mlir_utils.util import get_user_code_loc, get_result_or_results


Expand Down Expand Up @@ -241,6 +253,82 @@ def __init__(
)
)

pass


class LaunchOp(LaunchOp):
def __init__(
self,
grid_size: tuple[Any, Any, Any],
block_size: tuple[Any, Any, Any],
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
*,
loc=None,
ip=None,
):
if loc is None:
loc = get_user_code_loc()
_ods_context = get_default_loc_context(loc)
if async_dependencies is None:
async_dependencies = []
results = [gpu_async_token()] * len(async_dependencies)
grid_size_x, grid_size_y, grid_size_z = grid_size
block_size_x, block_size_y, block_size_z = block_size

regions = None
_ods_successors = None
super().__init__(
self.build_generic(
attributes={},
results=results,
operands=[
async_dependencies,
grid_size_x,
grid_size_y,
grid_size_z,
block_size_x,
block_size_y,
block_size_z,
dynamic_shared_memory_size,
],
successors=_ods_successors,
regions=regions,
loc=loc,
ip=ip,
)
)


def launch_(
grid_size: tuple[Any, Any, Any],
block_size: tuple[Any, Any, Any],
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
*,
loc=None,
ip=None,
):
if loc is None:
loc = get_user_code_loc()
for size in [grid_size, block_size]:
for i, s in enumerate(size):
if isinstance(s, int):
size[i] = constant(s, index=True)
launch_op = LaunchOp(
grid_size,
block_size,
async_dependencies,
dynamic_shared_memory_size,
loc=loc,
ip=ip,
)
launch_op.regions[0].blocks.append(*[T.index for _ in range(12)])
return launch_op


launch = region_op(launch_, terminator=lambda *args: TerminatorOp())


class LaunchFuncOp(LaunchFuncOp):
def __init__(
Expand All @@ -251,6 +339,7 @@ def __init__(
kernel_operands: list[Value] = None,
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
async_object=None,
*,
loc=None,
ip=None,
Expand All @@ -272,7 +361,6 @@ def __init__(
if async_dependencies is None:
async_dependencies = []
results = [gpu_async_token()] * len(async_dependencies)
async_object = None
grid_size_x, grid_size_y, grid_size_z = grid_size
block_size_x, block_size_y, block_size_z = block_size

Expand Down Expand Up @@ -310,6 +398,7 @@ def __call__(
block_size: tuple[Any, Any, Any],
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
stream=None,
):
for size in [grid_size, block_size]:
for i, s in enumerate(size):
Expand All @@ -328,6 +417,7 @@ def __call__(
kernel_operands,
async_dependencies,
dynamic_shared_memory_size,
async_object=stream,
loc=loc,
)
)
Expand All @@ -339,10 +429,20 @@ def __init__(self, func):
self.func = func

def __getitem__(self, item):
previous_frame = inspect.currentframe().f_back
var_names = [
[
var_name
for var_name, var_val in previous_frame.f_locals.items()
if var_val is arg
]
for arg in item
]
kwargs = {}
for it in item:
k, v = it.start, it.stop
kwargs[k] = v
for i, it in enumerate(item):
assert len(var_names[i]) == 1, "expected unique kwarg"
k = var_names[i][0]
kwargs[k] = it

return partial(self.func, **kwargs)

Expand Down Expand Up @@ -388,3 +488,31 @@ def gpu_func(
if emit:
func.emit()
return Grid(func)


def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return AllReduceOp(value, op=op, uniform=uniform, loc=loc, ip=ip)


def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
return maybe_cast(
get_result_or_results(
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
)
)


all_reduce = region_op(all_reduce__, terminator=YieldOp)


def wait(async_dependencies: Optional[list[Value]] = None, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
if async_dependencies is None:
async_dependencies = []
async_token = gpu_async_token()
return maybe_cast(
get_result_or_results(WaitOp(async_token, async_dependencies, loc=loc, ip=ip))
)
2 changes: 1 addition & 1 deletion mlir_utils/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __setitem__(self, idx, source):
if not self.has_rank():
raise ValueError("only ranked memref slicing/indexing supported")

idx = list((idx,) if isinstance(idx, int) else idx)
idx = list((idx,) if isinstance(idx, (Scalar, int, Value)) else idx)
for i, d in enumerate(idx):
if isinstance(d, int):
idx[i] = constant(d, index=True, loc=loc)
Expand Down
2 changes: 1 addition & 1 deletion mlir_utils/dialects/ext/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def match(
return maybe_cast(
get_result_or_results(
MatchOp(
T.transform_any_op(),
T.transform_any_op,
target,
ops=ops,
interface=interface,
Expand Down
23 changes: 14 additions & 9 deletions mlir_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
except ImportError:
warnings.warn("no transform dialect registered; transform extensions won't work")


_index = lambda: IndexType.get()
_bool = lambda: IntegerType.get_signless(1)

Expand Down Expand Up @@ -72,6 +71,14 @@
opaque = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)


def _transform_any_op():
return transform.AnyOpType.get()


def _llvm_ptr():
return Type.parse("!llvm.ptr")


def placeholder_opaque():
return opaque("scf", "placeholder")

Expand Down Expand Up @@ -103,6 +110,8 @@ def placeholder_opaque():
"cmp64": _cmp64,
"none": _none,
"pdl_operation": _pdl_operation,
"transform_any_op": _transform_any_op,
"llvm_ptr": _llvm_ptr,
}


Expand Down Expand Up @@ -174,13 +183,13 @@ def infer_mlir_type(
if isinstance(py_val, bool):
return _bool()
elif isinstance(py_val, int):
if -(2 ** 31) <= py_val < 2 ** 31:
if -(2**31) <= py_val < 2**31:
return _i32()
elif 2 ** 31 <= py_val < 2 ** 32:
elif 2**31 <= py_val < 2**32:
return _ui32()
elif -(2 ** 63) <= py_val < 2 ** 63:
elif -(2**63) <= py_val < 2**63:
return _i64()
elif 2 ** 63 <= py_val < 2 ** 64:
elif 2**63 <= py_val < 2**64:
return _ui64()
else:
raise RuntimeError(f"Nonrepresentable integer {py_val}.")
Expand Down Expand Up @@ -282,7 +291,3 @@ def memref_type_to_np_dtype(memref_type):

def transform_op(name):
return transform.OperationType.get(name)


def transform_any_op():
return transform.AnyOpType.get()
Loading

0 comments on commit f7b159c

Please sign in to comment.