diff --git a/iree/turbine/kernel/wave/shared_memory_indexing.py b/iree/turbine/kernel/wave/shared_memory_indexing.py index 8b2a11b6..fb345c75 100644 --- a/iree/turbine/kernel/wave/shared_memory_indexing.py +++ b/iree/turbine/kernel/wave/shared_memory_indexing.py @@ -5,13 +5,17 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._support.tracing import CapturedTrace -from ..ops.wave_ops import Read, Write, get_custom +from ..ops.wave_ops import Read, Write, MMA, get_custom from ..lang.global_symbols import * -from .utils import remove_global_indexing +from .utils import remove_global_indexing, align_index_vars from .constraints import Constraint, TilingConstraint import torch.fx as fx +def is_shared_mem_access(custom: "CustomOp") -> bool: + return custom.memory_type.address_space == SHARED_ADDRESS_SPACE + + def apply_shared_memory_indexing_corrections( trace: CapturedTrace, constraints: list[Constraint] ): @@ -23,9 +27,25 @@ def apply_shared_memory_indexing_corrections( def is_shared_memory_read_or_write(node: fx.Node): custom = get_custom(node) - if isinstance(custom, (Read, Write)): - if custom.memory_type.address_space == SHARED_ADDRESS_SPACE: - custom.index = remove_global_indexing(custom.index, constraints) + if isinstance(custom, (Read, Write)) and is_shared_mem_access(custom): + custom.index = remove_global_indexing(custom.index, constraints) return False trace.walk(is_shared_memory_read_or_write) + + +def align_index_sizes(trace: CapturedTrace, constraints: list[Constraint]): + """ + Adjust ops index sizes to WG/Tile size, so shared mem ops never need to + do partial read/writes. + """ + + def need_align(node: fx.Node): + custom = get_custom(node) + if isinstance(custom, (Read, Write)) and is_shared_mem_access(custom): + custom.index = align_index_vars(custom.index, constraints) + elif isinstance(custom, MMA): + custom.index = align_index_vars(custom.index, constraints) + return False + + trace.walk(need_align) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index aa29b1c7..d0d00e89 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -273,18 +273,27 @@ def remove_global_indexing( workgroup_ids = [WORKGROUP_0, WORKGROUP_1, WORKGROUP_2] subs = {w: 0 for w in workgroup_ids} + new_index = {key: safe_subs(index[key], subs) for key in index} + for key in new_index: + for constraint in tiling_constraints: + new_index[key] = new_index[key].subs({constraint.induction_var: 0}) + return new_index + + +def align_index_vars( + index: dict[IndexSymbol, IndexSequence], constraints: list[Constraint] +) -> dict[IndexSymbol, IndexSequence]: + """ + This function aligns index vars with Workgroup/Tiling constraints so it never + need partial reads/writes. + """ key_subs = { c.dim: (c.count * c.tile_size) for c in constraints if isinstance(c, (TilingConstraint, WorkgroupConstraint)) and subs_idxc(c.dim) != subs_idxc(c.count * c.tile_size) } - - new_index = {safe_subs(key, key_subs): safe_subs(index[key], subs) for key in index} - for key in new_index: - for constraint in tiling_constraints: - new_index[key] = new_index[key].subs({constraint.induction_var: 0}) - return new_index + return {safe_subs(key, key_subs): index[key] for key in index} def _invoke(vm_context, device, entry_function, inputs, outputs): diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 177d9867..e45788ff 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -38,7 +38,10 @@ from ..ops import wave_ops from ..ops.wave_ops import Reduction, CustomOp, get_custom from .index_sequence_analysis import partition_strided_operators -from .shared_memory_indexing import apply_shared_memory_indexing_corrections +from .shared_memory_indexing import ( + apply_shared_memory_indexing_corrections, + align_index_sizes, +) from .thread_shape_analysis import determine_thread_shapes from .scheduling.schedule import schedule_graph from .._support.indexing import IndexingContext, IndexExpr @@ -239,6 +242,9 @@ def _trace_and_get_kernel_signature( # Analyze Thread Shapes per Op. determine_thread_shapes(graph) + # Align sizes to WG/Tile sizes + align_index_sizes(graph, self.constraints) + # Decompose reduce Ops. decompose_reduce_ops(graph, self.constraints, idxc.subs)