diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 1dad0160..ff098c44 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -41,15 +41,13 @@ def allocate( shape: tuple[IndexExpr], dtype: DataType, address_space: IndexSymbol -) -> "Memory": - ... +) -> "Memory": ... def extract( register: "Register", offsets: tuple[IndexExpr], -) -> "Register": - ... +) -> "Register": ... def extract_slice( @@ -57,34 +55,30 @@ def extract_slice( offsets: tuple[IndexExpr], sizes: tuple[IndexExpr], strides: tuple[IndexExpr], -) -> "Register": - ... +) -> "Register": ... -def shared_memory_barrier(): - ... +def shared_memory_barrier(): ... def read( memory: "Memory", elements_per_thread: Optional[IndexExpr | int] = None, mapping: Optional[IndexMapping] = None, -) -> "Register": - ... +) -> "Register": ... def reduction( axis: IndexExpr, init_args: Sequence["Register"] -) -> Callable[[Callable[[AccT], AccT]], AccT]: - ... +) -> Callable[[Callable[[AccT], AccT]], AccT]: ... -def register(shape: tuple[IndexExpr, ...], dtype: DataType, value: float) -> "Register": - ... +def register( + shape: tuple[IndexExpr, ...], dtype: DataType, value: float +) -> "Register": ... -def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": - ... +def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": ... def write( @@ -92,50 +86,44 @@ def write( memory: "Memory", elements_per_thread: Optional[IndexExpr | int] = None, mapping: Optional[IndexMapping] = None, -): - ... +): ... -def exp2(src: "Register") -> "Register": - ... +def exp2(src: "Register") -> "Register": ... -def maximum(lhs: "Register", rhs: "Register") -> "Register": - ... +def maximum(lhs: "Register", rhs: "Register") -> "Register": ... def broadcast( arg: "Register", target_shape: Optional[IndexExpr | int] = None -) -> "Register": - ... +) -> "Register": ... def sum( src: "Register", acc: Optional["Register"] = None, dim: Optional[IndexExpr | int] = None, -) -> "Register": - ... +) -> "Register": ... def max( src: "Register", acc: Optional["Register"] = None, dim: Optional[IndexExpr | int] = None, -) -> "Register": - ... +) -> "Register": ... -def shuffle(src: "Register", offset: int, width: int) -> "Register": - ... +def shuffle(src: "Register", offset: int, width: int) -> "Register": ... -def cast(src: "Register", dtype: DataType) -> "Register": - ... +def cast(src: "Register", dtype: DataType) -> "Register": ... -def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": - ... +def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": ... + + +def reshape(inputs: Sequence["Register"]) -> "Register": ... def define_op(op_name: str) -> Callable[[T], T]: @@ -1400,3 +1388,27 @@ def type(self) -> Register: self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" return Register[*self.target_shape, src_type.dtype] + + +def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]: + return input if isinstance(input, Sequence) else (input,) + + +@define_op("reshape") +@dataclass +class Reshape(CustomOp, ABC): + """ + Represents a reshape operation that reshapes + vectors along the same dimension. + """ + + args: fx.Node | Sequence[fx.Node] + target_vector_shape: dict[IndexSymbol, int] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return get_custom(_to_sequence(self.args)[0]).indexing_dims + + @property + def type(self) -> Register: + return get_custom(_to_sequence(self.args)[0]).type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 5a790e06..7a4e5091 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -67,6 +67,7 @@ scheduling_group_barrier, cast, permute, + reshape, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -1310,3 +1311,36 @@ def handle_permute(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e vector_src = cast_py_value(emitter, register) emitter.bind_node_proxy(node, vector_src) + + +@handle_op(reshape) +def handle_reshape(emitter: WaveEmitter, node: fx.Node): + try: + args, target_vector_shapes = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + custom = get_custom(node) + innermost_dim = custom.type.symbolic_shape[-1] + offset = custom.expanded_dims[innermost_dim] + + # Determine whether to extract or combine. + if len(args) == 1: + # Extract the appropriate slice. + size = ( + target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] + ) + vector = cast_vector(emitter, args[0]) + result_type = VectorType.get([size], vector.type.element_type) + slice = vector_d.extract_strided_slice( + result_type, + vector, + [offset * size], + [size], + [1], + ) + emitter.bind_node_proxy(node, IRProxyValue(slice)) + return + + raise NotImplementedError( + "reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes" + ) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 6ba9b52d..55c1db85 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -250,6 +250,20 @@ def _expand_node( new_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) + # For reshapes, we need more explicit control over how the arguments are expanded. + if isinstance(new_node, Reshape): + _expand_reshape( + new_node, + trace, + dim_query, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node + return new_node + # Proceed with expansion of the arguments for i, arg in node.node_args.items(): arg_list = arg @@ -496,6 +510,84 @@ def _expand_mma_reduction( return new_node +def _expand_reshape( + reshape: Reshape, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + """ + When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together + for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and + the reshape wants to map it to m=4, n=4. So we start by expanding the node + node: {m = 0, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 1, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 2, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + node: {m = 3, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + ... + In general, + For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node + using the following recipe: + - if m_src < m_dst, => we have a one to many mapping from source to destination + so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once. + - if m_src > m_dst, => we have a many to one mapping from source to destination + so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times. + + In situations where the argument has been expanded along the same dimension, we reuse the expanded node + by making use of the context. + """ + + dim_combinations = {} + for dim, value in dim_query.items(): + if dim not in reshape.target_vector_shape: + continue + if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]: + scale_factor = ( + reshape.target_vector_shape[dim] // reshape.vector_shapes[dim] + ) + dim_combinations[dim] = [value // scale_factor] + else: + scale_factor = ( + reshape.vector_shapes[dim] // reshape.target_vector_shape[dim] + ) + begin = value * scale_factor + dim_combinations[dim] = list(range(begin, begin + scale_factor)) + reshape_dim_combinations = list(itertools.product(*dim_combinations.values())) + + new_args = [] + for i, arg_dim_query in enumerate(reshape_dim_combinations): + arg_dim_query = { + dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query) + } + if isinstance(reshape.args, Sequence): + custom_arg = get_custom(reshape.args[i]) + else: + custom_arg = get_custom(reshape.args) + new_node = _expand_node( + custom_arg, + trace, + arg_dim_query, + get_node_dim_scaling(custom_arg.fx_node), + context, + get_node_dim_scaling, + res_idx, + ) + new_args.append(new_node.fx_node) + + reshape.update_arg("args", new_args) + + def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: """Returns the name of a node with the dimensions appended.""" diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 140c0ac6..9923f873 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ..ops.wave_ops import ( + Allocate, Write, ExtractSlice, get_custom, @@ -12,6 +13,8 @@ MMA, Placeholder, IterArg, + CustomOp, + Reshape, ) from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .._support.tracing import CapturedTrace, IndexingContext @@ -182,6 +185,19 @@ def is_contiguous_dim( return is_innermost_dim or all_unit_dims +def add_reshape(custom: CustomOp): + for arg in custom.node_args.values(): + if not isinstance(arg, Sequence): + arg = [arg] + for subarg in arg: + # These are ops in the parent graph that have not had their vector shapes set yet. + if subarg.vector_shapes is None: + continue + if subarg.vector_shapes != custom.vector_shapes: + with custom.graph.inserting_before(custom.fx_node): + Reshape(subarg, custom.vector_shapes).add_to_graph(custom.graph) + + def set_vector_shapes( constraints: Sequence[Constraint], mma_index: dict[MMA, dict[IndexSymbol, int]], @@ -193,8 +209,8 @@ def set_vector_shapes( an MMA slice as well as the anchor node. """ custom = get_custom(node) - # MMA & Reduction nodes already have their vector shapes set. - if isinstance(custom, (MMA, Reduction)): + # MMA, Reduction & Reshape nodes already have their vector shapes set. + if isinstance(custom, (MMA, Reduction, Reshape)): return # Add vector shapes from constraints to all ops. These are global constraints. custom.vector_shapes = {} diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index b3f229e5..6ad4798e 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -12,7 +12,7 @@ UnitAttr, Value, ) -from typing import Optional, Callable, Any, List, Tuple +from typing import Optional, Callable, Any, List, Tuple, Sequence from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * @@ -25,6 +25,7 @@ Reduction, GetResult, IterArg, + Reshape, ) from .constraints import ( Constraint, @@ -192,6 +193,22 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) +def is_reshape_needed( + node: CustomOp, + node_vector_shapes: dict[IndexSymbol, int], + vector_shapes: dict[IndexSymbol, int], +) -> bool: + for dim in node.type.symbolic_shape: + if dim not in vector_shapes: + continue + if node_vector_shapes[dim] != vector_shapes[dim]: + print( + f"Reshape needed for {node} due to {dim}, {node_vector_shapes[dim]} != {vector_shapes[dim]}" + ) + return True + return False + + def get_mma_dimensional_mapping( trace: CapturedTrace, hardware_constraint: HardwareConstraint, @@ -213,6 +230,7 @@ def is_mma(node): mapping: dict[MMA, dict[IndexSymbol, int]] = {} mma_nodes = trace.walk(is_mma) + last_vector_shapes = {} for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] @@ -234,6 +252,19 @@ def is_mma(node): custom.vector_shapes.update(hardware_constraint.vector_shapes) custom.anchor = custom custom.reduction_dim = k + if last_vector_shapes and last_vector_shapes != custom.vector_shapes: + with custom.graph.inserting_before(custom.fx_node): + for i, arg in custom.node_args.items(): + if is_reshape_needed(arg, custom.vector_shapes, last_vector_shapes): + reshape = Reshape(arg.fx_node, last_vector_shapes).add_to_graph( + custom.graph + ) + custom_reshape = get_custom(reshape) + custom_reshape.vector_shapes = custom.vector_shapes + custom_reshape.anchor = custom + custom.update_arg(i, reshape) + + last_vector_shapes = custom.vector_shapes # Since expansion proceeds bottom-up, we set the vector shapes # of the parent reduction to the vector shapes of the last MMA node. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index ceff8128..a3a78e46 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -6,7 +6,11 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.lang.global_symbols import * -from iree.turbine.kernel.wave.utils import run_test +from iree.turbine.kernel.wave.utils import ( + run_test, + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) import torch M = tkl.sym.M @@ -763,11 +767,12 @@ def test_chained_gemm(): constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + mfma_variant = tkw.MMAType.F32_32x32x8_F16 constraints += [ tkw.HardwareConstraint( threads_per_wave=64, waves_per_block=(2, 2, 1), - mma_type=tkw.MMAType.F32_16x16x16_F16, + mma_type=mfma_variant, vector_shapes={B: 0}, ) ] @@ -808,8 +813,8 @@ def repeat( BLOCK_N: 64, BLOCK_K2: 32, BLOCK_B: 1, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), ADDRESS_SPACE: SHARED_ADDRESS_SPACE, ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, },