diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 6f00caa5..12ed2151 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -289,11 +289,14 @@ def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: return prev_mma return None + # Look in the backward slices of both the LHS and RHS to find + # mmas. If found, add reshapes if necessary. for mma in mma_nodes: custom_mma = get_custom(mma) - prev_mma = find_mma_in_slice(custom_mma.lhs) or find_mma_in_slice( - custom_mma.rhs - ) + prev_mma = find_mma_in_slice(custom_mma.lhs) + if prev_mma: + add_reshape_if_needed(custom_mma, prev_mma) + prev_mma = find_mma_in_slice(custom_mma.rhs) if prev_mma: add_reshape_if_needed(custom_mma, prev_mma)