diff --git a/iree/turbine/kernel/wave/expansion.py b/iree/turbine/kernel/wave/expansion.py index 726b0c96..517ea733 100644 --- a/iree/turbine/kernel/wave/expansion.py +++ b/iree/turbine/kernel/wave/expansion.py @@ -553,6 +553,53 @@ def get_dim_scaling( return dim_scaling +def _expand_mma_tiled_reduction( + mma: MMA, + trace: CapturedTrace, + dim_query: dict[IndexSymbol, int], + dim_scaling: dict[IndexSymbol, int], + context: ExpandedNodeMap, + get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]], + res_idx: int, +) -> CustomOp: + latest_reduced_op = mma + for scale_idx in range(1, dim_scaling[mma.reduction_dim]): + dim_query[mma.reduction_dim] = scale_idx + # Temporarily replace the loop carried arg here to avoid + # duplicated expansion. Otherwise we have the following situation: + # Suppose we have: + # mma_0_0_0(..., acc_0_0_0) + # mma_0_0_1(..., mma_0_0_0) + # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg + # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. + # To avoid this we temporarily replace the use of it with a dummy + # placeholder which will not trigger further expansion. + dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph) + dummy.type = None + + saved_acc = latest_reduced_op.acc + latest_reduced_op.update_arg("acc", dummy) + new_node = _expand_node( + latest_reduced_op, + trace, + dim_query, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) + + # This expansion always happens, user should never be reused + assert new_node != latest_reduced_op + # Update MMA_{t} to accumulate on MMA_{t-1}, and then save + # current MMA_{t} to outputs for use in next loop. + latest_reduced_op.update_arg("acc", saved_acc) + new_node.update_arg("acc", latest_reduced_op) + latest_reduced_op.graph.erase_node(dummy) + latest_reduced_op = new_node + return latest_reduced_op + + def _handle_reduction_dim( reduction: Reduction, output: Output, @@ -596,39 +643,14 @@ def get_output_index(custom: CustomOp): op_output_index = get_output_index(root_op) if dim_scaling[reduction.axis] <= 1: continue - for scale_idx in range(1, dim_scaling[reduction.axis]): - dims[reduction.axis] = scale_idx - # Temporarily replace the loop carried arg here to avoid - # duplicated expansion. Otherwise we have the following situation: - # Suppose we have: - # mma_0_0_0(..., acc_0_0_0) - # mma_0_0_1(..., mma_0_0_0) - # Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg - # mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node. - # To avoid this we temporarily replace the use of it with a dummy - # placeholder which will not trigger further expansion. - dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph) - dummy.type = None - - saved_acc = latest_reduced_op.acc - latest_reduced_op.update_arg("acc", dummy) - new_node = _expand_node( - latest_reduced_op, - trace, - dims, - dim_scaling, - context, - get_node_dim_scaling, - res_idx, - ) - - # This expansion always happens, user should never be reused - assert new_node != latest_reduced_op - # Update MMA_{t} to accumulate on MMA_{t-1}, and then save - # current MMA_{t} to outputs for use in next loop. - latest_reduced_op.update_arg("acc", saved_acc) - new_node.update_arg("acc", latest_reduced_op) - latest_reduced_op.graph.erase_node(dummy) - latest_reduced_op = new_node + latest_reduced_op = _expand_mma_tiled_reduction( + root_op, + trace, + dims, + dim_scaling, + context, + get_node_dim_scaling, + res_idx, + ) new_outputs[op_output_index] = latest_reduced_op.fx_node output.update_arg("return_vals", new_outputs)