Skip to content

Commit

Permalink
[TKW] Teach expansion to handle non direct acc and ReduceOp on reduct…
Browse files Browse the repository at this point in the history
…ion dim. (#243)

In flash attention, we need to enable non direct acc matmul, and also
expansion of reduceOp in reduction dimension. The former is needed in FA
since we are applying some scaling to the acc of second MMA before
feeding it in. The second case is required in FA because ReduceOp/MaxOp
is in the backward slice of second MMA's LHS, which would require it to
be expanded in K2/reduction dim as well.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Oct 28, 2024
1 parent 32a47b2 commit ddc8dbd
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 61 deletions.
13 changes: 11 additions & 2 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def type(self) -> Memory:
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhstype
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
return broadcasted_type


Expand Down Expand Up @@ -1300,7 +1300,12 @@ def type(self) -> Memory:
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
if not all_equal(src_types):
ref_shape = src_types[0].symbolic_shape
ref_dtype = src_types[0].dtype
if not all(
src_type.symbolic_shape == ref_shape and src_type.dtype == ref_dtype
for src_type in src_types
):
raise NotImplementedError(
"NYI: Only support case where all inputs to ReduceOp to have same type."
)
Expand All @@ -1322,6 +1327,10 @@ def num_reduction_dims(self) -> int:
else:
return 1

@property
def reduction_dim(self) -> IndexSymbol:
return self.dim


# TODO: Add support for more shuffle types.
@define_op("shuffle")
Expand Down
194 changes: 135 additions & 59 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def _expand_reduction(
dims = {dim: val for dim, val in zip(dim_scaling.keys(), dim_vals)}
if not isinstance(return_vals, Sequence):
return_vals = [return_vals]
# Proceed with expansion inside the reduction
for arg_idx, arg in enumerate(return_vals):
arg = get_custom(arg)
# Add GetResult nodes for the corresponding dimensions
Expand All @@ -327,33 +328,48 @@ def _expand_reduction(
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
] = new_node

# Proceed with expansion inside the reduction
new_output_args.append(
_expand_node(
arg,
trace,
dims,
get_node_dim_scaling(arg),
context,
get_node_dim_scaling,
res_idx,
)
expanded_output = _expand_node(
arg,
trace,
dims,
get_node_dim_scaling(arg),
context,
get_node_dim_scaling,
res_idx,
)
# If condition below is needed to skip over induction variable
# who doesn't have all dims of ReductionOp. For example,
# a reduction Op that has induction variables of types
# (max, mma) -> [M], [M, N]
# will have indexing dims of ([M, N]).
# However, the 1st induction variable won't expand in N-dim
# M:0, N:0 expand(max) -> max_0_0_0
# M:0, N:1 expand(max) -> max_0_0_0
# but will get added to the `new_output_args` without the if condition.

# TODO: Handle expansion of induction variables with "non-complete" dims
# by checking on the indexing_dims on each induction variable.
if expanded_output in new_output_args:
continue
new_output_args.append(expanded_output)

# Proceed with expansion outside the reduction
for init_arg in reduction.init_args:
custom_init_arg = get_custom(init_arg)
new_init_args.append(
_expand_node(
custom_init_arg,
trace,
dims,
get_node_dim_scaling(custom_init_arg),
context,
get_node_dim_scaling,
res_idx,
)
expanded_init_arg = _expand_node(
custom_init_arg,
trace,
dims,
get_node_dim_scaling(custom_init_arg),
context,
get_node_dim_scaling,
res_idx,
)
# TODO: Handle expansion of induction variables with "non-complete" dims
# by checking on the indexing_dims on each induction variable.
if expanded_init_arg in new_init_args:
continue
new_init_args.append(expanded_init_arg)

# Update init_args and return values
reduction.update_arg(
Expand Down Expand Up @@ -553,6 +569,54 @@ def get_dim_scaling(
return dim_scaling


def _expand_mma_tiled_reduction(
mma: MMA,
trace: CapturedTrace,
dim_query: dict[IndexSymbol, int],
dim_scaling: dict[IndexSymbol, int],
context: ExpandedNodeMap,
get_node_dim_scaling: Callable[[fx.Node], dict[IndexSymbol, int]],
res_idx: int,
) -> CustomOp:
latest_reduced_op = mma
# The initial nodes are expanded in the first dimension, so we start from 1
for scale_idx in range(1, dim_scaling[mma.reduction_dim]):
dim_query[mma.reduction_dim] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
dummy = Placeholder("dummy").add_to_graph(latest_reduced_op.graph)
dummy.type = None

saved_acc = latest_reduced_op.acc
latest_reduced_op.update_arg("acc", dummy)
new_node = _expand_node(
latest_reduced_op,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# Node is always cloned; Hence, will never be equal to latest reduced op
assert new_node != latest_reduced_op
# Update MMA_{t} to accumulate on MMA_{t-1}, and then save
# current MMA_{t} to outputs for use in next loop.
latest_reduced_op.update_arg("acc", saved_acc)
new_node.update_arg("acc", latest_reduced_op)
latest_reduced_op.graph.erase_node(dummy)
latest_reduced_op = new_node
return latest_reduced_op


def _handle_reduction_dim(
reduction: Reduction,
output: Output,
Expand All @@ -566,53 +630,65 @@ def _handle_reduction_dim(
# TODO: Register iter args with the reduction initially so accessing them is easier
iter_args: list[CustomOp] = []
reduction_subgraph = trace.get_subgraph(reduction.subgraph_name)

# TODO: Handle case where MMAs/ReduceOps do not have Output as direct consumer.
def get_output_index(custom: CustomOp):
output_users = [
get_custom(user)
for user in custom.fx_node.users
if isinstance(get_custom(user), Output)
]
if len(output_users) != 1:
raise NotImplementedError(
"NYI: Currently only handle direct and 1:1 MMA -> Output case."
)
return output_users[0].return_vals[0].index(custom.fx_node)

# Collect MMA and ReduceOp who's reduction axis matches parent ReductionOp.
reduction_root_ops = []
for node in (get_custom(fx_node) for fx_node in reduction_subgraph.nodes):
if isinstance(node, IterArg):
iter_args.append(node)
if isinstance(node, (MMA, ReduceOp)) and reduction.axis == node.reduction_dim:
reduction_root_ops.append(node)

new_outputs = list(reduction.outputs(trace.get_subgraph(reduction.subgraph_name)))
# Users of the loop carried nodes will be duplicated
for idx, carried_node in enumerate(iter_args):
# The initial nodes are expanded in the first dimension, so we start from 1
dim_scaling = get_node_dim_scaling(carried_node)
for scale_idx in range(1, dim_scaling[reduction.axis]):
for user in carried_node.users:
if isinstance(user, Output):
continue

dims = dict(user.fx_node.expanded_dims)
dims[reduction.axis] = scale_idx
# Temporarily replace the loop carried arg here to avoid
# duplicated expansion. Otherwise we have the following situation:
# Suppose we have:
# mma_0_0_0(..., acc_0_0_0)
# mma_0_0_1(..., mma_0_0_0)
# Expanding mma_0_0_1 to mma_0_0_2 will trigger expansion of its arg
# mma_0_0_0 in dims 0_0_2 as well, effectively duplicating the new node.
# To avoid this we temporarily replace the use of it with a dummy
# placeholder which will not trigger further expansion.
index = user.get_node_arg_index(carried_node)
dummy = Placeholder("dummy").add_to_graph(user.graph)
dummy.type = None

saved_arg = user.node_args[index]
user.update_arg(index, dummy)
new_node = _expand_node(
user,
for root_op in reduction_root_ops:
dim_scaling = get_node_dim_scaling(root_op)
dims = dict(root_op.fx_node.expanded_dims)
latest_reduced_op = root_op
op_output_index = get_output_index(root_op)
if isinstance(root_op, MMA):
latest_reduced_op = _expand_mma_tiled_reduction(
root_op,
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
elif isinstance(root_op, ReduceOp):
original_src = latest_reduced_op.arg
# The initial nodes are expanded in the first dimension, so we start from 1
for scale_idx in range(1, dim_scaling[reduction.axis]):
dims[root_op.reduction_dim] = scale_idx
current_src = latest_reduced_op.arg
if not isinstance(current_src, Sequence):
current_src = [current_src]
expanded_src = _expand_node(
get_custom(original_src),
trace,
dims,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)

# This expansion always happens, user should never be reused
assert new_node != user
user.update_arg(index, saved_arg)
new_node.update_arg(index, user)
user.graph.erase_node(dummy)
carried_node = user
new_outputs[idx] = new_node.fx_node

current_src.append(expanded_src.fx_node)
latest_reduced_op.update_arg("arg", current_src)
new_outputs[op_output_index] = latest_reduced_op.fx_node
init_dims = root_op.fx_node.expanded_dims
context[
(root_op, get_indexed_dims(init_dims, root_op), res_idx)
] = latest_reduced_op
output.update_arg("return_vals", new_outputs)
114 changes: 114 additions & 0 deletions lit_tests/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,120 @@ def test_batched_gemm():
# CHECK-NEXT: -----


@tkw.wave_trace_only()
def gemm_non_direct_acc(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
new_acc = tkw.exp2(a_reg) + acc
acc = tkw.mma(a_reg, b_reg, new_acc)
return acc

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_gemm_non_direct_acc():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1))
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
}
):
graph = gemm_non_direct_acc()
IndexingContext.current().finalize()
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
print_trace(graph)
# CHECK: %add_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_0_0), kwargs = {})
# CHECK: %add_1_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_1_0), kwargs = {})
# CHECK: %add_1_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_1_0_0, %acc_1_0_0), kwargs = {})
# CHECK: %add_0_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.add](args = (%exp2_0_0_0, %acc_0_1_0), kwargs = {})
# CHECK: %mma_0_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_0_0, %add_0_0_0), kwargs = {})
# CHECK: %mma_0_0_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_0_1, %mma_0_0_0), kwargs = {})
# CHECK: %mma_1_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_1_0, %add_1_1_0), kwargs = {})
# CHECK: %mma_1_1_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_1_1, %mma_1_1_0), kwargs = {})
# CHECK: %mma_1_0_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_0, %read_0_0_0, %add_1_0_0), kwargs = {})
# CHECK: %mma_1_0_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_1_0_1, %read_0_0_1, %mma_1_0_0), kwargs = {})
# CHECK: %mma_0_1_0
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_0, %read_0_1_0, %add_0_1_0), kwargs = {})
# CHECK: %mma_0_1_1
# CHECK-SAME: call_function[target=iree.turbine.kernel.ops.wave_ops.mma](args = (%read_0_0_1, %read_0_1_1, %mma_0_1_0), kwargs = {})


@tkw.wave_trace_only()
def tiled_max(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
):
init_max = tkl.Register[M, tkl.f16](-1e6)

@tkw.reduction(K, init_args=[init_max])
def repeat(acc: tkl.Register[M, tkl.f16]) -> tkl.Register[M, tkl.f16]:
a_reg = tkw.read(a, elements_per_thread=4)
partial_max = tkw.max(a_reg, acc, dim=K)
return partial_max

tkw.write(repeat, c, elements_per_thread=4)


@run_test
def test_tiled_max():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)]
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 1, 1),
vector_shapes={M: 16, K: 4},
)
]
with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_K: 32,
}
):
graph = tiled_max()
IndexingContext.current().finalize()
set_node_indices(graph, constraints)
expand_graph(graph, constraints)
set_post_expansion_indices(graph, constraints)
print_trace(graph)
# CHECK: max(arg=[read_0_0, read_0_1, read_0_2, read_0_3, read_0_4, read_0_5, read_0_6, read_0_7], init=acc_0_0
# CHECK: max(arg=[read_1_0, read_1_1, read_1_2, read_1_3, read_1_4, read_1_5, read_1_6, read_1_7], init=acc_1_0
# CHECK: output(return_vals=([max_0_0, max_1_0],))
# CHECK-NEXT: -----


@run_test
def test_gemm_reduction_expansion_only():
# Note: This does not implement an actual gemm computation but reuses the
Expand Down

0 comments on commit ddc8dbd

Please sign in to comment.