Skip to content

Commit

Permalink
Package reduction of tiled mma into function
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Oct 25, 2024
1 parent 3682533 commit 8461184
Showing 1 changed file with 56 additions and 34 deletions.
90 changes: 56 additions & 34 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,53 @@ def get_dim_scaling(
return dim_scaling


def _expand_mma_tiled_reduction(
mma: MMA,
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
) -> CustomOp:
latest_reduced_op = mma
for scale_idx in range(1, dim_scaling[mma.reduction_dim]):
dim_query[mma.reduction_dim] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph)
dummy.type = None

saved_acc = latest_reduced_op.acc
latest_reduced_op.update_arg("acc", dummy)
new_node = _expand_node(
latest_reduced_op,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# This expansion always happens, user should never be reused
assert new_node != latest_reduced_op
# Update MMA_{t} to accumulate on MMA_{t-1}, and then save
# current MMA_{t} to outputs for use in next loop.
latest_reduced_op.update_arg("acc", saved_acc)
new_node.update_arg("acc", latest_reduced_op)
latest_reduced_op.graph.erase_node(dummy)
latest_reduced_op = new_node
return latest_reduced_op


def _handle_reduction_dim(
reduction: Reduction,
output: Output,
Expand Down Expand Up @@ -596,39 +643,14 @@ def get_output_index(custom: CustomOp):
op_output_index = get_output_index(root_op)
if dim_scaling[reduction.axis] <= 1:
continue
for scale_idx in range(1, dim_scaling[reduction.axis]):
dims[reduction.axis] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph)
dummy.type = None

saved_acc = latest_reduced_op.acc
latest_reduced_op.update_arg("acc", dummy)
new_node = _expand_node(
latest_reduced_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# This expansion always happens, user should never be reused
assert new_node != latest_reduced_op
# Update MMA_{t} to accumulate on MMA_{t-1}, and then save
# current MMA_{t} to outputs for use in next loop.
latest_reduced_op.update_arg("acc", saved_acc)
new_node.update_arg("acc", latest_reduced_op)
latest_reduced_op.graph.erase_node(dummy)
latest_reduced_op = new_node
latest_reduced_op = _expand_mma_tiled_reduction(
root_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
new_outputs[op_output_index] = latest_reduced_op.fx_node
output.update_arg("return_vals", new_outputs)

0 comments on commit 8461184

Please sign in to comment.