Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 25, 2024
1 parent 042f49b commit 7a8a23c
Showing 1 changed file with 55 additions and 43 deletions.
98 changes: 55 additions & 43 deletions iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,23 @@ def _expand_node(
return context[(node, get_indexed_dims(dim_query, node), res_idx)]
elif isinstance(node, MMA):
# Handle expansion of MMA nodes whose reduction dim is not the same as the reduction
# dim of the parent reduction op.
if hasattr(node.graph, "parent_op") and node.reduction_dim not in dim_query:
# dim of the parent reduction op or when there is no parent reduction op.
has_parent_op = hasattr(node.graph, "parent_op")
reduction_axes_different = False
if has_parent_op:
reduction: Reduction = get_custom(node.graph.parent_op)
if reduction.axis != node.reduction_dim:
return _expand_mma_reduction(
node,
trace,
dim_query,
get_node_dim_scaling(node),
context,
get_node_dim_scaling,
res_idx,
)
reduction_axes_different = reduction.axis != node.reduction_dim
parallel_dim_query = node.reduction_dim not in dim_query
if (not has_parent_op or reduction_axes_different) and parallel_dim_query:
return _expand_mma_reduction(
node,
trace,
dim_query,
dim_scaling,
context,
get_node_dim_scaling,
res_idx,
)
elif isinstance(node, Reduction):
return _expand_reduction(
node, trace, dim_query, dim_scaling, context, get_node_dim_scaling, res_idx
Expand Down Expand Up @@ -248,35 +252,30 @@ def _expand_node(

# Proceed with expansion of the arguments
for i, arg in node.node_args.items():
new_arg = None
if not isinstance(arg, Sequence):
if is_expandable(arg):
new_arg = _expand_node(
arg,
trace,
restricted_dims,
get_node_dim_scaling(arg),
context,
get_node_dim_scaling,
res_idx,
)
new_node.update_arg(i, new_arg)
arg_list = arg
unpack = lambda x: x
if isinstance(arg, list):
if not all(is_expandable(a) for a in arg):
continue
else:
new_arg = []
for subarg in arg:
if is_expandable(subarg):
new_subarg = _expand_node(
subarg,
trace,
restricted_dims,
get_node_dim_scaling(subarg),
context,
get_node_dim_scaling,
res_idx,
)
new_arg.append(new_subarg.fx_node)
assert len(new_arg) == len(arg), "All subargs must be expanded"
new_node.update_arg(i, new_arg)
arg_list = [arg]
unpack = lambda x: x[0]
if not is_expandable(arg):
continue

new_args = []
for subarg in arg_list:
new_subarg = _expand_node(
subarg,
trace,
restricted_dims,
get_node_dim_scaling(subarg),
context,
get_node_dim_scaling,
res_idx,
)
new_args.append(new_subarg.fx_node)
new_node.update_arg(i, unpack(new_args))

context[(node, get_indexed_dims(restricted_dims, node), res_idx)] = new_node
return new_node
Expand Down Expand Up @@ -400,21 +399,34 @@ def _expand_mma_reduction(
"""

logger.debug(f"Expanding MMA reduction: {mma} in dims: {dim_query}")
expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim])

idxc = IndexingContext.current()
for dim in mma.indexing_dims:
if dim not in dim_scaling and mma.vector_shapes[dim] > 0:
tile_size = idxc.get_static_value(dim)
dim_scaling[dim] = tile_size // mma.vector_shapes[dim]

expand_dims = set(mma.indexing_dims) - set([mma.reduction_dim])

# Store the original mma node and accumulator value for expansion.
# When we begin expansion, we have a single mma node with the correct accumulator.
# This node corresponds to the dim query with all 0s and for this we reuse the
# original mma node. For all other queries, we create a new node.
# So say we have parallel dimensions {M, K2} and reduction dimension {K1}.
# For M = 0, K2 = 0, K1 = 0, we use the original mma node.
# For M = 0, K2 = 0, K1 = 1, we create a new node.
# Now, when it is time to expand along new parallel dimensions, we use the original node
# For M = 0, K2 = 1, K1 = 0, we use the original mma node so that the last cloned node's
# accumulator value is not modified.

dim_query_dims = tuple(dim_query.keys())
if not hasattr(_expand_mma_reduction, "acc"):
_expand_mma_reduction.acc = {}
if not hasattr(_expand_mma_reduction, "mma"):
_expand_mma_reduction.mma = {}
if dim_query_dims not in _expand_mma_reduction.mma:
if (
dim_query_dims not in _expand_mma_reduction.mma
or _expand_mma_reduction.mma[dim_query_dims].graph != mma.graph
):
_expand_mma_reduction.mma[dim_query_dims] = mma
_expand_mma_reduction.acc[dim_query_dims] = mma.acc

Expand Down

0 comments on commit 7a8a23c

Please sign in to comment.