Skip to content

Commit

Permalink
separate pass to laign sizes
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 committed Oct 16, 2024
1 parent db6b0a6 commit 3fedd88
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
30 changes: 25 additions & 5 deletions iree/turbine/kernel/wave/shared_memory_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._support.tracing import CapturedTrace
from ..ops.wave_ops import Read, Write, get_custom
from ..ops.wave_ops import Read, Write, MMA, get_custom
from ..lang.global_symbols import *
from .utils import remove_global_indexing
from .utils import remove_global_indexing, align_index_vars
from .constraints import Constraint, TilingConstraint
import torch.fx as fx


def is_shared_mem_access(custom: "CustomOp") -> bool:
return custom.memory_type.address_space == SHARED_ADDRESS_SPACE


def apply_shared_memory_indexing_corrections(
trace: CapturedTrace, constraints: list[Constraint]
):
Expand All @@ -23,9 +27,25 @@ def apply_shared_memory_indexing_corrections(

def is_shared_memory_read_or_write(node: fx.Node):
custom = get_custom(node)
if isinstance(custom, (Read, Write)):
if custom.memory_type.address_space == SHARED_ADDRESS_SPACE:
custom.index = remove_global_indexing(custom.index, constraints)
if isinstance(custom, (Read, Write)) and is_shared_mem_access(custom):
custom.index = remove_global_indexing(custom.index, constraints)
return False

trace.walk(is_shared_memory_read_or_write)


def align_index_sizes(trace: CapturedTrace, constraints: list[Constraint]):
"""
Adjust ops index sizes to WG/Tile size, so shared mem ops never need to
do partial read/writes.
"""

def need_align(node: fx.Node):
custom = get_custom(node)
if isinstance(custom, (Read, Write)) and is_shared_mem_access(custom):
custom.index = align_index_vars(custom.index, constraints)
elif isinstance(custom, MMA):
custom.index = align_index_vars(custom.index, constraints)
return False

trace.walk(need_align)
21 changes: 15 additions & 6 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,27 @@ def remove_global_indexing(
workgroup_ids = [WORKGROUP_0, WORKGROUP_1, WORKGROUP_2]
subs = {w: 0 for w in workgroup_ids}

new_index = {key: safe_subs(index[key], subs) for key in index}
for key in new_index:
for constraint in tiling_constraints:
new_index[key] = new_index[key].subs({constraint.induction_var: 0})
return new_index


def align_index_vars(
index: dict[IndexSymbol, IndexSequence], constraints: list[Constraint]
) -> dict[IndexSymbol, IndexSequence]:
"""
This function aligns index vars with Workgroup/Tiling constraints so it never
need partial reads/writes.
"""
key_subs = {
c.dim: (c.count * c.tile_size)
for c in constraints
if isinstance(c, (TilingConstraint, WorkgroupConstraint))
and subs_idxc(c.dim) != subs_idxc(c.count * c.tile_size)
}

new_index = {safe_subs(key, key_subs): safe_subs(index[key], subs) for key in index}
for key in new_index:
for constraint in tiling_constraints:
new_index[key] = new_index[key].subs({constraint.induction_var: 0})
return new_index
return {safe_subs(key, key_subs): index[key] for key in index}


def _invoke(vm_context, device, entry_function, inputs, outputs):
Expand Down
8 changes: 7 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
from ..ops import wave_ops
from ..ops.wave_ops import Reduction, CustomOp, get_custom
from .index_sequence_analysis import partition_strided_operators
from .shared_memory_indexing import apply_shared_memory_indexing_corrections
from .shared_memory_indexing import (
apply_shared_memory_indexing_corrections,
align_index_sizes,
)
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
Expand Down Expand Up @@ -239,6 +242,9 @@ def _trace_and_get_kernel_signature(
# Analyze Thread Shapes per Op.
determine_thread_shapes(graph)

# Align sizes to WG/Tile sizes
align_index_sizes(graph, self.constraints)

# Decompose reduce Ops.
decompose_reduce_ops(graph, self.constraints, idxc.subs)

Expand Down

0 comments on commit 3fedd88

Please sign in to comment.