From 181b93ea2fff3e5efa1a57d689e9aa173658dd45 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Sat, 26 Oct 2024 13:34:17 -0700 Subject: [PATCH] Add support for varying vector shapes This PR adds support for expanding operators with varying vector shapes, specifically for the MMA case where either the same dimension has different vector shapes in different mmas or if different instructions are being used. The idea is to insert a reshape operator whenever such a shape mismatch is discovered. The reshape operator lowers to an extract or concatenate operation, depending on the context. Signed-off-by: Harsh Menon --- iree/turbine/kernel/ops/wave_ops.py | 80 +++++++++------- iree/turbine/kernel/wave/codegen.py | 34 +++++++ iree/turbine/kernel/wave/expansion.py | 92 +++++++++++++++++++ .../kernel/wave/index_sequence_analysis.py | 20 +++- iree/turbine/kernel/wave/utils.py | 33 ++++++- lit_tests/kernel/wave/codegen.py | 13 ++- 6 files changed, 231 insertions(+), 41 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 1dad0160..ff098c44 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -41,15 +41,13 @@ def allocate( shape: tuple[IndexExpr], dtype: DataType, address_space: IndexSymbol -) -> "Memory": - ... +) -> "Memory": ... def extract( register: "Register", offsets: tuple[IndexExpr], -) -> "Register": - ... +) -> "Register": ... def extract_slice( @@ -57,34 +55,30 @@ def extract_slice( offsets: tuple[IndexExpr], sizes: tuple[IndexExpr], strides: tuple[IndexExpr], -) -> "Register": - ... +) -> "Register": ... -def shared_memory_barrier(): - ... +def shared_memory_barrier(): ... def read( memory: "Memory", elements_per_thread: Optional[IndexExpr | int] = None, mapping: Optional[IndexMapping] = None, -) -> "Register": - ... +) -> "Register": ... def reduction( axis: IndexExpr, init_args: Sequence["Register"] -) -> Callable[[Callable[[AccT], AccT]], AccT]: - ... +) -> Callable[[Callable[[AccT], AccT]], AccT]: ... -def register(shape: tuple[IndexExpr, ...], dtype: DataType, value: float) -> "Register": - ... +def register( + shape: tuple[IndexExpr, ...], dtype: DataType, value: float +) -> "Register": ... -def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": - ... +def mma(lhs: "Register", rhs: "Register", acc: "Register") -> "Register": ... def write( @@ -92,50 +86,44 @@ def write( memory: "Memory", elements_per_thread: Optional[IndexExpr | int] = None, mapping: Optional[IndexMapping] = None, -): - ... +): ... -def exp2(src: "Register") -> "Register": - ... +def exp2(src: "Register") -> "Register": ... -def maximum(lhs: "Register", rhs: "Register") -> "Register": - ... +def maximum(lhs: "Register", rhs: "Register") -> "Register": ... def broadcast( arg: "Register", target_shape: Optional[IndexExpr | int] = None -) -> "Register": - ... +) -> "Register": ... def sum( src: "Register", acc: Optional["Register"] = None, dim: Optional[IndexExpr | int] = None, -) -> "Register": - ... +) -> "Register": ... def max( src: "Register", acc: Optional["Register"] = None, dim: Optional[IndexExpr | int] = None, -) -> "Register": - ... +) -> "Register": ... -def shuffle(src: "Register", offset: int, width: int) -> "Register": - ... +def shuffle(src: "Register", offset: int, width: int) -> "Register": ... -def cast(src: "Register", dtype: DataType) -> "Register": - ... +def cast(src: "Register", dtype: DataType) -> "Register": ... -def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": - ... +def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": ... + + +def reshape(inputs: Sequence["Register"]) -> "Register": ... def define_op(op_name: str) -> Callable[[T], T]: @@ -1400,3 +1388,27 @@ def type(self) -> Register: self.target_shape ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" return Register[*self.target_shape, src_type.dtype] + + +def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]: + return input if isinstance(input, Sequence) else (input,) + + +@define_op("reshape") +@dataclass +class Reshape(CustomOp, ABC): + """ + Represents a reshape operation that reshapes + vectors along the same dimension. + """ + + args: fx.Node | Sequence[fx.Node] + target_vector_shape: dict[IndexSymbol, int] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return get_custom(_to_sequence(self.args)[0]).indexing_dims + + @property + def type(self) -> Register: + return get_custom(_to_sequence(self.args)[0]).type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 5a790e06..7a4e5091 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -67,6 +67,7 @@ scheduling_group_barrier, cast, permute, + reshape, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -1310,3 +1311,36 @@ def handle_permute(emitter: WaveEmitter, node: fx.Node): raise ValidationError("Malformed arguments") from e vector_src = cast_py_value(emitter, register) emitter.bind_node_proxy(node, vector_src) + + +@handle_op(reshape) +def handle_reshape(emitter: WaveEmitter, node: fx.Node): + try: + args, target_vector_shapes = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + custom = get_custom(node) + innermost_dim = custom.type.symbolic_shape[-1] + offset = custom.expanded_dims[innermost_dim] + + # Determine whether to extract or combine. + if len(args) == 1: + # Extract the appropriate slice. + size = ( + target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim] + ) + vector = cast_vector(emitter, args[0]) + result_type = VectorType.get([size], vector.type.element_type) + slice = vector_d.extract_strided_slice( + result_type, + vector, + [offset * size], + [size], + [1], + ) + emitter.bind_node_proxy(node, IRProxyValue(slice)) + return + + raise NotImplementedError( + "reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes" + ) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 6ba9b52d..55c1db85 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -250,6 +250,20 @@ def _expand_node( new_node.expanded_dims = restricted_dims new_node.fx_node.name = get_expanded_name(node, restricted_dims) + # For reshapes, we need more explicit control over how the arguments are expanded. + if isinstance(new_node, Reshape): + _expand_reshape( + new_node, + trace, + dim_query, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node + return new_node + # Proceed with expansion of the arguments for i, arg in node.node_args.items(): arg_list = arg @@ -496,6 +510,84 @@ def _expand_mma_reduction( return new_node +def _expand_reshape( + reshape: Reshape, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + """ + When expanding a reshape, we have to expand the arguments of the reshape and then concatenate them together + for the expanded node. Say we have a node with indexing dims = [M, N] with vector shapes m=8, n=2 and + the reshape wants to map it to m=4, n=4. So we start by expanding the node + node: {m = 0, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 1, n = 0} + arg: {m = 0, n = 0} + arg: {m = 0, n = 1} + node: {m = 2, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + node: {m = 3, n = 0} + arg: {m = 1, n = 0} + arg: {m = 1, n = 1} + ... + In general, + For the (m = i, n = j) expansion of the reshape node, we expand the arguments of the reshape node + using the following recipe: + - if m_src < m_dst, => we have a one to many mapping from source to destination + so we expand the arguments along m = i // (m_dst / m_src) and we expand the argument only once. + - if m_src > m_dst, => we have a many to one mapping from source to destination + so we expand the arguments along m = i * (m_src / m_dst), ... and we expand the argument m_dst / m_src times. + + In situations where the argument has been expanded along the same dimension, we reuse the expanded node + by making use of the context. + """ + + dim_combinations = {} + for dim, value in dim_query.items(): + if dim not in reshape.target_vector_shape: + continue + if reshape.vector_shapes[dim] < reshape.target_vector_shape[dim]: + scale_factor = ( + reshape.target_vector_shape[dim] // reshape.vector_shapes[dim] + ) + dim_combinations[dim] = [value // scale_factor] + else: + scale_factor = ( + reshape.vector_shapes[dim] // reshape.target_vector_shape[dim] + ) + begin = value * scale_factor + dim_combinations[dim] = list(range(begin, begin + scale_factor)) + reshape_dim_combinations = list(itertools.product(*dim_combinations.values())) + + new_args = [] + for i, arg_dim_query in enumerate(reshape_dim_combinations): + arg_dim_query = { + dim: val for dim, val in zip(dim_combinations.keys(), arg_dim_query) + } + if isinstance(reshape.args, Sequence): + custom_arg = get_custom(reshape.args[i]) + else: + custom_arg = get_custom(reshape.args) + new_node = _expand_node( + custom_arg, + trace, + arg_dim_query, + get_node_dim_scaling(custom_arg.fx_node), + context, + get_node_dim_scaling, + res_idx, + ) + new_args.append(new_node.fx_node) + + reshape.update_arg("args", new_args) + + def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: """Returns the name of a node with the dimensions appended.""" diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 140c0ac6..9923f873 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ..ops.wave_ops import ( + Allocate, Write, ExtractSlice, get_custom, @@ -12,6 +13,8 @@ MMA, Placeholder, IterArg, + CustomOp, + Reshape, ) from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .._support.tracing import CapturedTrace, IndexingContext @@ -182,6 +185,19 @@ def is_contiguous_dim( return is_innermost_dim or all_unit_dims +def add_reshape(custom: CustomOp): + for arg in custom.node_args.values(): + if not isinstance(arg, Sequence): + arg = [arg] + for subarg in arg: + # These are ops in the parent graph that have not had their vector shapes set yet. + if subarg.vector_shapes is None: + continue + if subarg.vector_shapes != custom.vector_shapes: + with custom.graph.inserting_before(custom.fx_node): + Reshape(subarg, custom.vector_shapes).add_to_graph(custom.graph) + + def set_vector_shapes( constraints: Sequence[Constraint], mma_index: dict[MMA, dict[IndexSymbol, int]], @@ -193,8 +209,8 @@ def set_vector_shapes( an MMA slice as well as the anchor node. """ custom = get_custom(node) - # MMA & Reduction nodes already have their vector shapes set. - if isinstance(custom, (MMA, Reduction)): + # MMA, Reduction & Reshape nodes already have their vector shapes set. + if isinstance(custom, (MMA, Reduction, Reshape)): return # Add vector shapes from constraints to all ops. These are global constraints. custom.vector_shapes = {} diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index b3f229e5..6ad4798e 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -12,7 +12,7 @@ UnitAttr, Value, ) -from typing import Optional, Callable, Any, List, Tuple +from typing import Optional, Callable, Any, List, Tuple, Sequence from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * @@ -25,6 +25,7 @@ Reduction, GetResult, IterArg, + Reshape, ) from .constraints import ( Constraint, @@ -192,6 +193,22 @@ def simplify_index(index: IndexExpr) -> IndexExpr: return subs_idxc(index.subs(mapping)) +def is_reshape_needed( + node: CustomOp, + node_vector_shapes: dict[IndexSymbol, int], + vector_shapes: dict[IndexSymbol, int], +) -> bool: + for dim in node.type.symbolic_shape: + if dim not in vector_shapes: + continue + if node_vector_shapes[dim] != vector_shapes[dim]: + print( + f"Reshape needed for {node} due to {dim}, {node_vector_shapes[dim]} != {vector_shapes[dim]}" + ) + return True + return False + + def get_mma_dimensional_mapping( trace: CapturedTrace, hardware_constraint: HardwareConstraint, @@ -213,6 +230,7 @@ def is_mma(node): mapping: dict[MMA, dict[IndexSymbol, int]] = {} mma_nodes = trace.walk(is_mma) + last_vector_shapes = {} for node in mma_nodes: custom: MMA = get_custom(node) m, n = custom.acc_type.symbolic_shape[-2:] @@ -234,6 +252,19 @@ def is_mma(node): custom.vector_shapes.update(hardware_constraint.vector_shapes) custom.anchor = custom custom.reduction_dim = k + if last_vector_shapes and last_vector_shapes != custom.vector_shapes: + with custom.graph.inserting_before(custom.fx_node): + for i, arg in custom.node_args.items(): + if is_reshape_needed(arg, custom.vector_shapes, last_vector_shapes): + reshape = Reshape(arg.fx_node, last_vector_shapes).add_to_graph( + custom.graph + ) + custom_reshape = get_custom(reshape) + custom_reshape.vector_shapes = custom.vector_shapes + custom_reshape.anchor = custom + custom.update_arg(i, reshape) + + last_vector_shapes = custom.vector_shapes # Since expansion proceeds bottom-up, we set the vector shapes # of the parent reduction to the vector shapes of the last MMA node. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index ceff8128..a3a78e46 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -6,7 +6,11 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.lang.global_symbols import * -from iree.turbine.kernel.wave.utils import run_test +from iree.turbine.kernel.wave.utils import ( + run_test, + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) import torch M = tkl.sym.M @@ -763,11 +767,12 @@ def test_chained_gemm(): constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + mfma_variant = tkw.MMAType.F32_32x32x8_F16 constraints += [ tkw.HardwareConstraint( threads_per_wave=64, waves_per_block=(2, 2, 1), - mma_type=tkw.MMAType.F32_16x16x16_F16, + mma_type=mfma_variant, vector_shapes={B: 0}, ) ] @@ -808,8 +813,8 @@ def repeat( BLOCK_N: 64, BLOCK_K2: 32, BLOCK_B: 1, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), ADDRESS_SPACE: SHARED_ADDRESS_SPACE, ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, },