Skip to content

Commit

Permalink
Add support for varying vector shapes
Browse files Browse the repository at this point in the history
This PR adds support for expanding operators
with varying vector shapes, specifically for the
MMA case where either the same dimension has different
vector shapes in different mmas or if different instructions
are being used.

The idea is to insert a reshape operator whenever such a shape
mismatch is discovered.  The reshape operator lowers to an extract or
concatenate operation, depending on the context.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 29, 2024
1 parent 2b45c0f commit 181b93e
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 41 deletions.
80 changes: 46 additions & 34 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,101 +41,89 @@

def allocate(
shape: tuple[IndexExpr], dtype: DataType, address_space: IndexSymbol
) -> "Memory":
...
) -> "Memory": ...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
) -> "Register":
...
) -> "Register": ...


def extract_slice(
register: "Register",
offsets: tuple[IndexExpr],
sizes: tuple[IndexExpr],
strides: tuple[IndexExpr],
) -> "Register":
...
) -> "Register": ...


def shared_memory_barrier():
...
def shared_memory_barrier(): ...


def read(
memory: "Memory",
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
) -> "Register":
...
) -> "Register": ...


def reduction(
axis: IndexExpr, init_args: Sequence["Register"]
) -> Callable[[Callable[[AccT], AccT]], AccT]:
...
) -> Callable[[Callable[[AccT], AccT]], AccT]: ...


def register(shape: tuple[IndexExpr, ...], dtype: DataType, value: float) -> "Register":
...
def register(
shape: tuple[IndexExpr, ...], dtype: DataType, value: float
) -> "Register": ...


def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register":
...
def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": ...


def write(
register_: "Register",
memory: "Memory",
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
):
...
): ...


def exp2(src: "Register") -> "Register":
...
def exp2(src: "Register") -> "Register": ...


def maximum(lhs: "Register", rhs: "Register") -> "Register":
...
def maximum(lhs: "Register", rhs: "Register") -> "Register": ...


def broadcast(
arg: "Register", target_shape: Optional[IndexExpr | int] = None
) -> "Register":
...
) -> "Register": ...


def sum(
src: "Register",
acc: Optional["Register"] = None,
dim: Optional[IndexExpr | int] = None,
) -> "Register":
...
) -> "Register": ...


def max(
src: "Register",
acc: Optional["Register"] = None,
dim: Optional[IndexExpr | int] = None,
) -> "Register":
...
) -> "Register": ...


def shuffle(src: "Register", offset: int, width: int) -> "Register":
...
def shuffle(src: "Register", offset: int, width: int) -> "Register": ...


def cast(src: "Register", dtype: DataType) -> "Register":
...
def cast(src: "Register", dtype: DataType) -> "Register": ...


def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register":
...
def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": ...


def reshape(inputs: Sequence["Register"]) -> "Register": ...


def define_op(op_name: str) -> Callable[[T], T]:
Expand Down Expand Up @@ -1400,3 +1388,27 @@ def type(self) -> Register:
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
return input if isinstance(input, Sequence) else (input,)


@define_op("reshape")
@dataclass
class Reshape(CustomOp, ABC):
"""
Represents a reshape operation that reshapes
vectors along the same dimension.
"""

args: fx.Node | Sequence[fx.Node]
target_vector_shape: dict[IndexSymbol, int]

@property
def indexing_dims(self) -> list[IndexExpr]:
return get_custom(_to_sequence(self.args)[0]).indexing_dims

@property
def type(self) -> Register:
return get_custom(_to_sequence(self.args)[0]).type
34 changes: 34 additions & 0 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
scheduling_group_barrier,
cast,
permute,
reshape,
)
from ..lang.wave_types import IndexMapping, IndexSymbol
from ..compiler.base import CodegenError, ValidationError, NDEBUG
Expand Down Expand Up @@ -1310,3 +1311,36 @@ def handle_permute(emitter: WaveEmitter, node: fx.Node):
raise ValidationError("Malformed arguments") from e
vector_src = cast_py_value(emitter, register)
emitter.bind_node_proxy(node, vector_src)


@handle_op(reshape)
def handle_reshape(emitter: WaveEmitter, node: fx.Node):
try:
args, target_vector_shapes = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e
custom = get_custom(node)
innermost_dim = custom.type.symbolic_shape[-1]
offset = custom.expanded_dims[innermost_dim]

# Determine whether to extract or combine.
if len(args) == 1:
# Extract the appropriate slice.
size = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
vector = cast_vector(emitter, args[0])
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
vector,
[offset * size],
[size],
[1],
)
emitter.bind_node_proxy(node, IRProxyValue(slice))
return

raise NotImplementedError(
"reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes"
)
92 changes: 92 additions & 0 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,20 @@ def _expand_node(
new_node.expanded_dims = restricted_dims
new_node.fx_node.name = get_expanded_name(node, restricted_dims)

# For reshapes, we need more explicit control over how the arguments are expanded.
if isinstance(new_node, Reshape):
_expand_reshape(
new_node,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node

# Proceed with expansion of the arguments
for i, arg in node.node_args.items():
arg_list = arg
Expand Down Expand Up @@ -496,6 +510,84 @@ def _expand_mma_reduction(
return new_node


def _expand_reshape(
reshape: Reshape,
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
) -> CustomOp:
"""
When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together
for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and
the reshape wants to map it to m=4, n=4. So we start by expanding the node
node: {m = 0, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 1, n = 0}
arg: {m = 0, n = 0}
arg: {m = 0, n = 1}
node: {m = 2, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
node: {m = 3, n = 0}
arg: {m = 1, n = 0}
arg: {m = 1, n = 1}
...
In general,
For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node
using the following recipe:
- if m_src < m_dst, => we have a one to many mapping from source to destination
so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once.
- if m_src > m_dst, => we have a many to one mapping from source to destination
so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times.
In situations where the argument has been expanded along the same dimension, we reuse the expanded node
by making use of the context.
"""

dim_combinations = {}
for dim, value in dim_query.items():
if dim not in reshape.target_vector_shape:
continue
if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]:
scale_factor = (
reshape.target_vector_shape[dim] // reshape.vector_shapes[dim]
)
dim_combinations[dim] = [value // scale_factor]
else:
scale_factor = (
reshape.vector_shapes[dim] // reshape.target_vector_shape[dim]
)
begin = value * scale_factor
dim_combinations[dim] = list(range(begin, begin + scale_factor))
reshape_dim_combinations = list(itertools.product(*dim_combinations.values()))

new_args = []
for i, arg_dim_query in enumerate(reshape_dim_combinations):
arg_dim_query = {
dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query)
}
if isinstance(reshape.args, Sequence):
custom_arg = get_custom(reshape.args[i])
else:
custom_arg = get_custom(reshape.args)
new_node = _expand_node(
custom_arg,
trace,
arg_dim_query,
get_node_dim_scaling(custom_arg.fx_node),
context,
get_node_dim_scaling,
res_idx,
)
new_args.append(new_node.fx_node)

reshape.update_arg("args", new_args)


def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str:
"""Returns the name of a node with the dimensions appended."""

Expand Down
20 changes: 18 additions & 2 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import (
Allocate,
Write,
ExtractSlice,
get_custom,
Reduction,
MMA,
Placeholder,
IterArg,
CustomOp,
Reshape,
)
from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint
from .._support.tracing import CapturedTrace, IndexingContext
Expand Down Expand Up @@ -182,6 +185,19 @@ def is_contiguous_dim(
return is_innermost_dim or all_unit_dims


def add_reshape(custom: CustomOp):
for arg in custom.node_args.values():
if not isinstance(arg, Sequence):
arg = [arg]
for subarg in arg:
# These are ops in the parent graph that have not had their vector shapes set yet.
if subarg.vector_shapes is None:
continue
if subarg.vector_shapes != custom.vector_shapes:
with custom.graph.inserting_before(custom.fx_node):
Reshape(subarg, custom.vector_shapes).add_to_graph(custom.graph)


def set_vector_shapes(
constraints: Sequence[Constraint],
mma_index: dict[MMA, dict[IndexSymbol, int]],
Expand All @@ -193,8 +209,8 @@ def set_vector_shapes(
an MMA slice as well as the anchor node.
"""
custom = get_custom(node)
# MMA & Reduction nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction)):
# MMA, Reduction & Reshape nodes already have their vector shapes set.
if isinstance(custom, (MMA, Reduction, Reshape)):
return
# Add vector shapes from constraints to all ops. These are global constraints.
custom.vector_shapes = {}
Expand Down
Loading

0 comments on commit 181b93e

Please sign in to comment.