Skip to content

Commit

Permalink
[TKW] Add support for multiple/local reduceOp (#234)
Browse files Browse the repository at this point in the history
In order to support flash attention, we'd need to be able to expand
ReduceOps in the reduction dimension as well. We will do this by
expanding the source of ReduceOp and locally reduce all of them. In that
effort, we introduce this PR(1st out of 2) that add support of locally
reducing over multiple variables.

The second PR on the way would be expansion of ReduceOp.

In this PR we are contributing two things:
1. Checks for consistency of indexing_dims, types, thread_shapes for
multiple sources of ReduceOp
2. Modify emit of local reduction to generate iteratively slice and
reduce over multiple arguments/srcs.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Oct 24, 2024
1 parent 00dcee7 commit 50e17a5
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 10 deletions.
28 changes: 25 additions & 3 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,19 +1236,41 @@ class ReduceOp(CustomOp, ABC):
dim: which symbolic dim to reduce.
"""

arg: fx.Node
arg: fx.Node | list[fx.Node]
init: fx.Node = None
dim: Optional[Any] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
src_indexing = get_custom(self.arg).indexing_dims
# Local import to break circular dep.
from ..wave.utils import all_equal

if isinstance(self.arg, Sequence):
src_indexings = [get_custom(arg).indexing_dims for arg in self.arg]
if not all_equal(src_indexings):
raise NotImplementedError(
"NYI: Only support case where all inputs to ReduceOp to have same indexing dim."
)
src_indexing = src_indexings[0]
else:
src_indexing = get_custom(self.arg).indexing_dims
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
src_type = get_custom(self.arg).type
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
if not all_equal(src_types):
raise NotImplementedError(
"NYI: Only support case where all inputs to ReduceOp to have same type."
)
src_type = src_types[0]
else:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
Expand Down
43 changes: 36 additions & 7 deletions iree/turbine/kernel/wave/decompose_reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Reduction,
)

from .utils import DCE, subs_idxc
from .utils import DCE, subs_idxc, all_equal
import torch.fx as fx
import math
from typing import Callable
Expand All @@ -37,6 +37,16 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph):
return custom


def emit_sources_reduction(
binary_fn: Callable, src: list[fx.Node], graph: fx.Graph
) -> fx.Node:
init = src[0]
for i in range(1, len(src)):
init = get_graph_node(binary_fn(init, src[i]), graph)
init.index = src[0].index
return init


def emit_local_reduction(
binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int
) -> fx.Node:
Expand Down Expand Up @@ -67,11 +77,12 @@ def decompose_reduce_ops(
):
"""
The lowering for multi_reduction is done in two steps:
1. Local Reduce: Each thread reduces all elements carried by it along
1. Source Reduce: Each thread reduce locally all it's sources.
2. Local Reduce: Each thread reduces all elements carried by it along
the reduction dimensions.
2. Thread Reduce: Each thread reduces result of step 1 across threads
3. Thread Reduce: Each thread reduces result of step 2 across threads
by doing a butterfly shuffle.
3. Accumulator Reduce: Each thread reduces it's intermediate reduced
4. Accumulator Reduce: Each thread reduces it's intermediate reduced
results with the accumulator it holds.
"""
# Get reducte nodes.
Expand All @@ -98,19 +109,37 @@ def decompose_reduce_ops(
raise ValueError(
"No reduction dim specified, please specify a reduction dim."
)
if not isinstance(reduction_src, (list, tuple)):
reduction_src = [reduction_src]

# Local Reduce
if reduction_dim is not get_custom(custom.arg).type.symbolic_shape[-1]:
src_fastest_dims = [
get_custom(arg).type.symbolic_shape[-1] for arg in reduction_src
]
if not all_equal(src_fastest_dims):
raise NotImplementedError(
"NYI: Expect all reduce_src to have same fastest dim."
)
if reduction_dim is not src_fastest_dims[0]:
raise NotImplementedError(
"Only implemented reduction on fastest dimension."
)

get_thread_shape = lambda index: max(
subs_idxc(x.size) for x in index.values()
)
local_reduction_size = get_thread_shape(get_custom(custom.arg).index)
local_reduce_sizes = [
get_thread_shape(get_custom(arg).index) for arg in reduction_src
]
if not all_equal(local_reduce_sizes):
raise NotImplementedError(
"NYI: Expect all reduce_src to have same local reduce size."
)
src_reduction = emit_sources_reduction(
binary_fn, reduction_src, custom.graph
)
local_reduction = emit_local_reduction(
binary_fn, reduction_src, custom.graph, local_reduction_size
binary_fn, src_reduction, custom.graph, local_reduce_sizes[0]
)

# Global Reduce
Expand Down
6 changes: 6 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int:
return 4
case MMAType.F32_32x32x16_F8:
return 16


def all_equal(input_list: list[Any]) -> bool:
if len(input_list) == 0:
return True
return all(elem == input_list[0] for elem in input_list)
63 changes: 63 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,69 @@ def test(
# CHECK: arith.addf {{.*}} : vector<1xf16>


# Tests for multiple local reduction, and we to emit and iteratively slice and reduce over multiple variables correctly.
@run_test
def test_mutliple_local_reduce_sum():
M = tkl.sym.M
N = tkl.sym.N
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(1, 1, 1),
vector_shapes={M: 1, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16],
):
lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD)
res = tkw.sum([lhs, rhs], dim=N)
tkw.write(res, c, elements_per_thread=1)

config = {"backend": "rocm", "device": "hip", "target": "gfx942"}

shape = (256, 128)
a = torch.randn(shape, dtype=torch.float16)
b = torch.randn(shape, dtype=torch.float16)
c = torch.zeros((shape[0],), dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
BLOCK_M: 1,
BLOCK_N: 128,
ELEMS_PER_THREAD: 2,
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
):
print(test(a, b, c).module_op)
# CHECK: %[[LHS:.+]] = vector.load {{.*}} : memref<256x128xf16
# CHECK: %[[RHS:.+]] = vector.load {{.*}} : memref<256x128xf16
# Reduce all sources locally.
# CHECK: %[[SRC_REDUC:.+]] = arith.addf %[[LHS]], %[[RHS]] : vector<2xf16>
# Do Local Reductions.
# CHECK: %[[LOCAL_REDUC0:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [0], sizes = [1], strides = [1]}
# CHECK: %[[LOCAL_REDUC1:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [1], sizes = [1], strides = [1]}
# CHECK: %[[REDUC_0:.+]] = arith.addf %[[LOCAL_REDUC0]], %[[LOCAL_REDUC1]] : vector<1xf16>
# Expanded Global Max Reduction
# CHECK-COUNT-6: gpu.shuffle xor


# This test is to ensure that the propagation of indexing_dims between reduction and operations
# outside the reduction is working properly.
@run_test
Expand Down

0 comments on commit 50e17a5

Please sign in to comment.