Skip to content

Commit

Permalink
Add support for chained matmuls
Browse files Browse the repository at this point in the history
This PR adds support for chained matmuls. In
order to do this, we implement the following
- Support for op-specific vector_shapes
- Support for op-specific dim expansions (reduction
  dims of mmas will now always have chained expansions,
  even if they are not the induction variable of
  reduction ops)
- Adds a permute operator

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 24, 2024
1 parent 67b253a commit 2785de4
Show file tree
Hide file tree
Showing 13 changed files with 727 additions and 102 deletions.
126 changes: 119 additions & 7 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def cast(src: "Register", dtype: DataType) -> "Register":
...


def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register":
...


def expansion_conflict(src: "Register", vector_shapes: dict[IndexSymbol, int]):
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -260,6 +268,7 @@ def new_function(*args: Any, **kwargs: dict[str, Any]):
def get_custom(node: fx.Node) -> "CustomOp":
"""Get the corresponding CustomOp for a given fx.Node."""
if isinstance(node, CustomOp):
breakpoint()
print("Careful! You passed a custom op where an fx.Node was required.")
return node
if not isinstance(node, fx.Node):
Expand Down Expand Up @@ -399,6 +408,10 @@ def copy(
new_node.index = copy.deepcopy(self.fx_node.index)
if new_name:
new_node.name = new_name
if hasattr(self.fx_node, "vector_shapes"):
new_node.vector_shapes = self.fx_node.vector_shapes
if hasattr(self.fx_node, "reduction_dim"):
new_node.reduction_dim = self.fx_node.reduction_dim
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
Expand Down Expand Up @@ -520,12 +533,29 @@ def expanded_dims(self, value: dict[IndexSymbol, int]):
raise ValueError("Expanded dims must be a dict")
self.fx_node.expanded_dims = value

def post_expansion(self, constraints: list["Constraint"]) -> None:
@property
def anchor(self) -> fx.Node:
"""
Hook for post-expansion operations. This is called after the arguments
of the node are expanded.
The anchor is a node that provides information to the node
such as vector_shapes, indexing information etc.
"""
pass
if hasattr(self.fx_node, "anchor"):
return self.fx_node.anchor
return None

@anchor.setter
def anchor(self, value: fx.Node):
self.fx_node.anchor = value

@property
def vector_shapes(self) -> dict[IndexSymbol, int]:
if hasattr(self.fx_node, "vector_shapes"):
return self.fx_node.vector_shapes
return None

@vector_shapes.setter
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

def align_index(self, constraints: list["Constraint"]) -> None:
"""
Expand Down Expand Up @@ -883,6 +913,15 @@ def align_index(self, constraints: list["Constraint"]) -> None:

self.index = align_index_vars(self.index, constraints)

@property
def reduction_dim(self) -> IndexSymbol:
if hasattr(self.fx_node, "reduction_dim"):
return self.fx_node.reduction_dim

@reduction_dim.setter
def reduction_dim(self, value: IndexSymbol):
self.fx_node.reduction_dim = value


@define_op("read")
@dataclass
Expand Down Expand Up @@ -1033,9 +1072,11 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
return_vals = output.return_vals[0]
return (
[
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
(
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
)
for val in return_vals
]
if isinstance(return_vals, (Sequence))
Expand Down Expand Up @@ -1309,3 +1350,74 @@ def indexing_dims(self) -> list[IndexSymbol]:
def type(self) -> Memory:
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]


@define_op("permute")
@dataclass
class Permute(CustomOp, ABC):
"""
Represents a permute operation that
permutes arg into the target shape.
"""

arg: fx.Node
target_shape: Sequence[IndexExpr]

@property
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
src_type = get_custom(self.arg).type
return Register[*self.target_shape, src_type.dtype]

@property
def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]:
"""
Computes the permuted index based on the target shape.
"""
src_shape = get_custom(self.arg).type.symbolic_shape
target_shape = self.target_shape
dim_map: dict[IndexSymbol, IndexSymbol] = {}
for dim_src, dim_target in zip(src_shape, target_shape):
dim_map[dim_target] = dim_src
src_index = get_custom(self.arg).index
target_index = dict(src_index)
for dim_target, dim_src in dim_map.items():
target_index[dim_target] = src_index[dim_src]
return target_index

@index.setter
def index(self, value: Any):
"""
Updates the index of the node based on a per-dimension index sequence.
"""
if value is None:
return
if isinstance(value, dict):
assert all(
isinstance(v, IndexSequence) for v in value.values()
), f"Index must be a dict with values of type IndexSequence"
self.fx_node.index = {}
for dim, key in value.items():
self.fx_node.index[dim] = key
elif isinstance(value, list):
self.fx_node.index = value
else:
raise ValueError("Index must be a dict")


@define_op("expansion_conflict")
@dataclass
class ExpansionConflict(CustomOp, ABC):
arg: fx.Node
vector_shapes: dict[IndexSymbol, int]

@property
def indexing_dims(self) -> list[IndexExpr]:
return []

@property
def type(self) -> Register:
return self.arg.type
11 changes: 11 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
scheduling_barrier,
scheduling_group_barrier,
cast,
permute,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1299,3 +1300,13 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node):
)

emitter.bind_node_proxy(node, IRProxyValue(casted_vector))


@handle_op(permute)
def handle_permute(emitter: WaveEmitter, node: fx.Node):
try:
register, _ = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
vector_src = cast_py_value(emitter, register)
emitter.bind_node_proxy(node, vector_src)
Loading

0 comments on commit 2785de4

Please sign in to comment.