Skip to content

Commit

Permalink
[TK] Lower read with mapping to vector.gather (#74)
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Aug 13, 2024
1 parent 7fb008f commit bd6ebc6
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 12 deletions.
47 changes: 47 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,53 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK: vector.load %[[DATA]][%[[IDX_X]], %[[IDX_Y]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>


@run
def test_read_mapped():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
mapping = tkw.IndexMapping(
num_iterators=2, inputs={N: i, M: j}, outputs={N: i, M: j}
)

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
tkw.read(a, mapping=mapping, elements_per_thread=16)

with codegen_test_context():
a = torch.randn(16, 16, dtype=torch.float16)
print(test(a).module_op)
# CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding)
# CHECK: %[[WG_0:.+]] = stream.dispatch.workgroup.id[0]
# CHECK: %[[WG_1:.+]] = stream.dispatch.workgroup.id[1]
# CHECK: %[[T0:.+]] = gpu.thread_id x
# CHECK: %[[T1:.+]] = gpu.thread_id y
# CHECK: %[[DATA:.+]] = stream.binding.subspan %[[ARG0]]
# CHECK: %[[C16:.+]] = arith.constant 16 : index
# CHECK: %[[WG0_OFF:.+]] = arith.muli %[[WG_0]], %[[C16]]
# CHECK: %[[C4:.+]] = arith.constant 4 : index
# CHECK: %[[T0_OFF:.+]] = arith.divsi %[[T0]], %[[C4]]
# CHECK: %[[IDX_X:.+]] = arith.addi %[[T0_OFF]], %[[WG0_OFF]]
# CHECK: %[[C16_0:.+]] = arith.constant 16 : index
# CHECK: %[[T1_OFF:.+]] = arith.muli %[[T1]], %[[C16_0]] : index
# CHECK: %[[C16_1:.+]] = arith.constant 16 : index
# CHECK: %[[WG1_OFF:.+]] = arith.muli %[[WG_1]], %[[C16_1]]
# CHECK: %[[IDX_Y:.+]] = arith.addi %[[WG1_OFF]], %[[T1_OFF]]
# CHECK: %[[OFF:.+]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
# CHECK: %[[MASK:.+]] = vector.constant_mask [16] : vector<16xi1>
# CHECK: %[[PASSTHRU:.+]] = vector.splat %{{.*}} : vector<16xf16>
# CHECK: %[[RES:.+]] = vector.gather %[[DATA]][%[[IDX_X]], %[[IDX_Y]]] [%[[OFF]]], %[[MASK]], %[[PASSTHRU]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> into vector<16xf16>


@run
def test_read_write():
constraints: list[tkw.Constraint] = [
Expand Down
3 changes: 3 additions & 0 deletions shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..lang.types import (
Index,
)
from ..lang.wave_types import IndexMapping
from ..ops.wave_ops import CustomOp, Placeholder, Reduction, Unknown

from .regions import RegionGraph, SubgraphTracer
Expand Down Expand Up @@ -111,6 +112,8 @@ def create_arg(self, a):
# Let DataType persist as arguments.
if isinstance(a, DataType):
return a
if isinstance(a, IndexMapping):
return a
return super().create_arg(a)


Expand Down
14 changes: 14 additions & 0 deletions shark_turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ def _subs_expr(expr: Any, subs: Iterable[tuple[IndexExpr, IndexExpr]]) -> Any:
return expr


def _is_identity_mapping(iters: Iterable[IndexSymbol], mapping: SymbolsMap) -> bool:
if len(iters) != len(mapping):
return False

for it, val in zip(iters, mapping.values()):
if it != val:
return False

return True


class IndexMapping:
"""
Represents a mapping between 2 sets of indices.
Expand Down Expand Up @@ -209,3 +220,6 @@ def map_output_indices(
self, symbols: Optional[tuple[IndexSymbol, ...]] = None
) -> tuple[IndexExpr, ...]:
return self._map_indices(self.output_mapping, symbols)

def is_output_identity(self) -> bool:
return _is_identity_mapping(self.iters.keys(), self.output_mapping)
2 changes: 2 additions & 0 deletions shark_turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ class Read(CustomOp):

@property
def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
return list(self.mapping.output_shape)
# TODO: This could contain ints.
return list(self.memory.type.symbolic_shape)

Expand Down
86 changes: 74 additions & 12 deletions shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@
import torch.fx as fx

from ..compiler.ir import (
DenseElementsAttr,
IndexType,
InsertionPoint,
Location,
OpResult,
IntegerAttr,
IntegerType,
IrType,
Value,
IndexType,
Location,
MemRefType,
OpResult,
ShapedType,
Value,
VectorType,
IntegerAttr,
arith_d,
func_d,
gpu_d,
Expand All @@ -30,6 +32,7 @@
from ..compiler.kernel_codegen import BoundKernelSignature
from .._support.tracing import CapturedTrace
from ..compiler.builder import IRProxyValue
from ..compiler.utils import strides_from_symbolic_shape
from ..compiler.vector_codegen import (
cast_kernel_buffer,
cast_py_literal,
Expand All @@ -38,7 +41,8 @@
)

# Indexing imports.
from .._support.indexing import IndexingContext
from .._support.indexing import IndexingContext, IndexExpr
from .indexing import IndexSequence


@dataclass
Expand Down Expand Up @@ -219,6 +223,23 @@ def handle_register(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("Register: Currently only stub implementation")


def _get_start_indices(
emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr]
) -> list[OpResult]:
start_indices = []
for dim_indexing in src_indices:
i = src_indices[dim_indexing]
if isinstance(i, IndexSequence):
i = i.start
start_indices.append(gen_sympy_index(emitter, i))

return start_indices


def _compute_offset(indices: list[int], strides: list[int]) -> int:
return int(sum(i * s for i, s in zip(indices, strides)))


@handle_op(read)
def handle_read(emitter: WaveEmitter, node: fx.Node):
# This is similar to tkl.store with fixed start indices for now.
Expand All @@ -227,22 +248,63 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
except ValueError as e:
raise ValidationError("Malformed arguments") from e

assert mapping is None, "mapping is not supported yet"

vector_shape = cast_py_literal(emitter, (elements_per_thread,))
# memory has no IR node yet.
kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory)

if not hasattr(node, "index"):
raise ValidationError("codegen expected read to have index attr.")

start_indices = []
for dim_indexing in node.index:
start_indices.append(gen_sympy_index(emitter, node.index[dim_indexing].start))
index = node.index

element_type = kb_ir_type.element_type
vector_type = VectorType.get(vector_shape, element_type)
result = vector_d.load(vector_type, kb_src, start_indices)
if mapping is None:
start_indices = _get_start_indices(emitter, index)
result = vector_d.load(vector_type, kb_src, start_indices)
else:
assert (
mapping.is_output_identity()
), "non-dentity output mapping is not supported yet"
mem_index = memory.index
input_mapping = mapping.map_input_indices(mem_index.keys())

iters = mapping.iters
subs = [(sym, expr.start) for sym, expr in zip(iters.keys(), index.values())]

input_index = {
key: m.subs(subs) for key, m in zip(mem_index.keys(), input_mapping)
}

strides = strides_from_symbolic_shape(IndexingContext.current(), mem_index)
offsets = []
subs = [(sym, 0) for sym in iters.keys()]
for i in range(elements_per_thread):
# Update most-minor dim, i.e. in case of identity mapping it will
# be quivalent to just vector.load
subs[-1] = (subs[-1][0], i)
indices = [int(i.subs(subs)) for i in input_mapping]
offsets.append(
IntegerAttr.get(IndexType.get(), _compute_offset(indices, strides))
)

start_indices = _get_start_indices(emitter, input_index)
offsets_vec_type = VectorType.get([elements_per_thread], IndexType.get())
mask_vec_type = VectorType.get(
[elements_per_thread], IntegerType.get_signless(1)
)

offsets_vec = arith_d.ConstantOp(
offsets_vec_type, DenseElementsAttr.get(offsets, offsets_vec_type)
)
mask = vector_d.constant_mask(mask_vec_type, [elements_per_thread])
zero = arith_d.ConstantOp(vector_type.element_type, 0)
passthru = vector_d.splat(vector_type, zero)

result = vector_d.gather(
vector_type, kb_src, start_indices, offsets_vec, mask, passthru
)

emitter.bind_node_proxy(node, IRProxyValue(result))


Expand Down

0 comments on commit bd6ebc6

Please sign in to comment.