diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 19e3f64c..d0c96637 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1236,19 +1236,41 @@ class ReduceOp(CustomOp, ABC): dim: which symbolic dim to reduce. """ - arg: fx.Node + arg: fx.Node | list[fx.Node] init: fx.Node = None dim: Optional[Any] = None @property def indexing_dims(self) -> list[IndexSymbol]: - src_indexing = get_custom(self.arg).indexing_dims + # Local import to break circular dep. + from ..wave.utils import all_equal + + if isinstance(self.arg, Sequence): + src_indexings = [get_custom(arg).indexing_dims for arg in self.arg] + if not all_equal(src_indexings): + raise NotImplementedError( + "NYI: Only support case where all inputs to ReduceOp to have same indexing dim." + ) + src_indexing = src_indexings[0] + else: + src_indexing = get_custom(self.arg).indexing_dims dst_indexing = [dim for dim in src_indexing if dim != self.dim] return dst_indexing @property def type(self) -> Memory: - src_type = get_custom(self.arg).type + if isinstance(self.arg, Sequence): + # Local import to break circular dep. + from ..wave.utils import all_equal + + src_types = [get_custom(arg).type for arg in self.arg] + if not all_equal(src_types): + raise NotImplementedError( + "NYI: Only support case where all inputs to ReduceOp to have same type." + ) + src_type = src_types[0] + else: + src_type = get_custom(self.arg).type reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim] dst_type = Register[*reduced_dims, src_type.dtype] return dst_type diff --git a/iree/turbine/kernel/wave/decompose_reduce_ops.py b/iree/turbine/kernel/wave/decompose_reduce_ops.py index 0be318ef..bf972b75 100644 --- a/iree/turbine/kernel/wave/decompose_reduce_ops.py +++ b/iree/turbine/kernel/wave/decompose_reduce_ops.py @@ -23,7 +23,7 @@ Reduction, ) -from .utils import DCE, subs_idxc +from .utils import DCE, subs_idxc, all_equal import torch.fx as fx import math from typing import Callable @@ -37,6 +37,16 @@ def get_graph_node(custom: CustomOp, graph: fx.Graph): return custom +def emit_sources_reduction( + binary_fn: Callable, src: list[fx.Node], graph: fx.Graph +) -> fx.Node: + init = src[0] + for i in range(1, len(src)): + init = get_graph_node(binary_fn(init, src[i]), graph) + init.index = src[0].index + return init + + def emit_local_reduction( binary_fn: Callable, src: fx.Node, graph: fx.Graph, local_reduction_size: int ) -> fx.Node: @@ -67,11 +77,12 @@ def decompose_reduce_ops( ): """ The lowering for multi_reduction is done in two steps: - 1. Local Reduce: Each thread reduces all elements carried by it along + 1. Source Reduce: Each thread reduce locally all it's sources. + 2. Local Reduce: Each thread reduces all elements carried by it along the reduction dimensions. - 2. Thread Reduce: Each thread reduces result of step 1 across threads + 3. Thread Reduce: Each thread reduces result of step 2 across threads by doing a butterfly shuffle. - 3. Accumulator Reduce: Each thread reduces it's intermediate reduced + 4. Accumulator Reduce: Each thread reduces it's intermediate reduced results with the accumulator it holds. """ # Get reducte nodes. @@ -98,9 +109,18 @@ def decompose_reduce_ops( raise ValueError( "No reduction dim specified, please specify a reduction dim." ) + if not isinstance(reduction_src, (list, tuple)): + reduction_src = [reduction_src] # Local Reduce - if reduction_dim is not get_custom(custom.arg).type.symbolic_shape[-1]: + src_fastest_dims = [ + get_custom(arg).type.symbolic_shape[-1] for arg in reduction_src + ] + if not all_equal(src_fastest_dims): + raise NotImplementedError( + "NYI: Expect all reduce_src to have same fastest dim." + ) + if reduction_dim is not src_fastest_dims[0]: raise NotImplementedError( "Only implemented reduction on fastest dimension." ) @@ -108,9 +128,18 @@ def decompose_reduce_ops( get_thread_shape = lambda index: max( subs_idxc(x.size) for x in index.values() ) - local_reduction_size = get_thread_shape(get_custom(custom.arg).index) + local_reduce_sizes = [ + get_thread_shape(get_custom(arg).index) for arg in reduction_src + ] + if not all_equal(local_reduce_sizes): + raise NotImplementedError( + "NYI: Expect all reduce_src to have same local reduce size." + ) + src_reduction = emit_sources_reduction( + binary_fn, reduction_src, custom.graph + ) local_reduction = emit_local_reduction( - binary_fn, reduction_src, custom.graph, local_reduction_size + binary_fn, src_reduction, custom.graph, local_reduce_sizes[0] ) # Global Reduce diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index f3c9201d..674b8d5b 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -739,3 +739,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x16_F8: return 16 + + +def all_equal(input_list: list[Any]) -> bool: + if len(input_list) == 0: + return True + return all(elem == input_list[0] for elem in input_list) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 5778b689..61e08298 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1211,6 +1211,69 @@ def test( # CHECK: arith.addf {{.*}} : vector<1xf16> +# Tests for multiple local reduction, and we to emit and iteratively slice and reduce over multiple variables correctly. +@run_test +def test_mutliple_local_reduce_sum(): + M = tkl.sym.M + N = tkl.sym.N + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(1, 1, 1), + vector_shapes={M: 1, N: BLOCK_N}, + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, ADDRESS_SPACE, tkl.f16], + ): + lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD) + rhs = tkw.read(b, elements_per_thread=ELEMS_PER_THREAD) + res = tkw.sum([lhs, rhs], dim=N) + tkw.write(res, c, elements_per_thread=1) + + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + + shape = (256, 128) + a = torch.randn(shape, dtype=torch.float16) + b = torch.randn(shape, dtype=torch.float16) + c = torch.zeros((shape[0],), dtype=torch.float16) + with tk.gen.TestLaunchContext( + { + M: shape[0], + N: shape[1], + BLOCK_M: 1, + BLOCK_N: 128, + ELEMS_PER_THREAD: 2, + ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value, + }, + canonicalize=True, + ): + print(test(a, b, c).module_op) + # CHECK: %[[LHS:.+]] = vector.load {{.*}} : memref<256x128xf16 + # CHECK: %[[RHS:.+]] = vector.load {{.*}} : memref<256x128xf16 + # Reduce all sources locally. + # CHECK: %[[SRC_REDUC:.+]] = arith.addf %[[LHS]], %[[RHS]] : vector<2xf16> + # Do Local Reductions. + # CHECK: %[[LOCAL_REDUC0:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [0], sizes = [1], strides = [1]} + # CHECK: %[[LOCAL_REDUC1:.+]] = vector.extract_strided_slice %[[SRC_REDUC]] {offsets = [1], sizes = [1], strides = [1]} + # CHECK: %[[REDUC_0:.+]] = arith.addf %[[LOCAL_REDUC0]], %[[LOCAL_REDUC1]] : vector<1xf16> + # Expanded Global Max Reduction + # CHECK-COUNT-6: gpu.shuffle xor + + # This test is to ensure that the propagation of indexing_dims between reduction and operations # outside the reduction is working properly. @run_test