Skip to content

Commit

Permalink
Add support for chained matmuls
Browse files Browse the repository at this point in the history
This PR adds support for chained matmuls. In
order to do this, we implement the following
- Support for op-specific vector_shapes
- Support for op-specific dim expansions (reduction
  dims of mmas will now always have chained expansions,
  even if they are not the induction variable of
  reduction ops)
- Adds a permute operator
- Adds anchor and vector_shapes to each op
- Adds reduction_dim to MMA ops

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 24, 2024
1 parent 67b253a commit c443bf3
Show file tree
Hide file tree
Showing 13 changed files with 690 additions and 102 deletions.
89 changes: 82 additions & 7 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def cast(src: "Register", dtype: DataType) -> "Register":
...


def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register":
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -399,6 +403,10 @@ def copy(
new_node.index = copy.deepcopy(self.fx_node.index)
if new_name:
new_node.name = new_name
if hasattr(self.fx_node, "vector_shapes"):
new_node.vector_shapes = self.fx_node.vector_shapes
if hasattr(self.fx_node, "reduction_dim"):
new_node.reduction_dim = self.fx_node.reduction_dim
return get_custom(new_node)

def replace_all_uses_with(self, new_node: CustomOp | fx.Node):
Expand Down Expand Up @@ -520,12 +528,29 @@ def expanded_dims(self, value: dict[IndexSymbol, int]):
raise ValueError("Expanded dims must be a dict")
self.fx_node.expanded_dims = value

def post_expansion(self, constraints: list["Constraint"]) -> None:
@property
def anchor(self) -> fx.Node:
"""
Hook for post-expansion operations. This is called after the arguments
of the node are expanded.
The anchor is a node that provides information to the node
such as vector_shapes, indexing information etc.
"""
pass
if hasattr(self.fx_node, "anchor"):
return self.fx_node.anchor
return None

@anchor.setter
def anchor(self, value: fx.Node):
self.fx_node.anchor = value

@property
def vector_shapes(self) -> dict[IndexSymbol, int]:
if hasattr(self.fx_node, "vector_shapes"):
return self.fx_node.vector_shapes
return None

@vector_shapes.setter
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

def align_index(self, constraints: list["Constraint"]) -> None:
"""
Expand Down Expand Up @@ -883,6 +908,15 @@ def align_index(self, constraints: list["Constraint"]) -> None:

self.index = align_index_vars(self.index, constraints)

@property
def reduction_dim(self) -> IndexSymbol:
if hasattr(self.fx_node, "reduction_dim"):
return self.fx_node.reduction_dim

@reduction_dim.setter
def reduction_dim(self, value: IndexSymbol):
self.fx_node.reduction_dim = value


@define_op("read")
@dataclass
Expand Down Expand Up @@ -1033,9 +1067,11 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]:
return_vals = output.return_vals[0]
return (
[
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
(
get_custom(val).acc_index
if isinstance(get_custom(val), MMA)
else val.index
)
for val in return_vals
]
if isinstance(return_vals, (Sequence))
Expand Down Expand Up @@ -1309,3 +1345,42 @@ def indexing_dims(self) -> list[IndexSymbol]:
def type(self) -> Memory:
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]


@define_op("permute")
@dataclass
class Permute(CustomOp, ABC):
"""
Represents a permute operation that
permutes arg into the target shape.
"""

arg: fx.Node
target_shape: Sequence[IndexExpr]

@property
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
src_type = get_custom(self.arg).type
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]

@property
def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]:
"""
Computes the permuted index based on the target shape.
"""
src_type = get_custom(self.arg).type
dim_map = {
tgt: src for src, tgt in zip(src_type.symbolic_shape, self.target_shape)
}
return {tgt: get_custom(self.arg).index[src] for tgt, src in dim_map.items()}

@index.setter
def index(self, value: Any):
CustomOp.index.fset(self, value)
11 changes: 11 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
scheduling_barrier,
scheduling_group_barrier,
cast,
permute,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1299,3 +1300,13 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node):
)

emitter.bind_node_proxy(node, IRProxyValue(casted_vector))


@handle_op(permute)
def handle_permute(emitter: WaveEmitter, node: fx.Node):
try:
register, _ = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
vector_src = cast_py_value(emitter, register)
emitter.bind_node_proxy(node, vector_src)
Loading

0 comments on commit c443bf3

Please sign in to comment.