diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index dc3504e0..e035a98a 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -48,8 +48,6 @@ def determine_shuffle_config( """ access_pattern = index[reduction_dim] - elements_per_thread = access_pattern.size - cluster_size = vector_shapes[reduction_dim] // elements_per_thread # Since we are only concerned with what happens within a subgroup, # we can ignore the TID_1 and TID_2 components of the index. We can @@ -67,11 +65,11 @@ def determine_shuffle_config( offset = access_pattern.start.subs({k: 0 for k in ignore}) offset = subs_idxc(offset) offset_table = [offset.subs({THREAD_0: i}) for i in range(subgroup_size)] - # Determine the thread ids participating in the shuffle. + unique_offsets = list(dict.fromkeys(offset_table)) + cluster_size = len(set(offset_table)) thread_ids = [] - for i in range(cluster_size): - thread_ids.append(offset_table.index(i * elements_per_thread)) - + for thread_offset in unique_offsets: + thread_ids.append(offset_table.index(thread_offset)) cluster_stride = [x - y for x, y in zip(thread_ids[1:], thread_ids[:-1])] assert all_equal(cluster_stride), f"Cluster stride must be equal across threads." return cluster_size, cluster_stride[0] diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index f79e10f2..b6ce10ad 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -219,6 +219,9 @@ def set_vector_shapes( or node in mma_slices[mma][MMA_LHS] or node in mma_slices[mma][MMA_RHS] ): + # Ensure that the operators indexing dims are present in the anchor. + if not set(custom.indexing_dims).issubset(mma.indexing_dims): + continue custom.anchor = mma custom.vector_shapes = custom.vector_shapes | mma.vector_shapes return diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 04d0d2d9..129f7551 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -51,7 +51,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): # Anchor Indicies and Conflict resolution helpers ################################################################# -anchorOpTypes = (Read, Write, MMA, ReduceOp) +anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) legalSubtypes = (IterArg,) nonPropagatableTypes = anchorOpTypes + noHandleTypes @@ -201,9 +201,13 @@ def determine_thread_shapes(trace: CapturedTrace): index_sizes, set([]) ).union(bwd_slice) elif isinstance(custom, MMA): - lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op) - rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op) - acc_slice = capture_forward_slice(custom.acc, propagatable_op) + lhs_bwd_slice = set([custom.lhs]) + if propagatable_op(custom.lhs): + lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op) + rhs_bwd_slice = set([custom.rhs]) + if propagatable_op(custom.rhs): + rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op) + acc_slice = capture_forward_slice(custom.fx_node, propagatable_op) if not isinstance(get_custom(custom.acc), MMA): acc_slice = acc_slice.union( capture_backward_slice(custom.acc, propagatable_op) @@ -220,6 +224,15 @@ def determine_thread_shapes(trace: CapturedTrace): thread_size_to_ops[rhs_index] = thread_size_to_ops.get( rhs_index, set([]) ).union(rhs_bwd_slice) + elif isinstance(custom, Reshape): + # The reshape op acts like a barrier for the MMA preventing + # the mma from propagating the thread shapes of its reshaped + # operands backwards. + bwd_size = get_dim_sizes(custom.args.index) + bwd_slice = capture_backward_slice(custom.args, propagatable_op) + thread_size_to_ops[bwd_size] = thread_size_to_ops.get( + bwd_size, set([]) + ).union(bwd_slice) # Go through each index-size buckets, and apply the index-size to ops in the bucket. cummulative_set = set() diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 12ed2151..bcaa772f 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -265,17 +265,18 @@ def is_mma(node): # in the backward slice of the lhs and rhs upto a previous mma (if one exists). # So we check for the previous node of the first operator in the slice to see # if it is an MMA and if so check if a reshape is required. - def add_reshape_if_needed(mma: MMA, prev_mma: MMA): + def add_reshape_if_needed(mma: MMA, prev_mma: MMA, arg_index: int): with mma.graph.inserting_before(mma.fx_node): - for i, arg in mma.node_args.items(): - if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): - reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( - custom.graph - ) - custom_reshape = get_custom(reshape) - custom_reshape.vector_shapes = custom.vector_shapes - custom_reshape.anchor = custom - custom.update_arg(i, reshape) + arg = mma.lhs if arg_index == 0 else mma.rhs + arg = get_custom(arg) + if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): + reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( + custom.graph + ) + custom_reshape = get_custom(reshape) + custom_reshape.vector_shapes = custom.vector_shapes + custom_reshape.anchor = custom + custom.update_arg(arg_index, reshape) def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: """ @@ -295,10 +296,10 @@ def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: custom_mma = get_custom(mma) prev_mma = find_mma_in_slice(custom_mma.lhs) if prev_mma: - add_reshape_if_needed(custom_mma, prev_mma) + add_reshape_if_needed(custom_mma, prev_mma, 0) prev_mma = find_mma_in_slice(custom_mma.rhs) if prev_mma: - add_reshape_if_needed(custom_mma, prev_mma) + add_reshape_if_needed(custom_mma, prev_mma, 1) return mapping, mma_slices @@ -641,7 +642,8 @@ def bfs( filter_fn: Callable[[fx.node], bool], ) -> set[fx.Node]: """ - Run BFS on the graph to capture the forward slice of a node. + Run BFS on the graph. The filter function is not applied to + the incoming node. """ visited: set[fx.Node] = set() queue: list[fx.Node] = []