diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index e19f9833..92219c1e 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -98,6 +98,53 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): # CHECK: vector.load %[[DATA]][%[[IDX_X]], %[[IDX_Y]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16> +@run +def test_read_mapped(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + mapping = tkw.IndexMapping( + num_iterators=2, inputs={N: i, M: j}, outputs={N: i, M: j} + ) + + @tkw.wave(constraints) + def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]): + tkw.read(a, mapping=mapping, elements_per_thread=16) + + with codegen_test_context(): + a = torch.randn(16, 16, dtype=torch.float16) + print(test(a).module_op) + # CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding) + # CHECK: %[[WG_0:.+]] = stream.dispatch.workgroup.id[0] + # CHECK: %[[WG_1:.+]] = stream.dispatch.workgroup.id[1] + # CHECK: %[[T0:.+]] = gpu.thread_id x + # CHECK: %[[T1:.+]] = gpu.thread_id y + # CHECK: %[[DATA:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[C16:.+]] = arith.constant 16 : index + # CHECK: %[[WG0_OFF:.+]] = arith.muli %[[WG_0]], %[[C16]] + # CHECK: %[[C4:.+]] = arith.constant 4 : index + # CHECK: %[[T0_OFF:.+]] = arith.divsi %[[T0]], %[[C4]] + # CHECK: %[[IDX_X:.+]] = arith.addi %[[T0_OFF]], %[[WG0_OFF]] + # CHECK: %[[C16_0:.+]] = arith.constant 16 : index + # CHECK: %[[T1_OFF:.+]] = arith.muli %[[T1]], %[[C16_0]] : index + # CHECK: %[[C16_1:.+]] = arith.constant 16 : index + # CHECK: %[[WG1_OFF:.+]] = arith.muli %[[WG_1]], %[[C16_1]] + # CHECK: %[[IDX_Y:.+]] = arith.addi %[[WG1_OFF]], %[[T1_OFF]] + # CHECK: %[[OFF:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> + # CHECK: %[[MASK:.+]] = vector.constant_mask [16] : vector<16xi1> + # CHECK: %[[PASSTHRU:.+]] = vector.splat %{{.*}} : vector<16xf16> + # CHECK: %[[RES:.+]] = vector.gather %[[DATA]][%[[IDX_X]], %[[IDX_Y]]] [%[[OFF]]], %[[MASK]], %[[PASSTHRU]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> into vector<16xf16> + + @run def test_read_write(): constraints: list[tkw.Constraint] = [ diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py index 05b726a0..ec609561 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/shark_turbine/kernel/_support/tracing.py @@ -30,6 +30,7 @@ from ..lang.types import ( Index, ) +from ..lang.wave_types import IndexMapping from ..ops.wave_ops import CustomOp, Placeholder, Reduction, Unknown from .regions import RegionGraph, SubgraphTracer @@ -111,6 +112,8 @@ def create_arg(self, a): # Let DataType persist as arguments. if isinstance(a, DataType): return a + if isinstance(a, IndexMapping): + return a return super().create_arg(a) diff --git a/shark_turbine/kernel/lang/wave_types.py b/shark_turbine/kernel/lang/wave_types.py index 4d0fc658..f25dd75b 100644 --- a/shark_turbine/kernel/lang/wave_types.py +++ b/shark_turbine/kernel/lang/wave_types.py @@ -137,6 +137,17 @@ def _subs_expr(expr: Any, subs: Iterable[tuple[IndexExpr, IndexExpr]]) -> Any: return expr +def _is_identity_mapping(iters: Iterable[IndexSymbol], mapping: SymbolsMap) -> bool: + if len(iters) != len(mapping): + return False + + for it, val in zip(iters, mapping.values()): + if it != val: + return False + + return True + + class IndexMapping: """ Represents a mapping between 2 sets of indices. @@ -209,3 +220,6 @@ def map_output_indices( self, symbols: Optional[tuple[IndexSymbol, ...]] = None ) -> tuple[IndexExpr, ...]: return self._map_indices(self.output_mapping, symbols) + + def is_output_identity(self) -> bool: + return _is_identity_mapping(self.iters.keys(), self.output_mapping) diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index 32fac5b3..5c3aaddf 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -572,6 +572,8 @@ class Read(CustomOp): @property def indexing_dims(self) -> list[IndexSymbol]: + if self.mapping is not None: + return list(self.mapping.output_shape) # TODO: This could contain ints. return list(self.memory.type.symbolic_shape) diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index 45c4e29a..ba845425 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -5,16 +5,18 @@ import torch.fx as fx from ..compiler.ir import ( + DenseElementsAttr, + IndexType, InsertionPoint, - Location, - OpResult, + IntegerAttr, + IntegerType, IrType, - Value, - IndexType, + Location, MemRefType, + OpResult, ShapedType, + Value, VectorType, - IntegerAttr, arith_d, func_d, gpu_d, @@ -30,6 +32,7 @@ from ..compiler.kernel_codegen import BoundKernelSignature from .._support.tracing import CapturedTrace from ..compiler.builder import IRProxyValue +from ..compiler.utils import strides_from_symbolic_shape from ..compiler.vector_codegen import ( cast_kernel_buffer, cast_py_literal, @@ -38,7 +41,8 @@ ) # Indexing imports. -from .._support.indexing import IndexingContext +from .._support.indexing import IndexingContext, IndexExpr +from .indexing import IndexSequence @dataclass @@ -219,6 +223,23 @@ def handle_register(emitter: WaveEmitter, node: fx.Node): raise NotImplementedError("Register: Currently only stub implementation") +def _get_start_indices( + emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr] +) -> list[OpResult]: + start_indices = [] + for dim_indexing in src_indices: + i = src_indices[dim_indexing] + if isinstance(i, IndexSequence): + i = i.start + start_indices.append(gen_sympy_index(emitter, i)) + + return start_indices + + +def _compute_offset(indices: list[int], strides: list[int]) -> int: + return int(sum(i * s for i, s in zip(indices, strides))) + + @handle_op(read) def handle_read(emitter: WaveEmitter, node: fx.Node): # This is similar to tkl.store with fixed start indices for now. @@ -227,8 +248,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): except ValueError as e: raise ValidationError("Malformed arguments") from e - assert mapping is None, "mapping is not supported yet" - vector_shape = cast_py_literal(emitter, (elements_per_thread,)) # memory has no IR node yet. kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory) @@ -236,13 +255,56 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): if not hasattr(node, "index"): raise ValidationError("codegen expected read to have index attr.") - start_indices = [] - for dim_indexing in node.index: - start_indices.append(gen_sympy_index(emitter, node.index[dim_indexing].start)) + index = node.index element_type = kb_ir_type.element_type vector_type = VectorType.get(vector_shape, element_type) - result = vector_d.load(vector_type, kb_src, start_indices) + if mapping is None: + start_indices = _get_start_indices(emitter, index) + result = vector_d.load(vector_type, kb_src, start_indices) + else: + assert ( + mapping.is_output_identity() + ), "non-dentity output mapping is not supported yet" + mem_index = memory.index + input_mapping = mapping.map_input_indices(mem_index.keys()) + + iters = mapping.iters + subs = [(sym, expr.start) for sym, expr in zip(iters.keys(), index.values())] + + input_index = { + key: m.subs(subs) for key, m in zip(mem_index.keys(), input_mapping) + } + + strides = strides_from_symbolic_shape(IndexingContext.current(), mem_index) + offsets = [] + subs = [(sym, 0) for sym in iters.keys()] + for i in range(elements_per_thread): + # Update most-minor dim, i.e. in case of identity mapping it will + # be quivalent to just vector.load + subs[-1] = (subs[-1][0], i) + indices = [int(i.subs(subs)) for i in input_mapping] + offsets.append( + IntegerAttr.get(IndexType.get(), _compute_offset(indices, strides)) + ) + + start_indices = _get_start_indices(emitter, input_index) + offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get()) + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + + offsets_vec = arith_d.ConstantOp( + offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type) + ) + mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread]) + zero = arith_d.ConstantOp(vector_type.element_type, 0) + passthru = vector_d.splat(vector_type, zero) + + result = vector_d.gather( + vector_type, kb_src, start_indices, offsets_vec, mask, passthru + ) + emitter.bind_node_proxy(node, IRProxyValue(result))