Skip to content

Commit

Permalink
Address comments #3
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 30, 2024
1 parent 9d6bb0f commit e09dff2
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def find_mma_in_slice(node: CustomOp) -> Optional[MMA]:
return prev_mma
return None

# Look in the backward slices of both the LHS and RHS to find
# mmas. If found, add reshapes if necessary.
for mma in mma_nodes:
custom_mma = get_custom(mma)
prev_mma = find_mma_in_slice(custom_mma.lhs) or find_mma_in_slice(
custom_mma.rhs
)
prev_mma = find_mma_in_slice(custom_mma.lhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)
prev_mma = find_mma_in_slice(custom_mma.rhs)
if prev_mma:
add_reshape_if_needed(custom_mma, prev_mma)

Expand Down

0 comments on commit e09dff2

Please sign in to comment.