diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 19e3f64c..4ee52d7b 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -134,6 +134,14 @@ def cast(src: "Register", dtype: DataType) -> "Register": ... +def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": + ... + + +def expansion_conflict(src: "Register", vector_shapes: dict[IndexSymbol, int]): + ... + + def define_op(op_name: str) -> Callable[[T], T]: def decorator(cls: T) -> T: cls.tkw_op_name = op_name @@ -260,6 +268,7 @@ def new_function(*args: Any, **kwargs: dict[str, Any]): def get_custom(node: fx.Node) -> "CustomOp": """Get the corresponding CustomOp for a given fx.Node.""" if isinstance(node, CustomOp): + breakpoint() print("Careful! You passed a custom op where an fx.Node was required.") return node if not isinstance(node, fx.Node): @@ -399,6 +408,10 @@ def copy( new_node.index = copy.deepcopy(self.fx_node.index) if new_name: new_node.name = new_name + if hasattr(self.fx_node, "vector_shapes"): + new_node.vector_shapes = self.fx_node.vector_shapes + if hasattr(self.fx_node, "reduction_dim"): + new_node.reduction_dim = self.fx_node.reduction_dim return get_custom(new_node) def replace_all_uses_with(self, new_node: CustomOp | fx.Node): @@ -520,12 +533,29 @@ def expanded_dims(self, value: dict[IndexSymbol, int]): raise ValueError("Expanded dims must be a dict") self.fx_node.expanded_dims = value - def post_expansion(self, constraints: list["Constraint"]) -> None: + @property + def anchor(self) -> fx.Node: """ - Hook for post-expansion operations. This is called after the arguments - of the node are expanded. + The anchor is a node that provides information to the node + such as vector_shapes, indexing information etc. """ - pass + if hasattr(self.fx_node, "anchor"): + return self.fx_node.anchor + return None + + @anchor.setter + def anchor(self, value: fx.Node): + self.fx_node.anchor = value + + @property + def vector_shapes(self) -> dict[IndexSymbol, int]: + if hasattr(self.fx_node, "vector_shapes"): + return self.fx_node.vector_shapes + return None + + @vector_shapes.setter + def vector_shapes(self, value: dict[IndexSymbol, int]): + self.fx_node.vector_shapes = value def align_index(self, constraints: list["Constraint"]) -> None: """ @@ -883,6 +913,15 @@ def align_index(self, constraints: list["Constraint"]) -> None: self.index = align_index_vars(self.index, constraints) + @property + def reduction_dim(self) -> IndexSymbol: + if hasattr(self.fx_node, "reduction_dim"): + return self.fx_node.reduction_dim + + @reduction_dim.setter + def reduction_dim(self, value: IndexSymbol): + self.fx_node.reduction_dim = value + @define_op("read") @dataclass @@ -1033,9 +1072,11 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]: return_vals = output.return_vals[0] return ( [ - get_custom(val).acc_index - if isinstance(get_custom(val), MMA) - else val.index + ( + get_custom(val).acc_index + if isinstance(get_custom(val), MMA) + else val.index + ) for val in return_vals ] if isinstance(return_vals, (Sequence)) @@ -1309,3 +1350,74 @@ def indexing_dims(self) -> list[IndexSymbol]: def type(self) -> Memory: src_shape = get_custom(self.arg).type.symbolic_shape return Register[*src_shape, self.dtype] + + +@define_op("permute") +@dataclass +class Permute(CustomOp, ABC): + """ + Represents a permute operation that + permutes arg into the target shape. + """ + + arg: fx.Node + target_shape: Sequence[IndexExpr] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return self.target_shape + + @property + def type(self) -> Register: + src_type = get_custom(self.arg).type + return Register[*self.target_shape, src_type.dtype] + + @property + def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]: + """ + Computes the permuted index based on the target shape. + """ + src_shape = get_custom(self.arg).type.symbolic_shape + target_shape = self.target_shape + dim_map: dict[IndexSymbol, IndexSymbol] = {} + for dim_src, dim_target in zip(src_shape, target_shape): + dim_map[dim_target] = dim_src + src_index = get_custom(self.arg).index + target_index = dict(src_index) + for dim_target, dim_src in dim_map.items(): + target_index[dim_target] = src_index[dim_src] + return target_index + + @index.setter + def index(self, value: Any): + """ + Updates the index of the node based on a per-dimension index sequence. + """ + if value is None: + return + if isinstance(value, dict): + assert all( + isinstance(v, IndexSequence) for v in value.values() + ), f"Index must be a dict with values of type IndexSequence" + self.fx_node.index = {} + for dim, key in value.items(): + self.fx_node.index[dim] = key + elif isinstance(value, list): + self.fx_node.index = value + else: + raise ValueError("Index must be a dict") + + +@define_op("expansion_conflict") +@dataclass +class ExpansionConflict(CustomOp, ABC): + arg: fx.Node + vector_shapes: dict[IndexSymbol, int] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return [] + + @property + def type(self) -> Register: + return self.arg.type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 5ada0335..5a790e06 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -66,6 +66,7 @@ scheduling_barrier, scheduling_group_barrier, cast, + permute, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -1299,3 +1300,13 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): ) emitter.bind_node_proxy(node, IRProxyValue(casted_vector)) + + +@handle_op(permute) +def handle_permute(emitter: WaveEmitter, node: fx.Node): + try: + register, _ = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + vector_src = cast_py_value(emitter, register) + emitter.bind_node_proxy(node, vector_src) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index f83b8fad..42f44493 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -117,10 +117,7 @@ def expand_graph( Create a graph that represents the expanded version of the wave function. The expansion is done in the dimensions specified by the constraints. """ - if isinstance(constraints_or_scaling, dict): - dim_scaling = constraints_or_scaling - else: - dim_scaling = get_dim_scaling(constraints_or_scaling) + get_node_dim_scaling = partial(get_dim_scaling, constraints_or_scaling) # Start from the back and expand in the corresponding indexing dimensions of a node # Then proceed to the operands @@ -146,6 +143,7 @@ def expand_graph( if node.__class__ not in leaf_nodes: continue + dim_scaling = get_node_dim_scaling(node) for dim_combination in get_dim_combinations(dim_scaling, node.indexing_dims): expand_dims = { dim: val for dim, val in zip(dim_scaling.keys(), dim_combination) @@ -157,6 +155,7 @@ def expand_graph( expand_dims, dim_scaling, expansion_context, + get_node_dim_scaling, ) @@ -166,6 +165,7 @@ def _expand_node( dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" @@ -177,8 +177,9 @@ def _expand_node( elem, trace, dim_query, - dim_scaling, + get_node_dim_scaling(elem), context, + get_node_dim_scaling, res_idx, ).fx_node ) @@ -187,8 +188,23 @@ def _expand_node( if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") return context[(node, get_indexed_dims(dim_query, node), res_idx)] + elif isinstance(node, MMA): + if hasattr(node.graph, "parent_op") and node.reduction_dim not in dim_query: + reduction: Reduction = get_custom(node.graph.parent_op) + if reduction.axis != node.reduction_dim: + return _expand_mma_reduction( + node, + trace, + dim_query, + get_node_dim_scaling(node), + context, + get_node_dim_scaling, + res_idx, + ) elif isinstance(node, Reduction): - return _expand_reduction(node, trace, dim_query, dim_scaling, context, res_idx) + return _expand_reduction( + node, trace, dim_query, dim_scaling, context, get_node_dim_scaling, res_idx + ) elif isinstance(node, Getitem): res_idx = node.res_idx elif isinstance(node, GetResult) and not isinstance(node, Getitem): @@ -230,15 +246,34 @@ def _expand_node( # Proceed with expansion of the arguments for i, arg in node.node_args.items(): - if is_expandable(arg): - new_arg = _expand_node( - arg, - trace, - restricted_dims, - dim_scaling, - context, - res_idx, - ) + new_arg = None + if not isinstance(arg, Sequence): + if is_expandable(arg): + new_arg = _expand_node( + arg, + trace, + restricted_dims, + get_node_dim_scaling(arg), + context, + get_node_dim_scaling, + res_idx, + ) + new_node.update_arg(i, new_arg) + else: + new_arg = [] + for subarg in arg: + if is_expandable(subarg): + new_subarg = _expand_node( + subarg, + trace, + restricted_dims, + get_node_dim_scaling(subarg), + context, + get_node_dim_scaling, + res_idx, + ) + new_arg.append(new_subarg.fx_node) + assert len(new_arg) == len(arg), "All subargs must be expanded" new_node.update_arg(i, new_arg) context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node @@ -251,6 +286,7 @@ def _expand_reduction( dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" @@ -258,6 +294,7 @@ def _expand_reduction( users = reduction.users expand_dims: list[IndexSymbol] = [] for user in users: + dim_scaling.update(get_node_dim_scaling(user)) for indexing_dim in user.indexing_dims: if indexing_dim not in expand_dims: expand_dims.append(indexing_dim) @@ -295,21 +332,24 @@ def _expand_reduction( arg, trace, dims, - dim_scaling, + get_node_dim_scaling(arg), context, + get_node_dim_scaling, res_idx, ) ) # Proceed with expansion outside the reduction for init_arg in reduction.init_args: + custom_init_arg = get_custom(init_arg) new_init_args.append( _expand_node( - get_custom(init_arg), + custom_init_arg, trace, dims, - dim_scaling, + get_node_dim_scaling(custom_init_arg), context, + get_node_dim_scaling, res_idx, ) ) @@ -325,6 +365,7 @@ def _expand_reduction( trace, dim_scaling, context, + get_node_dim_scaling, res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return @@ -332,6 +373,92 @@ def _expand_reduction( return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] +def _expand_mma_reduction( + mma: MMA, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + + logger.debug(f"Expanding MMA reduction: {mma} in dims: {dim_query}") + idxc = IndexingContext.current() + for dim in mma.indexing_dims: + if dim not in dim_scaling and mma.vector_shapes[dim] > 0: + tile_size = idxc.get_static_value(dim) + dim_scaling[dim] = tile_size // mma.vector_shapes[dim] + + expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim]) + + # Store the accumulator value for expansion. + if not hasattr(_expand_mma_reduction, "acc"): + _expand_mma_reduction.acc = mma.acc + if not hasattr(_expand_mma_reduction, "mma"): + _expand_mma_reduction.mma = mma + + context_key = ( + _expand_mma_reduction.mma, + get_indexed_dims(dim_query, expand_dims), + res_idx, + ) + + # Clone the mma node + user = _expand_mma_reduction.mma + for scale_idx in range(dim_scaling[mma.reduction_dim]): + if isinstance(user, Output): + continue + + dims = dim_query + dims[mma.reduction_dim] = scale_idx + # Temporarily replace the loop carried arg here to avoid + # duplicated expansion. Otherwise we have the following situation: + # Suppose we have: + # mma_0_0_0(..., acc_0_0_0) + # mma_0_0_1(..., mma_0_0_0) + # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg + # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. + # To avoid this we temporarily replace the use of it with a dummy + # placeholder which will not trigger further expansion. + index = user.get_node_arg_index(get_custom(user.acc)) + dummy = Placeholder("dummy").add_to_graph(user.graph) + dummy.type = None + + saved_arg = user.node_args[index] + user.update_arg(index, dummy) + new_node = _expand_node( + user, + trace, + dims, + get_node_dim_scaling(user), + context, + get_node_dim_scaling, + ) + # print("-----------------") + # print(f"new_node: {new_node.rhs}, {new_node.lhs}, {new_node.acc}") + # print(f"user: {user.rhs}, {user.lhs}, {user.acc}") + # print(f"dims = {dims}") + + # Update the new node accumulator with the user, except the first one. + if scale_idx > 0: + new_node.update_arg(index, user) + else: + if scale_idx == 0: + # Replace acc in new_node with _expand_mma_reduction.acc + new_node.update_arg(index, _expand_mma_reduction.acc) + # Replace acc in user with saved_arg + # Remove the dummy placeholder. + user.update_arg(index, saved_arg) + # Erase the dummy placeholder.. + user.graph.erase_node(dummy) + # Update the user to the new node. + user = new_node + + context[context_key] = new_node + return new_node + + def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: """Returns the name of a node with the dimensions appended.""" @@ -355,9 +482,14 @@ def _contains(elem, container): return elem in container -def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int]: - """Get the number of expansions for the dimensions based on the constraints.""" +def get_dim_scaling( + constraints: Sequence[Constraint], node: fx.Node +) -> dict[IndexSymbol, int]: + """Get the number of expansions for the dimensions based on the constraints for a specific node.""" dim_scaling: dict[IndexSymbol, int] = {} + if node.vector_shapes is None: + return dim_scaling + hardware_constraints: list[HardwareConstraint] = [ constraint for constraint in constraints @@ -373,12 +505,9 @@ def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int] ): hw_cons = hardware_constraints[0] tile_size = idxc.get_static_value(constraint.tile_size) - if not _contains(constraint.dim, hw_cons.vector_shapes): - raise ValueError( - f"Attempting to determine vector shape for unmapped dimension {constraint.dim}" - ) - - vector_size = hw_cons.vector_shapes[constraint.dim] + if constraint.dim not in node.vector_shapes: + continue + vector_size = node.vector_shapes[constraint.dim] # No dim scaling for dims with 0 vector size. if vector_size == 0: @@ -399,6 +528,7 @@ def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int] "Tile size must be divisible by wave count and vector size" ) dim_scaling[constraint.dim] = tile_size // wave_count // vector_size + return dim_scaling @@ -408,6 +538,7 @@ def _handle_reduction_dim( trace: CapturedTrace, dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int, ): # Rediscover iter args @@ -422,6 +553,7 @@ def _handle_reduction_dim( # Users of the loop carried nodes will be duplicated for idx, carried_node in enumerate(iter_args): # The initial nodes are expanded in the first dimension, so we start from 1 + dim_scaling = get_node_dim_scaling(carried_node) for scale_idx in range(1, dim_scaling[reduction.axis]): for user in carried_node.users: if isinstance(user, Output): @@ -450,6 +582,7 @@ def _handle_reduction_dim( dims, dim_scaling, context, + get_node_dim_scaling, res_idx, ) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 21ed0d97..7f417665 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -4,7 +4,15 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ..ops.wave_ops import Write, ExtractSlice, get_custom, Reduction +from ..ops.wave_ops import ( + Write, + ExtractSlice, + get_custom, + Reduction, + MMA, + Placeholder, + IterArg, +) from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .._support.tracing import CapturedTrace, IndexingContext from .._support.indexing import IndexSymbol, IndexSequence @@ -27,15 +35,10 @@ def get_vector_shape( - hardware_constraint: HardwareConstraint, + vector_shapes: dict[IndexSymbol, int], symbolic_shape: list[IndexSymbol], ) -> list[int]: - assert all( - dim in hardware_constraint.vector_shapes for dim in symbolic_shape - ), "Missing vector shape in hardware constraint" - vector_shapes = [ - max(hardware_constraint.vector_shapes[dim], 1) for dim in symbolic_shape - ] + vector_shapes = [max(vector_shapes[dim], 1) for dim in symbolic_shape] return vector_shapes @@ -81,7 +84,9 @@ def has_strided_access(node: fx.Node) -> bool: for dim in custom.index } - shape = get_vector_shape(hw_constraint, custom.register_type.symbolic_shape) + shape = get_vector_shape( + custom.vector_shapes, custom.register_type.symbolic_shape + ) elements_per_thread = subs_idxc(custom.elements_per_thread) max_stride_dim, max_stride = max( [(dim, seq.stride) for dim, seq in simplified_index.items()], @@ -124,6 +129,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): mma_index, mma_slices = get_mma_dimensional_mapping( trace, get_hardware_constraint(constraints) ) + trace.walk(partial(set_vector_shapes, constraints, mma_index, mma_slices)) trace.walk(partial(set_node_index, constraints, mma_index, mma_slices)) @@ -166,10 +172,49 @@ def is_contiguous_dim( return is_innermost_dim or all_unit_dims +def set_vector_shapes( + constraints: Sequence[Constraint], + mma_index: dict[MMA, dict[IndexSymbol, int]], + mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], + node: fx.Node, +): + """ + Set the vector shapes for the specific op based on whether the op lies in + an MMA slice as well as the anchor node. + """ + custom = get_custom(node) + # MMA & Reduction nodes already have their vector shapes set. + if isinstance(custom, (MMA, Reduction)): + return + # Add vector shapes from constraints to all ops. These are global constraints. + custom.vector_shapes = {} + hw_constraint = get_hardware_constraint(constraints) + if hw_constraint.vector_shapes: + custom.vector_shapes = hw_constraint.vector_shapes + + if len(mma_slices) == 1: + # If there is just one MMA slice, there is no ambiguity in the vector shapes + # and we set that singular MMA op as the anchor for all ops. + mma = list(mma_slices.keys())[0] + custom.anchor = mma + custom.vector_shapes = custom.vector_shapes | mma.vector_shapes + return + + for mma in mma_slices: + if ( + node in mma_slices[mma][MMA_ACC] + or node in mma_slices[mma][MMA_LHS] + or node in mma_slices[mma][MMA_RHS] + ): + custom.anchor = mma + custom.vector_shapes = custom.vector_shapes | mma.vector_shapes + return + + def set_node_index( constraints: Sequence[Constraint], - mma_index: dict[IndexSymbol, int], - mma_slices: dict[IndexSymbol, list[fx.Node]], + mma_index: dict[MMA, dict[IndexSymbol, int]], + mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], node: fx.Node, ): """ @@ -198,10 +243,11 @@ def set_node_index( # The semantics of elements_per_thread are that it represents the number of # elements that are loaded contiguously from memory. custom = get_custom(node) + anchor = custom.anchor elements_per_thread = getattr(custom, "elements_per_thread", None) - if isinstance(custom, Reduction): + if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): return for dim in custom.indexing_dims: @@ -209,8 +255,8 @@ def set_node_index( for constraint in sorted_constraints: if isinstance(constraint, HardwareConstraint): inputs = None - if dim in mma_index: - inputs = (mma_index[dim], elements_per_thread, None) + if anchor and dim in mma_index[anchor]: + inputs = (mma_index[anchor][dim], elements_per_thread, None) else: # Assumes vector shapes are associated with workgroup dims. if dim not in workgroup_constraints: @@ -240,9 +286,13 @@ def set_node_index( # dependence in the dimensional index. # TODO: Evaluate if this is a valid case. continue - index_seq = constraint.apply(dim, *inputs, dim in mma_index) - if dim in mma_index: - index_seq = specialize_index_sequence(index_seq, mma_slices, custom) + index_seq = constraint.apply( + dim, *inputs, anchor and dim in mma_index[anchor] + ) + if anchor and dim in mma_index[anchor]: + index_seq = specialize_index_sequence( + index_seq, mma_slices[anchor], custom + ) elif constraint.dim == dim: if index_seq is None: @@ -262,7 +312,6 @@ def set_post_expansion_indices(trace: CapturedTrace, constraints: list[Constrain """ Add offsets to the indices based on the expanded dims. """ - hw_cons = get_hardware_constraint(constraints) def apply_offset(node: fx.Node): custom = get_custom(node) @@ -270,7 +319,7 @@ def apply_offset(node: fx.Node): return False for dim, scale in custom.expanded_dims.items(): if dim in custom.index: - custom.index[dim].start += scale * hw_cons.vector_shapes[dim] + custom.index[dim].start += scale * custom.vector_shapes[dim] return False trace.walk(apply_offset) diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index 4872529f..17b0a98a 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -10,6 +10,34 @@ from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM +def get_chain_mmt_asm( + query_type: str, key_type: str, value_type: str, output_type: str +) -> str: + B, M, K1, input_dtype = query_type.split("x") + B, K2, K1, input_dtype = key_type.split("x") + B, N, K2, input_dtype = value_type.split("x") + B, M, N, output_dtype = output_type.split("x") + intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}" + intermediate_cast_type = f"{B}x{K2}x{M}x{input_dtype}" + transposed_cast_type = f"{B}x{M}x{K2}x{input_dtype}" + return f""" + func.func @chain_mmt(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{ + %c0 = arith.constant 0.0 : f32 + %init = tensor.empty() : tensor<{intermediate_output_type}> + %inital_result = linalg.fill ins(%c0 : f32) outs(%init : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}> + %result = linalg.batch_matmul_transpose_b ins(%key, %query : tensor<{key_type}>, tensor<{query_type}>) + outs(%inital_result : tensor<{intermediate_output_type}>) -> tensor<{intermediate_output_type}> + %trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}> + %init2 = tensor.empty() : tensor<{transposed_cast_type}> + %transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1] + %init3 = tensor.empty() : tensor<{output_type}> + %inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{output_type}>) -> tensor<{output_type}> + %result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value: tensor<{transposed_cast_type}>, tensor<{value_type}>) + outs(%inital_result3 : tensor<{output_type}>) -> tensor<{output_type}> + return %result2 : tensor<{output_type}> + }}""" + + def get_mmt_asm( lhs_type: str, rhs_type: str, @@ -104,6 +132,12 @@ def generate_iree_ref( rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_mmt_asm(lhs_type, rhs_type, acc_type, batch=True) + elif kernel_type == "chain_mmt": + query_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) + key_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) + value_type = get_type_str(kernel_inputs[2].shape, kernel_inputs[2].dtype) + output_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + asm = get_chain_mmt_asm(query_type, key_type, value_type, output_type) elif kernel_type.startswith(conv_str): lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 86faf74a..04d0d2d9 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -88,14 +88,22 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: # Determine the correct indexSize for binaryOp and insert broadcasting. dst_op = lhs if lhs_dim_set > rhs_dim_set else rhs broadcast_idx, broadcast_src = (1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs) - broadcast = Broadcast(broadcast_src.fx_node, dst_op.type) with custom_node.graph.inserting_before(custom_node.fx_node): - broadcast.add_to_graph(custom_node.graph) - custom_node.update_arg(broadcast_idx, broadcast.fx_node) - propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op) + broadcast = Broadcast(broadcast_src.fx_node, dst_op.type).add_to_graph( + custom_node.graph + ) + custom_broadcast = get_custom(broadcast) + custom_broadcast.vector_shapes = broadcast_src.vector_shapes + custom_broadcast.anchor = broadcast_src.anchor + custom_node.update_arg(broadcast_idx, custom_broadcast.fx_node) + propagated_resolutions = capture_forward_slice( + custom_broadcast.fx_node, propagatable_op + ) for node in propagated_resolutions: get_custom(node).index = dst_op.index - resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op) + resolved_resolutions = capture_backward_slice( + custom_broadcast.fx_node, propagatable_op + ) return propagated_resolutions.union(resolved_resolutions) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index f3c9201d..d17069f1 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -194,23 +194,23 @@ def simplify_index(index: IndexExpr) -> IndexExpr: def get_mma_dimensional_mapping( trace: CapturedTrace, hardware_constraint: HardwareConstraint, -) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: +) -> tuple[ + dict[MMA, dict[IndexSymbol, int]], dict[MMA, dict[IndexSymbol, list[fx.Node]]] +]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have acc = tkw.mma(a_reg, b_reg, acc) where a_reg has shape UxV, b has shape SxV and acc has shape UxS, we map U to the MMA M dimension (0), S to the MMA N dimension (1) and - V to the MMA K dimension (2). - - Also update the vector shapes in the hardware constraint based on the - discovered MMA dimensions. + V to the MMA K dimension (2). We maintain this map per mma node and + also update the vector_shapes of the mma node based on this information. """ def is_mma(node): return isinstance(get_custom(node), MMA) - mapping: dict[IndexSymbol, int] = {} + mapping: dict[MMA, dict[IndexSymbol, int]] = {} mma_nodes = trace.walk(is_mma) for node in mma_nodes: custom: MMA = get_custom(node) @@ -219,18 +219,30 @@ def is_mma(node): rhs_shape = custom.rhs_type.symbolic_shape acc_shape = custom.acc_type.symbolic_shape k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop() - mapping[m] = 0 - mapping[n] = 1 - mapping[k] = 2 - # Update vector shapes in hardware constraint. - M, N, K = hardware_constraint.mma_matrix_shapes - if not hardware_constraint.vector_shapes: - hardware_constraint.vector_shapes = {} - hardware_constraint.vector_shapes[m] = M - hardware_constraint.vector_shapes[n] = N - hardware_constraint.vector_shapes[k] = K - - return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) + if custom not in mapping: + mapping[custom] = {} + mapping[custom][m] = 0 + mapping[custom][n] = 1 + mapping[custom][k] = 2 + custom.vector_shapes = { + m: hardware_constraint.mma_matrix_shapes[0], + n: hardware_constraint.mma_matrix_shapes[1], + k: hardware_constraint.mma_matrix_shapes[2], + } + if hardware_constraint.vector_shapes: + custom.vector_shapes.update(hardware_constraint.vector_shapes) + custom.anchor = custom + custom.reduction_dim = k + + # Since expansion proceeds bottom-up, we set the vector shapes + # of the parent reduction to the vector shapes of the last MMA node. + if hasattr(custom.graph, "parent_op"): + reduction = get_custom(custom.graph.parent_op) + reduction.vector_shapes = custom.vector_shapes + reduction.anchor = custom + + mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes} + return mapping, mma_slices def get_hardware_vector_size( @@ -503,7 +515,9 @@ def get_users( """ users = [] for user in node.users: - custom = get_custom(user) + custom = user + if not isinstance(custom, CustomOp): + custom = get_custom(user) if isinstance(custom, Reduction): # Map init arg to iter arg reduction = custom @@ -553,6 +567,9 @@ def get_inputs( # Map get result to output reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + elif isinstance(custom, Reduction): + reduction_subgraph = custom.get_root_graph().subgraphs[custom.subgraph_name] + inputs.append(custom.outputs(reduction_subgraph)) else: # Default handling for other ops. for input in node.all_input_nodes: @@ -602,16 +619,18 @@ def capture_backward_slice( return bfs(node, lambda x, y: get_inputs(x, y), filter_fn) -def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: +def capture_mma_slices(mma: MMA) -> dict[IndexSymbol, list[fx.Node]]: """ Given an index sequence, specialize it to a LHS, RHS or ACC index sequence based on whether the node is used as the LHS, RHS or ACC in the MMA node. """ mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} - for mma in mma_nodes: - mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) - mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) - mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + is_not_mma = lambda x: not isinstance(get_custom(x), MMA) + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs, is_not_mma) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs, is_not_mma) + mma_slices[MMA_ACC] += capture_forward_slice(mma.fx_node, is_not_mma).union( + capture_backward_slice(mma.acc, is_not_mma) + ) return mma_slices diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 4b302e87..c4c02ccd 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -115,13 +115,13 @@ def test_read_write_equal_sizes(): # CHECK-SAME: (%read_0_1, %allocate, 4, None) # CHECK-NEXT: %shared_memory_barrier # CHECK-NEXT: %read_shared_0_0 - # CHECK-SAME: (%allocate, 4, None, [%write_shared_0_0]) + # CHECK-SAME: (%allocate, 4, None, [%write_shared_0_0] # CHECK-NEXT: %read_shared_1_1 - # CHECK-SAME: (%allocate, 4, None, [%write_shared_1_1]) + # CHECK-SAME: (%allocate, 4, None, [%write_shared_1_1] # CHECK-NEXT: %read_shared_1_0 - # CHECK-SAME: (%allocate, 4, None, [%write_shared_1_0]) + # CHECK-SAME: (%allocate, 4, None, [%write_shared_1_0] # CHECK-NEXT: %read_shared_0_1 - # CHECK-SAME: (%allocate, 4, None, [%write_shared_0_1]) + # CHECK-SAME: (%allocate, 4, None, [%write_shared_0_1] # CHECK-NEXT: %write_0_0 # CHECK-SAME: (%read_shared_0_0, %c, 4, None) # CHECK-NEXT: %write_1_1 diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 5778b689..35713c0c 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -750,6 +750,78 @@ def repeat( # CHECK: return +@run_test +def test_chained_gemm(): + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + BLOCK_K2 = tkl.sym.BLOCK_K2 + + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def chained_gemm( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K1: 32, + K2: 256, + B: 8, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + q = torch.randn(8, 64, 64, dtype=torch.float16) + k = torch.randn(8, 256, 64, dtype=torch.float16) + v = torch.zeros(8, 128, 256, dtype=torch.float16) + output = torch.zeros(8, 64, 128, dtype=torch.float32) + print(chained_gemm(q, k, v, output).module_op) + + @run_test def test_gemm_pipelined(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index 5635c8a9..de8a3b40 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -256,10 +256,10 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) @@ -442,10 +442,10 @@ def test_batched_gemm(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) - # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(B, M, N), dtype=f32, value=0.0, index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0, register_0_1_0, register_1_0_0, register_1_1_0], subgraph_name=region_0, implicit_captures=[a, b]) # CHECK-NEXT: get_result(value=reduction, res_idx=3) # CHECK-NEXT: get_result(value=reduction, res_idx=2) @@ -590,7 +590,7 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: placeholder(_name=a # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c - # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) + # CHECK-NEXT: register(shape=(M, N), dtype=f32, value=0.0, index={M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16)}) # CHECK-NEXT: reduction(axis=K, init_args=[register_0_0_0] # CHECK-NEXT: get_result(value=reduction, res_idx=0) # CHECK-NEXT: write(register_=getresult_0_0_0 diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 16dc03cb..36594f22 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -186,13 +186,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 7371a7bb..a6a61a17 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -137,13 +137,13 @@ def test_gemm(): # CHECK-NEXT: placeholder(_name=b # CHECK-NEXT: placeholder(_name=c # CHECK-NEXT: register - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16) + 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16)}) # CHECK-NEXT: register( - # CHECK-SAME: index={M: $WG0*BLOCK_M + Mod($T0, 16), N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) + # CHECK-SAME: index={M: $WG0*BLOCK_M + 4*floor((Mod($T0, 64))/16) : 4 : 16, N: $WG1*BLOCK_N + BLOCK_N/2 + Mod($T0, 16) + 16}) # CHECK-NEXT: allocate( # CHECK-NEXT: allocate( # CHECK-NEXT: reduction( diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py new file mode 100644 index 00000000..4d5bc26b --- /dev/null +++ b/tests/kernel/wave/wave_attention_test.py @@ -0,0 +1,187 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import torch +import unittest +import iree.turbine.kernel as tk +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.iree_utils import generate_iree_ref +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from iree.turbine.kernel.wave.constraints import MMAType +import os +import json +from torch.testing import assert_close + +_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) +require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +# Whether to dump the generated MLIR module. +test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) +# Whether to use scheduling group barriers (needs LLVM fix). +enable_scheduling_barriers = int(os.environ.get("WAVE_USE_SCHED_BARRIERS", 0)) + +# Add test shapes for validation and performance testing. +perf_test = lambda *a: pytest.param(*a, marks=pytest.mark.perf_only) +default_test_shapes = {} +# Order of shapes: (B, M, N, K1, K2) +default_test_shapes["test_attention"] = [ + (8, 128, 128, 64, 256), +] +default_test_shapes["test_attention"] += [perf_test(x) for x in default_test_shapes] + + +user_specified_test_shapes = "" + +test_params_path = os.environ.get("TEST_PARAMS_PATH", None) + +if test_params_path: + with open(test_params_path, "r") as file: + user_specified_test_shapes = json.load(file) + + +def get_test_shapes(test_name: str) -> list[tuple[int]]: + if test_name in user_specified_test_shapes: + return user_specified_test_shapes[test_name] + return default_test_shapes[test_name] + + +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_attention")) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + ], +) +def testChainedGemm( + shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def chained_gemm( + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat( + acc: tkl.Register[B, M, N, tkl.f32] + ) -> tkl.Register[B, M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[2], shape[4], dtype=torch.float16) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + mb = chained_gemm(q, k, v, output) + + if test_dump_generated_mlir: + filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + iree_ref = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + generate_iree_ref( + "chain_mmt", [q, k, v], [iree_ref], config, run_bench=run_bench + ) + assert_close(output, iree_ref)