Skip to content

Commit

Permalink
Add support for more intrinsics in attention
Browse files Browse the repository at this point in the history
This PR adds support for the 32x32x8 intrinsic in
attention. This should allow for a wider search space
when tuning attention.

In order to do so, the following changes were required:
- Add reshape as an anchor op in thread shape analysis,
  primarily to block the backward propagation of thread
  shapes from the MMA operands

- Propagate thread shapes from MMA operands only if
  they are propagatable

- Modify assignment of anchor ops to take indexing dims
  into account

- Modify the computation of cluster size and stride

- Simplification of code that adds reshapes

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Nov 1, 2024
1 parent 8febe6a commit 5d55747
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 23 deletions.
10 changes: 4 additions & 6 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ def determine_shuffle_config(
"""
access_pattern = index[reduction_dim]
elements_per_thread = access_pattern.size
cluster_size = vector_shapes[reduction_dim] // elements_per_thread

# Since we are only concerned with what happens within a subgroup,
# we can ignore the TID_1 and TID_2 components of the index. We can
Expand All @@ -67,11 +65,11 @@ def determine_shuffle_config(
offset = access_pattern.start.subs({k: 0 for k in ignore})
offset = subs_idxc(offset)
offset_table = [offset.subs({THREAD_0: i}) for i in range(subgroup_size)]
# Determine the thread ids participating in the shuffle.
unique_offsets = list(dict.fromkeys(offset_table))
cluster_size = len(set(offset_table))
thread_ids = []
for i in range(cluster_size):
thread_ids.append(offset_table.index(i * elements_per_thread))

for thread_offset in unique_offsets:
thread_ids.append(offset_table.index(thread_offset))
cluster_stride = [x - y for x, y in zip(thread_ids[1:], thread_ids[:-1])]
assert all_equal(cluster_stride), f"Cluster stride must be equal across threads."
return cluster_size, cluster_stride[0]
Expand Down
3 changes: 3 additions & 0 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def set_vector_shapes(
or node in mma_slices[mma][MMA_LHS]
or node in mma_slices[mma][MMA_RHS]
):
# Ensure that the operators indexing dims are present in the anchor.
if not set(custom.indexing_dims).issubset(mma.indexing_dims):
continue
custom.anchor = mma
custom.vector_shapes = custom.vector_shapes | mma.vector_shapes
return
Expand Down
21 changes: 17 additions & 4 deletions iree/turbine/kernel/wave/thread_shape_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]):
# Anchor Indicies and Conflict resolution helpers
#################################################################

anchorOpTypes = (Read, Write, MMA, ReduceOp)
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape)
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate)
legalSubtypes = (IterArg,)
nonPropagatableTypes = anchorOpTypes + noHandleTypes
Expand Down Expand Up @@ -201,9 +201,13 @@ def determine_thread_shapes(trace: CapturedTrace):
index_sizes, set([])
).union(bwd_slice)
elif isinstance(custom, MMA):
lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op)
rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op)
acc_slice = capture_forward_slice(custom.acc, propagatable_op)
lhs_bwd_slice = set([custom.lhs])
if propagatable_op(custom.lhs):
lhs_bwd_slice = capture_backward_slice(custom.lhs, propagatable_op)
rhs_bwd_slice = set([custom.rhs])
if propagatable_op(custom.rhs):
rhs_bwd_slice = capture_backward_slice(custom.rhs, propagatable_op)
acc_slice = capture_forward_slice(custom.fx_node, propagatable_op)
if not isinstance(get_custom(custom.acc), MMA):
acc_slice = acc_slice.union(
capture_backward_slice(custom.acc, propagatable_op)
Expand All @@ -220,6 +224,15 @@ def determine_thread_shapes(trace: CapturedTrace):
thread_size_to_ops[rhs_index] = thread_size_to_ops.get(
rhs_index, set([])
).union(rhs_bwd_slice)
elif isinstance(custom, Reshape):
# The reshape op acts like a barrier for the MMA preventing
# the mma from propagating the thread shapes of its reshaped
# operands backwards.
bwd_size = get_dim_sizes(custom.args.index)
bwd_slice = capture_backward_slice(custom.args, propagatable_op)
thread_size_to_ops[bwd_size] = thread_size_to_ops.get(
bwd_size, set([])
).union(bwd_slice)

# Go through each index-size buckets, and apply the index-size to ops in the bucket.
cummulative_set = set()
Expand Down
28 changes: 15 additions & 13 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,18 @@ def is_mma(node):
# in the backward slice of the lhs and rhs upto a previous mma (if one exists).
# So we check for the previous node of the first operator in the slice to see
# if it is an MMA and if so check if a reshape is required.
def add_reshape_if_needed(mma: MMA, prev_mma: MMA):
def add_reshape_if_needed(mma: MMA, prev_mma: MMA, arg_index: int):
with mma.graph.inserting_before(mma.fx_node):
for i, arg in mma.node_args.items():
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
custom.graph
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)
arg = mma.lhs if arg_index == 0 else mma.rhs
arg = get_custom(arg)
if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes):
reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph(
custom.graph
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(arg_index, reshape)

def find_mma_in_slice(node: CustomOp) -> Optional[MMA]:
"""
Expand All @@ -295,10 +296,10 @@ def find_mma_in_slice(node: CustomOp) -> Optional[MMA]:
custom_mma = get_custom(mma)
prev_mma = find_mma_in_slice(custom_mma.lhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)
add_reshape_if_needed(custom_mma, prev_mma, 0)
prev_mma = find_mma_in_slice(custom_mma.rhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)
add_reshape_if_needed(custom_mma, prev_mma, 1)

return mapping, mma_slices

Expand Down Expand Up @@ -641,7 +642,8 @@ def bfs(
filter_fn: Callable[[fx.node], bool],
) -> set[fx.Node]:
"""
Run BFS on the graph to capture the forward slice of a node.
Run BFS on the graph. The filter function is not applied to
the incoming node.
"""
visited: set[fx.Node] = set()
queue: list[fx.Node] = []
Expand Down
Loading

0 comments on commit 5d55747

Please sign in to comment.