Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod committed Oct 29, 2024
1 parent 07337a9 commit 7025f48
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 44 deletions.
34 changes: 17 additions & 17 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,23 +1324,23 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
offset = custom.expanded_dims[innermost_dim]

# Determine whether to extract or combine.
if len(args) == 1:
# Extract the appropriate slice.
size = (
target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
)
vector = cast_vector(emitter, args[0])
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
vector,
[offset * size],
[size],
[1],
if len(args) > 1:
raise NotImplementedError(
"reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes"
)
emitter.bind_node_proxy(node, IRProxyValue(slice))
return

raise NotImplementedError(
"reshape: Currently only handles cases where target_vector_shapes > custom.vector_shapes"
# Extract the appropriate slice. The offset is obtained from the expanded_dim
# and so corresponds to the dim_query during expansion. To obtain the
# actual offset, we need to multiple by the size which is determined by comparing
# the source and target vector shapes along the innermost dimension.
size = target_vector_shapes[innermost_dim] // custom.vector_shapes[innermost_dim]
vector = cast_vector(emitter, args[0])
result_type = VectorType.get([size], vector.type.element_type)
slice = vector_d.extract_strided_slice(
result_type,
vector,
[offset * size],
[size],
[1],
)
emitter.bind_node_proxy(node, IRProxyValue(slice))
13 changes: 0 additions & 13 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,6 @@ def is_contiguous_dim(
return is_innermost_dim or all_unit_dims


def add_reshape(custom: CustomOp):
for arg in custom.node_args.values():
if not isinstance(arg, Sequence):
arg = [arg]
for subarg in arg:
# These are ops in the parent graph that have not had their vector shapes set yet.
if subarg.vector_shapes is None:
continue
if subarg.vector_shapes != custom.vector_shapes:
with custom.graph.inserting_before(custom.fx_node):
Reshape(subarg, custom.vector_shapes).add_to_graph(custom.graph)


def set_vector_shapes(
constraints: Sequence[Constraint],
mma_index: dict[MMA, dict[IndexSymbol, int]],
Expand Down
40 changes: 26 additions & 14 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def is_mma(node):

mapping: dict[MMA, dict[IndexSymbol, int]] = {}
mma_nodes = trace.walk(is_mma)
last_vector_shapes = {}
for node in mma_nodes:
custom: MMA = get_custom(node)
m, n = custom.acc_type.symbolic_shape[-2:]
Expand All @@ -250,19 +249,6 @@ def is_mma(node):
custom.vector_shapes.update(hardware_constraint.vector_shapes)
custom.anchor = custom
custom.reduction_dim = k
if last_vector_shapes and last_vector_shapes != custom.vector_shapes:
with custom.graph.inserting_before(custom.fx_node):
for i, arg in custom.node_args.items():
if is_reshape_needed(arg, custom.vector_shapes, last_vector_shapes):
reshape = Reshape(arg.fx_node, last_vector_shapes).add_to_graph(
custom.graph
)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)

last_vector_shapes = custom.vector_shapes

# Since expansion proceeds bottom-up, we set the vector shapes
# of the parent reduction to the vector shapes of the last MMA node.
Expand All @@ -272,6 +258,32 @@ def is_mma(node):
reduction.anchor = custom

mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes}

# Determine if any reshapes are required. Reshapes are added for
# chained matmuls when the vector shapes of the operands in one matmul
# differ from those in another matmul.
for src in mma_nodes:
custom_src = get_custom(src)
for dst in mma_nodes:
if src == dst:
continue
custom_dst = get_custom(dst)
lhs_slice = capture_backward_slice(custom_dst.lhs)
rhs_slice = capture_backward_slice(custom_dst.rhs)
if src in lhs_slice or src in rhs_slice:
with custom_dst.graph.inserting_before(dst):
for i, arg in custom_dst.node_args.items():
if is_reshape_needed(
arg, custom_dst.vector_shapes, custom_src.vector_shapes
):
reshape = Reshape(
arg.fx_node, custom_src.vector_shapes
).add_to_graph(custom.graph)
custom_reshape = get_custom(reshape)
custom_reshape.vector_shapes = custom.vector_shapes
custom_reshape.anchor = custom
custom.update_arg(i, reshape)

return mapping, mma_slices


Expand Down

0 comments on commit 7025f48

Please sign in to comment.