Skip to content

Commit

Permalink
[TK] Basic tkw.write lowering (#76)
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Aug 13, 2024
1 parent 1257104 commit 7fb008f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
47 changes: 47 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,53 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
# CHECK: vector.load %[[DATA]][%[[IDX_X]], %[[IDX_Y]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>


@run
def test_read_write():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
res = tkw.read(a, elements_per_thread=16)
tkw.write(res, b, elements_per_thread=16)

with codegen_test_context():
a = torch.randn(16, 16, dtype=torch.float16)
b = torch.zeros(16, 16, dtype=torch.float16)
print(test(a, b).module_op)
# CHECK: func.func @test(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding)
# CHECK: %[[WG_0:.+]] = stream.dispatch.workgroup.id[0]
# CHECK: %[[WG_1:.+]] = stream.dispatch.workgroup.id[1]
# CHECK: %[[T0:.+]] = gpu.thread_id x
# CHECK: %[[T1:.+]] = gpu.thread_id y

# CHECK: %[[RES:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>

# CHECK: %[[OUT:.+]] = stream.binding.subspan %[[ARG1]]

# CHECK: %[[C16:.+]] = arith.constant 16 : index
# CHECK: %[[WG0_OFF:.+]] = arith.muli %[[WG_0]], %[[C16]]
# CHECK: %[[C4:.+]] = arith.constant 4 : index
# CHECK: %[[T0_OFF:.+]] = arith.divsi %[[T0]], %[[C4]]
# CHECK: %[[IDX_X:.+]] = arith.addi %[[T0_OFF]], %[[WG0_OFF]]
# CHECK: %[[C16_0:.+]] = arith.constant 16 : index
# CHECK: %[[T1_OFF:.+]] = arith.muli %[[T1]], %[[C16_0]] : index
# CHECK: %[[C16_1:.+]] = arith.constant 16 : index
# CHECK: %[[WG1_OFF:.+]] = arith.muli %[[WG_1]], %[[C16_1]]
# CHECK: %[[IDX_Y:.+]] = arith.addi %[[WG1_OFF]], %[[T1_OFF]]
# CHECK: vector.store %[[RES]], %[[OUT]][%[[IDX_X]], %[[IDX_Y]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16>


@run
def test_add_float():
constraints: list[tkw.Constraint] = [
Expand Down
26 changes: 25 additions & 1 deletion shark_turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,31 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):

@handle_op(write)
def handle_write(emitter: WaveEmitter, node: fx.Node):
raise NotImplementedError("Write: Currently only stub implementation")
try:
register, memory, elements_per_thread, mapping = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

assert mapping is None, "mapping is not supported yet"

# memory has no IR node yet.
kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory)
insert_vector = cast_vector(emitter, register, element_type=kb_ir_type.element_type)
insert_type = VectorType(insert_vector.type)

# TODO: Support elements_per_thread size mismatch and broadcasting
assert tuple(insert_type.shape) == (
elements_per_thread,
), f"Shape doesn't match: {tuple(insert_type.shape)} and {(elements_per_thread,)}"

if not hasattr(node, "index"):
raise ValidationError("codegen expected read to have index attr.")

start_indices = []
for dim_indexing in node.index:
start_indices.append(gen_sympy_index(emitter, node.index[dim_indexing].start))

vector_d.store(insert_vector, kb_dest, start_indices)


###############################################################################
Expand Down

0 comments on commit 7fb008f

Please sign in to comment.