diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 19e3f64c..466c9455 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -134,6 +134,14 @@ def cast(src: "Register", dtype: DataType) -> "Register": ... +def permute(src: "Register", target_shape: Sequence[IndexExpr]) -> "Register": + ... + + +def expansion_conflict(src: "Register", vector_shapes: dict[IndexSymbol, int]): + ... + + def define_op(op_name: str) -> Callable[[T], T]: def decorator(cls: T) -> T: cls.tkw_op_name = op_name @@ -399,6 +407,10 @@ def copy( new_node.index = copy.deepcopy(self.fx_node.index) if new_name: new_node.name = new_name + if hasattr(self.fx_node, "vector_shapes"): + new_node.vector_shapes = self.fx_node.vector_shapes + if hasattr(self.fx_node, "reduction_dim"): + new_node.reduction_dim = self.fx_node.reduction_dim return get_custom(new_node) def replace_all_uses_with(self, new_node: CustomOp | fx.Node): @@ -520,12 +532,29 @@ def expanded_dims(self, value: dict[IndexSymbol, int]): raise ValueError("Expanded dims must be a dict") self.fx_node.expanded_dims = value - def post_expansion(self, constraints: list["Constraint"]) -> None: + @property + def anchor(self) -> fx.Node: """ - Hook for post-expansion operations. This is called after the arguments - of the node are expanded. + The anchor is a node that provides information to the node + such as vector_shapes, indexing information etc. """ - pass + if hasattr(self.fx_node, "anchor"): + return self.fx_node.anchor + return None + + @anchor.setter + def anchor(self, value: fx.Node): + self.fx_node.anchor = value + + @property + def vector_shapes(self) -> dict[IndexSymbol, int]: + if hasattr(self.fx_node, "vector_shapes"): + return self.fx_node.vector_shapes + return None + + @vector_shapes.setter + def vector_shapes(self, value: dict[IndexSymbol, int]): + self.fx_node.vector_shapes = value def align_index(self, constraints: list["Constraint"]) -> None: """ @@ -883,6 +912,15 @@ def align_index(self, constraints: list["Constraint"]) -> None: self.index = align_index_vars(self.index, constraints) + @property + def reduction_dim(self) -> IndexSymbol: + if hasattr(self.fx_node, "reduction_dim"): + return self.fx_node.reduction_dim + + @reduction_dim.setter + def reduction_dim(self, value: IndexSymbol): + self.fx_node.reduction_dim = value + @define_op("read") @dataclass @@ -1033,9 +1071,11 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]: return_vals = output.return_vals[0] return ( [ - get_custom(val).acc_index - if isinstance(get_custom(val), MMA) - else val.index + ( + get_custom(val).acc_index + if isinstance(get_custom(val), MMA) + else val.index + ) for val in return_vals ] if isinstance(return_vals, (Sequence)) @@ -1309,3 +1349,74 @@ def indexing_dims(self) -> list[IndexSymbol]: def type(self) -> Memory: src_shape = get_custom(self.arg).type.symbolic_shape return Register[*src_shape, self.dtype] + + +@define_op("permute") +@dataclass +class Permute(CustomOp, ABC): + """ + Represents a permute operation that + permutes arg into the target shape. + """ + + arg: fx.Node + target_shape: Sequence[IndexExpr] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return self.target_shape + + @property + def type(self) -> Register: + src_type = get_custom(self.arg).type + return Register[*self.target_shape, src_type.dtype] + + @property + def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]: + """ + Computes the permuted index based on the target shape. + """ + src_shape = get_custom(self.arg).type.symbolic_shape + target_shape = self.target_shape + dim_map: dict[IndexSymbol, IndexSymbol] = {} + for dim_src, dim_target in zip(src_shape, target_shape): + dim_map[dim_target] = dim_src + src_index = get_custom(self.arg).index + target_index = dict(src_index) + for dim_target, dim_src in dim_map.items(): + target_index[dim_target] = src_index[dim_src] + return target_index + + @index.setter + def index(self, value: Any): + """ + Updates the index of the node based on a per-dimension index sequence. + """ + if value is None: + return + if isinstance(value, dict): + assert all( + isinstance(v, IndexSequence) for v in value.values() + ), f"Index must be a dict with values of type IndexSequence" + self.fx_node.index = {} + for dim, key in value.items(): + self.fx_node.index[dim] = key + elif isinstance(value, list): + self.fx_node.index = value + else: + raise ValueError("Index must be a dict") + + +@define_op("expansion_conflict") +@dataclass +class ExpansionConflict(CustomOp, ABC): + arg: fx.Node + vector_shapes: dict[IndexSymbol, int] + + @property + def indexing_dims(self) -> list[IndexExpr]: + return [] + + @property + def type(self) -> Register: + return self.arg.type diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 5ada0335..5a790e06 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -66,6 +66,7 @@ scheduling_barrier, scheduling_group_barrier, cast, + permute, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -1299,3 +1300,13 @@ def handle_cast(emitter: WaveEmitter, node: fx.Node): ) emitter.bind_node_proxy(node, IRProxyValue(casted_vector)) + + +@handle_op(permute) +def handle_permute(emitter: WaveEmitter, node: fx.Node): + try: + register, _ = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + vector_src = cast_py_value(emitter, register) + emitter.bind_node_proxy(node, vector_src) diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index f83b8fad..fe9d3458 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -117,10 +117,7 @@ def expand_graph( Create a graph that represents the expanded version of the wave function. The expansion is done in the dimensions specified by the constraints. """ - if isinstance(constraints_or_scaling, dict): - dim_scaling = constraints_or_scaling - else: - dim_scaling = get_dim_scaling(constraints_or_scaling) + get_node_dim_scaling = partial(get_dim_scaling, constraints_or_scaling) # Start from the back and expand in the corresponding indexing dimensions of a node # Then proceed to the operands @@ -146,6 +143,7 @@ def expand_graph( if node.__class__ not in leaf_nodes: continue + dim_scaling = get_node_dim_scaling(node) for dim_combination in get_dim_combinations(dim_scaling, node.indexing_dims): expand_dims = { dim: val for dim, val in zip(dim_scaling.keys(), dim_combination) @@ -157,6 +155,7 @@ def expand_graph( expand_dims, dim_scaling, expansion_context, + get_node_dim_scaling, ) @@ -166,6 +165,7 @@ def _expand_node( dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int = 0, ) -> CustomOp: """Expand a single node or list of nodes in specific dimensions and recursively proceed to its inputs.""" @@ -177,8 +177,9 @@ def _expand_node( elem, trace, dim_query, - dim_scaling, + get_node_dim_scaling(elem), context, + get_node_dim_scaling, res_idx, ).fx_node ) @@ -187,8 +188,23 @@ def _expand_node( if (node, get_indexed_dims(dim_query, node), res_idx) in context: logger.debug(f"Already expanded node: {node} in {dim_query}") return context[(node, get_indexed_dims(dim_query, node), res_idx)] + elif isinstance(node, MMA): + if hasattr(node.graph, "parent_op") and node.reduction_dim not in dim_query: + reduction: Reduction = get_custom(node.graph.parent_op) + if reduction.axis != node.reduction_dim: + return _expand_mma_reduction( + node, + trace, + dim_query, + get_node_dim_scaling(node), + context, + get_node_dim_scaling, + res_idx, + ) elif isinstance(node, Reduction): - return _expand_reduction(node, trace, dim_query, dim_scaling, context, res_idx) + return _expand_reduction( + node, trace, dim_query, dim_scaling, context, get_node_dim_scaling, res_idx + ) elif isinstance(node, Getitem): res_idx = node.res_idx elif isinstance(node, GetResult) and not isinstance(node, Getitem): @@ -230,16 +246,35 @@ def _expand_node( # Proceed with expansion of the arguments for i, arg in node.node_args.items(): - if is_expandable(arg): - new_arg = _expand_node( - arg, - trace, - restricted_dims, - dim_scaling, - context, - res_idx, - ) - new_node.update_arg(i, new_arg) + new_arg = None + if not isinstance(arg, Sequence): + if is_expandable(arg): + new_arg = _expand_node( + arg, + trace, + restricted_dims, + get_node_dim_scaling(arg), + context, + get_node_dim_scaling, + res_idx, + ) + new_node.update_arg(i, new_arg) + else: + new_arg = [] + for subarg in arg: + if is_expandable(subarg): + new_subarg = _expand_node( + subarg, + trace, + restricted_dims, + get_node_dim_scaling(subarg), + context, + get_node_dim_scaling, + res_idx, + ) + new_arg.append(new_subarg) + if len(new_arg) == len(arg): + new_node.update_arg(i, new_arg) context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node return new_node @@ -251,6 +286,7 @@ def _expand_reduction( dim_query: dict[IndexSymbol, int], dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int = 0, ) -> CustomOp: """Expand a reduction in a specific dimension and recursively proceed to its inputs.""" @@ -258,6 +294,7 @@ def _expand_reduction( users = reduction.users expand_dims: list[IndexSymbol] = [] for user in users: + dim_scaling.update(get_node_dim_scaling(user)) for indexing_dim in user.indexing_dims: if indexing_dim not in expand_dims: expand_dims.append(indexing_dim) @@ -295,21 +332,24 @@ def _expand_reduction( arg, trace, dims, - dim_scaling, + get_node_dim_scaling(arg), context, + get_node_dim_scaling, res_idx, ) ) # Proceed with expansion outside the reduction for init_arg in reduction.init_args: + custom_init_arg = get_custom(init_arg) new_init_args.append( _expand_node( - get_custom(init_arg), + custom_init_arg, trace, dims, - dim_scaling, + get_node_dim_scaling(custom_init_arg), context, + get_node_dim_scaling, res_idx, ) ) @@ -325,6 +365,7 @@ def _expand_reduction( trace, dim_scaling, context, + get_node_dim_scaling, res_idx, ) # Even though we expanded the reduction in multiple dimensions, we only return @@ -332,6 +373,81 @@ def _expand_reduction( return context[(reduction, get_indexed_dims(dim_query, expand_dims), res_idx)] +def _expand_mma_reduction( + mma: MMA, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + + # Compute dim scaling for the reduction dimension. + idxc = IndexingContext.current() + for dim in mma.indexing_dims: + if dim not in dim_scaling and mma.vector_shapes[dim] > 0: + tile_size = idxc.get_static_value(dim) + dim_scaling[dim] = tile_size // mma.vector_shapes[dim] + + expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim]) + context_key = (mma, get_indexed_dims(dim_query, expand_dims), res_idx) + + # Store the accumulator value for expansion. + if not hasattr(_expand_mma_reduction, "acc"): + _expand_mma_reduction.acc = mma.acc + + # Clone the mma node + user = mma + for scale_idx in range(dim_scaling[mma.reduction_dim]): + if isinstance(user, Output): + continue + + dims = dim_query + dims[mma.reduction_dim] = scale_idx + # Temporarily replace the loop carried arg here to avoid + # duplicated expansion. Otherwise we have the following situation: + # Suppose we have: + # mma_0_0_0(..., acc_0_0_0) + # mma_0_0_1(..., mma_0_0_0) + # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg + # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. + # To avoid this we temporarily replace the use of it with a dummy + # placeholder which will not trigger further expansion. + index = user.get_node_arg_index(get_custom(user.acc)) + dummy = Placeholder("dummy").add_to_graph(user.graph) + dummy.type = None + + saved_arg = user.node_args[index] + user.update_arg(index, dummy) + new_node = _expand_node( + user, + trace, + dims, + get_node_dim_scaling(user), + context, + get_node_dim_scaling, + ) + + # Update the new node accumulator with the user, except the first one. + if scale_idx > 0: + new_node.update_arg(index, user) + else: + if scale_idx == 0: + # Replace acc in new_node with _expand_mma_reduction.acc + new_node.update_arg(index, _expand_mma_reduction.acc) + # Replace acc in user with saved_arg + # Remove the dummy placeholder. + user.update_arg(index, saved_arg) + # Erase the dummy placeholder.. + user.graph.erase_node(dummy) + # Update the user to the new node. + user = new_node + + context[context_key] = new_node + return new_node + + def get_expanded_name(node: CustomOp, dims: dict[IndexSymbol, int]) -> str: """Returns the name of a node with the dimensions appended.""" @@ -355,9 +471,14 @@ def _contains(elem, container): return elem in container -def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int]: - """Get the number of expansions for the dimensions based on the constraints.""" +def get_dim_scaling( + constraints: Sequence[Constraint], node: fx.Node +) -> dict[IndexSymbol, int]: + """Get the number of expansions for the dimensions based on the constraints for a specific node.""" dim_scaling: dict[IndexSymbol, int] = {} + if node.vector_shapes is None: + return dim_scaling + hardware_constraints: list[HardwareConstraint] = [ constraint for constraint in constraints @@ -373,12 +494,9 @@ def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int] ): hw_cons = hardware_constraints[0] tile_size = idxc.get_static_value(constraint.tile_size) - if not _contains(constraint.dim, hw_cons.vector_shapes): - raise ValueError( - f"Attempting to determine vector shape for unmapped dimension {constraint.dim}" - ) - - vector_size = hw_cons.vector_shapes[constraint.dim] + if constraint.dim not in node.vector_shapes: + continue + vector_size = node.vector_shapes[constraint.dim] # No dim scaling for dims with 0 vector size. if vector_size == 0: @@ -399,6 +517,7 @@ def get_dim_scaling(constraints: Sequence[Constraint]) -> dict[IndexSymbol, int] "Tile size must be divisible by wave count and vector size" ) dim_scaling[constraint.dim] = tile_size // wave_count // vector_size + return dim_scaling @@ -408,6 +527,7 @@ def _handle_reduction_dim( trace: CapturedTrace, dim_scaling: dict[IndexSymbol, int], context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], res_idx: int, ): # Rediscover iter args @@ -422,6 +542,7 @@ def _handle_reduction_dim( # Users of the loop carried nodes will be duplicated for idx, carried_node in enumerate(iter_args): # The initial nodes are expanded in the first dimension, so we start from 1 + dim_scaling = get_node_dim_scaling(carried_node) for scale_idx in range(1, dim_scaling[reduction.axis]): for user in carried_node.users: if isinstance(user, Output): @@ -450,6 +571,7 @@ def _handle_reduction_dim( dims, dim_scaling, context, + get_node_dim_scaling, res_idx, ) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 21ed0d97..7f417665 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -4,7 +4,15 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ..ops.wave_ops import Write, ExtractSlice, get_custom, Reduction +from ..ops.wave_ops import ( + Write, + ExtractSlice, + get_custom, + Reduction, + MMA, + Placeholder, + IterArg, +) from .constraints import Constraint, HardwareConstraint, WorkgroupConstraint from .._support.tracing import CapturedTrace, IndexingContext from .._support.indexing import IndexSymbol, IndexSequence @@ -27,15 +35,10 @@ def get_vector_shape( - hardware_constraint: HardwareConstraint, + vector_shapes: dict[IndexSymbol, int], symbolic_shape: list[IndexSymbol], ) -> list[int]: - assert all( - dim in hardware_constraint.vector_shapes for dim in symbolic_shape - ), "Missing vector shape in hardware constraint" - vector_shapes = [ - max(hardware_constraint.vector_shapes[dim], 1) for dim in symbolic_shape - ] + vector_shapes = [max(vector_shapes[dim], 1) for dim in symbolic_shape] return vector_shapes @@ -81,7 +84,9 @@ def has_strided_access(node: fx.Node) -> bool: for dim in custom.index } - shape = get_vector_shape(hw_constraint, custom.register_type.symbolic_shape) + shape = get_vector_shape( + custom.vector_shapes, custom.register_type.symbolic_shape + ) elements_per_thread = subs_idxc(custom.elements_per_thread) max_stride_dim, max_stride = max( [(dim, seq.stride) for dim, seq in simplified_index.items()], @@ -124,6 +129,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): mma_index, mma_slices = get_mma_dimensional_mapping( trace, get_hardware_constraint(constraints) ) + trace.walk(partial(set_vector_shapes, constraints, mma_index, mma_slices)) trace.walk(partial(set_node_index, constraints, mma_index, mma_slices)) @@ -166,10 +172,49 @@ def is_contiguous_dim( return is_innermost_dim or all_unit_dims +def set_vector_shapes( + constraints: Sequence[Constraint], + mma_index: dict[MMA, dict[IndexSymbol, int]], + mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], + node: fx.Node, +): + """ + Set the vector shapes for the specific op based on whether the op lies in + an MMA slice as well as the anchor node. + """ + custom = get_custom(node) + # MMA & Reduction nodes already have their vector shapes set. + if isinstance(custom, (MMA, Reduction)): + return + # Add vector shapes from constraints to all ops. These are global constraints. + custom.vector_shapes = {} + hw_constraint = get_hardware_constraint(constraints) + if hw_constraint.vector_shapes: + custom.vector_shapes = hw_constraint.vector_shapes + + if len(mma_slices) == 1: + # If there is just one MMA slice, there is no ambiguity in the vector shapes + # and we set that singular MMA op as the anchor for all ops. + mma = list(mma_slices.keys())[0] + custom.anchor = mma + custom.vector_shapes = custom.vector_shapes | mma.vector_shapes + return + + for mma in mma_slices: + if ( + node in mma_slices[mma][MMA_ACC] + or node in mma_slices[mma][MMA_LHS] + or node in mma_slices[mma][MMA_RHS] + ): + custom.anchor = mma + custom.vector_shapes = custom.vector_shapes | mma.vector_shapes + return + + def set_node_index( constraints: Sequence[Constraint], - mma_index: dict[IndexSymbol, int], - mma_slices: dict[IndexSymbol, list[fx.Node]], + mma_index: dict[MMA, dict[IndexSymbol, int]], + mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], node: fx.Node, ): """ @@ -198,10 +243,11 @@ def set_node_index( # The semantics of elements_per_thread are that it represents the number of # elements that are loaded contiguously from memory. custom = get_custom(node) + anchor = custom.anchor elements_per_thread = getattr(custom, "elements_per_thread", None) - if isinstance(custom, Reduction): + if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): return for dim in custom.indexing_dims: @@ -209,8 +255,8 @@ def set_node_index( for constraint in sorted_constraints: if isinstance(constraint, HardwareConstraint): inputs = None - if dim in mma_index: - inputs = (mma_index[dim], elements_per_thread, None) + if anchor and dim in mma_index[anchor]: + inputs = (mma_index[anchor][dim], elements_per_thread, None) else: # Assumes vector shapes are associated with workgroup dims. if dim not in workgroup_constraints: @@ -240,9 +286,13 @@ def set_node_index( # dependence in the dimensional index. # TODO: Evaluate if this is a valid case. continue - index_seq = constraint.apply(dim, *inputs, dim in mma_index) - if dim in mma_index: - index_seq = specialize_index_sequence(index_seq, mma_slices, custom) + index_seq = constraint.apply( + dim, *inputs, anchor and dim in mma_index[anchor] + ) + if anchor and dim in mma_index[anchor]: + index_seq = specialize_index_sequence( + index_seq, mma_slices[anchor], custom + ) elif constraint.dim == dim: if index_seq is None: @@ -262,7 +312,6 @@ def set_post_expansion_indices(trace: CapturedTrace, constraints: list[Constrain """ Add offsets to the indices based on the expanded dims. """ - hw_cons = get_hardware_constraint(constraints) def apply_offset(node: fx.Node): custom = get_custom(node) @@ -270,7 +319,7 @@ def apply_offset(node: fx.Node): return False for dim, scale in custom.expanded_dims.items(): if dim in custom.index: - custom.index[dim].start += scale * hw_cons.vector_shapes[dim] + custom.index[dim].start += scale * custom.vector_shapes[dim] return False trace.walk(apply_offset) diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 86faf74a..04d0d2d9 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -88,14 +88,22 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: # Determine the correct indexSize for binaryOp and insert broadcasting. dst_op = lhs if lhs_dim_set > rhs_dim_set else rhs broadcast_idx, broadcast_src = (1, rhs) if lhs_dim_set > rhs_dim_set else (0, lhs) - broadcast = Broadcast(broadcast_src.fx_node, dst_op.type) with custom_node.graph.inserting_before(custom_node.fx_node): - broadcast.add_to_graph(custom_node.graph) - custom_node.update_arg(broadcast_idx, broadcast.fx_node) - propagated_resolutions = capture_forward_slice(broadcast.fx_node, propagatable_op) + broadcast = Broadcast(broadcast_src.fx_node, dst_op.type).add_to_graph( + custom_node.graph + ) + custom_broadcast = get_custom(broadcast) + custom_broadcast.vector_shapes = broadcast_src.vector_shapes + custom_broadcast.anchor = broadcast_src.anchor + custom_node.update_arg(broadcast_idx, custom_broadcast.fx_node) + propagated_resolutions = capture_forward_slice( + custom_broadcast.fx_node, propagatable_op + ) for node in propagated_resolutions: get_custom(node).index = dst_op.index - resolved_resolutions = capture_backward_slice(broadcast.fx_node, propagatable_op) + resolved_resolutions = capture_backward_slice( + custom_broadcast.fx_node, propagatable_op + ) return propagated_resolutions.union(resolved_resolutions) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index f3c9201d..d17069f1 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -194,23 +194,23 @@ def simplify_index(index: IndexExpr) -> IndexExpr: def get_mma_dimensional_mapping( trace: CapturedTrace, hardware_constraint: HardwareConstraint, -) -> tuple[dict[IndexSymbol, int], dict[IndexSymbol, list[fx.Node]]]: +) -> tuple[ + dict[MMA, dict[IndexSymbol, int]], dict[MMA, dict[IndexSymbol, list[fx.Node]]] +]: """ Given a trace, determine the MMA dimensional mapping for all the MMA operations in the graph. For example, if we have acc = tkw.mma(a_reg, b_reg, acc) where a_reg has shape UxV, b has shape SxV and acc has shape UxS, we map U to the MMA M dimension (0), S to the MMA N dimension (1) and - V to the MMA K dimension (2). - - Also update the vector shapes in the hardware constraint based on the - discovered MMA dimensions. + V to the MMA K dimension (2). We maintain this map per mma node and + also update the vector_shapes of the mma node based on this information. """ def is_mma(node): return isinstance(get_custom(node), MMA) - mapping: dict[IndexSymbol, int] = {} + mapping: dict[MMA, dict[IndexSymbol, int]] = {} mma_nodes = trace.walk(is_mma) for node in mma_nodes: custom: MMA = get_custom(node) @@ -219,18 +219,30 @@ def is_mma(node): rhs_shape = custom.rhs_type.symbolic_shape acc_shape = custom.acc_type.symbolic_shape k = ((set(lhs_shape) & set(rhs_shape)) - set(acc_shape)).pop() - mapping[m] = 0 - mapping[n] = 1 - mapping[k] = 2 - # Update vector shapes in hardware constraint. - M, N, K = hardware_constraint.mma_matrix_shapes - if not hardware_constraint.vector_shapes: - hardware_constraint.vector_shapes = {} - hardware_constraint.vector_shapes[m] = M - hardware_constraint.vector_shapes[n] = N - hardware_constraint.vector_shapes[k] = K - - return mapping, capture_mma_slices([get_custom(x) for x in mma_nodes]) + if custom not in mapping: + mapping[custom] = {} + mapping[custom][m] = 0 + mapping[custom][n] = 1 + mapping[custom][k] = 2 + custom.vector_shapes = { + m: hardware_constraint.mma_matrix_shapes[0], + n: hardware_constraint.mma_matrix_shapes[1], + k: hardware_constraint.mma_matrix_shapes[2], + } + if hardware_constraint.vector_shapes: + custom.vector_shapes.update(hardware_constraint.vector_shapes) + custom.anchor = custom + custom.reduction_dim = k + + # Since expansion proceeds bottom-up, we set the vector shapes + # of the parent reduction to the vector shapes of the last MMA node. + if hasattr(custom.graph, "parent_op"): + reduction = get_custom(custom.graph.parent_op) + reduction.vector_shapes = custom.vector_shapes + reduction.anchor = custom + + mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes} + return mapping, mma_slices def get_hardware_vector_size( @@ -503,7 +515,9 @@ def get_users( """ users = [] for user in node.users: - custom = get_custom(user) + custom = user + if not isinstance(custom, CustomOp): + custom = get_custom(user) if isinstance(custom, Reduction): # Map init arg to iter arg reduction = custom @@ -553,6 +567,9 @@ def get_inputs( # Map get result to output reduction_subgraph = reduction.graph.subgraphs[reduction.subgraph_name] inputs.append(reduction.outputs(reduction_subgraph)[custom.res_idx]) + elif isinstance(custom, Reduction): + reduction_subgraph = custom.get_root_graph().subgraphs[custom.subgraph_name] + inputs.append(custom.outputs(reduction_subgraph)) else: # Default handling for other ops. for input in node.all_input_nodes: @@ -602,16 +619,18 @@ def capture_backward_slice( return bfs(node, lambda x, y: get_inputs(x, y), filter_fn) -def capture_mma_slices(mma_nodes: list[MMA]) -> dict[IndexSymbol, list[fx.Node]]: +def capture_mma_slices(mma: MMA) -> dict[IndexSymbol, list[fx.Node]]: """ Given an index sequence, specialize it to a LHS, RHS or ACC index sequence based on whether the node is used as the LHS, RHS or ACC in the MMA node. """ mma_slices = {x: [] for x in [MMA_LHS, MMA_RHS, MMA_ACC]} - for mma in mma_nodes: - mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs) - mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs) - mma_slices[MMA_ACC] += capture_forward_slice(mma.acc) + is_not_mma = lambda x: not isinstance(get_custom(x), MMA) + mma_slices[MMA_LHS] += capture_backward_slice(mma.lhs, is_not_mma) + mma_slices[MMA_RHS] += capture_backward_slice(mma.rhs, is_not_mma) + mma_slices[MMA_ACC] += capture_forward_slice(mma.fx_node, is_not_mma).union( + capture_backward_slice(mma.acc, is_not_mma) + ) return mma_slices diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 5778b689..7bfe46b8 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -750,6 +750,76 @@ def repeat( # CHECK: return +@run_test +def test_chained_gemm(): + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + BLOCK_K2 = tkl.sym.BLOCK_K2 + + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + vector_shapes={B: 0}, + ) + ] + + @tkw.wave(constraints) + def chained_gemm( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[B, M, N, tkl.f32](0.0) + + @tkw.reduction(K2, init_args=[c_reg]) + def repeat(acc: tkl.Register[B, M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + inner_acc = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + kq_reg = tkw.mma(k_reg, q_reg, inner_acc) + qk_reg = tkw.permute(kq_reg, target_shape=[B, M, K2]) + qk_cast_reg = tkw.cast(qk_reg, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(qk_cast_reg, v_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K1: 64, + K2: 256, + B: 8, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K2: 32, + BLOCK_B: 1, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + q = torch.randn(8, 64, 64, dtype=torch.float16) + k = torch.randn(8, 256, 64, dtype=torch.float16) + v = torch.zeros(8, 128, 256, dtype=torch.float16) + output = torch.zeros(8, 64, 128, dtype=torch.float32) + print(chained_gemm(q, k, v, output).module_op) + + @run_test def test_gemm_pipelined(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]