Skip to content

Commit

Permalink
add expansion of ReduceOp
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu committed Oct 26, 2024
1 parent 8461184 commit d9e0b24
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 11 deletions.
2 changes: 1 addition & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def type(self) -> Memory:
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
return broadcasted_type


Expand Down
39 changes: 29 additions & 10 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def get_output_index(custom: CustomOp):
# Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp.
reduction_root_ops = []
for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes):
if isinstance(node, MMA) and reduction.axis == node.reduction_dim:
if isinstance(node, (MMA, ReduceOp)) and reduction.axis == node.reduction_dim:
reduction_root_ops.append(node)

new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name)))
Expand All @@ -643,14 +643,33 @@ def get_output_index(custom: CustomOp):
op_output_index = get_output_index(root_op)
if dim_scaling[reduction.axis] <= 1:
continue
latest_reduced_op = _expand_mma_tiled_reduction(
root_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
if isinstance(root_op, MMA):
latest_reduced_op = _expand_mma_tiled_reduction(
root_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
elif isinstance(root_op, ReduceOp):
original_src = latest_reduced_op.arg
for scale_idx in range(1, dim_scaling[reduction.axis]):
dims[root_op.reduction_dim] = scale_idx
current_src = latest_reduced_op.arg
if not isinstance(current_src, Sequence):
current_src = [current_src]
expanded_src = _expand_node(
get_custom(original_src),
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
current_src.append(expanded_src.fx_node)
latest_reduced_op.update_arg("arg", current_src)
new_outputs[op_output_index] = latest_reduced_op.fx_node
output.update_arg("return_vals", new_outputs)
114 changes: 114 additions & 0 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,120 @@ def test_batched_gemm():
# CHECK-NEXT: -----


@tkw.wave_trace_only()
def gemm_non_direct_acc(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
new_acc = tkw.exp2(a_reg) + acc
acc = tkw.mma(a_reg, b_reg, new_acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_gemm_non_direct_acc():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
}
):
graph = gemm_non_direct_acc()
IndexingContext.current().finalize()
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
print_trace(graph)
# CHECK: %add_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {})
# CHECK: %add_1_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {})
# CHECK: %add_1_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {})
# CHECK: %add_0_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {})
# CHECK: %mma_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0), kwargs = {})
# CHECK: %mma_0_0_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_0_1, %mma_0_0_0), kwargs = {})
# CHECK: %mma_1_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_1_0, %add_1_1_0), kwargs = {})
# CHECK: %mma_1_1_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_1_1, %mma_1_1_0), kwargs = {})
# CHECK: %mma_1_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_0_0, %add_1_0_0), kwargs = {})
# CHECK: %mma_1_0_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_0_1, %mma_1_0_0), kwargs = {})
# CHECK: %mma_0_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_1_0, %add_0_1_0), kwargs = {})
# CHECK: %mma_0_1_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_1_1, %mma_0_1_0), kwargs = {})


@tkw.wave_trace_only()
def tiled_max(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
):
init_max = tkl.Register[M, tkl.f16](-1e6)

@tkw.reduction(K, init_args=[init_max])
def repeat(acc: tkl.Register[M, tkl.f16]) -> tkl.Register[M, tkl.f16]:
a_reg = tkw.read(a, elements_per_thread=4)
partial_max = tkw.max(a_reg, acc, dim=K)
return partial_max

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_tiled_max():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 1, 1),
vector_shapes={M: 16, K: 4},
)
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_K: 32,
}
):
graph = tiled_max()
IndexingContext.current().finalize()
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
print_trace(graph)
# CHECK: max(arg=[read_0_0, read_0_1, read_0_2, read_0_3, read_0_4, read_0_5, read_0_6, read_0_7], init=acc_0_0
# CHECK: max(arg=[read_1_0, read_1_1, read_1_2, read_1_3, read_1_4, read_1_5, read_1_6, read_1_7], init=acc_1_0
# CHECK: output(return_vals=([max_0_0, max_1_0],))
# CHECK-NEXT: -----


@run_test
def test_gemm_reduction_expansion_only():
# Note: This does not implement an actual gemm computation but reuses the
Expand Down

0 comments on commit d9e0b24

Please sign in to comment.