Skip to content

Commit

Permalink
[TKW] Fix indexing of permute to enable attention (#244)
Browse files Browse the repository at this point in the history
Most of the time in GPU programming, we would only materialize
"transposes"/"permutes" of data during reads and writes. When doing
transposes/permutations of data in GPU registers, it is most of the time
free/no-op, since threads will still own the same data but just
symbolically different.

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Oct 29, 2024
1 parent ddc8dbd commit 8846708
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 16 deletions.
15 changes: 0 additions & 15 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,18 +1400,3 @@ def type(self) -> Register:
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]

@property
def index(self) -> Optional[dict[IndexSymbol, IndexSequence]]:
"""
Computes the permuted index based on the target shape.
"""
src_type = get_custom(self.arg).type
dim_map = {
tgt: src for src, tgt in zip(src_type.symbolic_shape, self.target_shape)
}
return {tgt: get_custom(self.arg).index[src] for tgt, src in dim_map.items()}

@index.setter
def index(self, value: Any):
CustomOp.index.fset(self, value)
167 changes: 166 additions & 1 deletion tests/kernel/wave/wave_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import pytest
import torch
import math
import unittest
import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
Expand All @@ -20,7 +21,7 @@
from iree.turbine.kernel.wave.constraints import MMAType
import os
import json
from torch.testing import assert_close
from torch.testing import assert_close, assert_allclose

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
Expand Down Expand Up @@ -194,3 +195,167 @@ def repeat(
"chain_mmt", [q, k, v], [iree_ref], config, run_bench=run_bench
)
assert_close(output, iree_ref)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_attention"))
@pytest.mark.parametrize("enable_scheduling", [False])
@pytest.mark.parametrize(
"mfma_variant",
[
MMAType.F32_16x16x16_F16,
],
)
def testAttention(
shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request
):
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
# Input sizes
B = tkl.sym.B
M = tkl.sym.M
N = tkl.sym.N
K1 = tkl.sym.K1
K2 = tkl.sym.K2
# Workgroup tile sizes
BLOCK_B = tkl.sym.BLOCK_B
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K2 = tkl.sym.BLOCK_K2
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD

# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)]
constraints += [tkw.TilingConstraint(K2, BLOCK_K2)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=mfma_variant,
vector_shapes={B: 0, M: 16, N: 16},
)
]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(
num_iterators=3, inputs={B: i, M: j, N: k}, outputs={B: i, N: k, M: j}
)

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, N, M, tkl.f32](0.0)
init_sum = tkl.Register[B, M, tkl.f32](0.0)
init_max = tkl.Register[B, M, tkl.f32](-1e6)
# This microkernel encodes the fact that if the reduction
# dimension were tiled, then we would need to materialize a loop.
@tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])
def repeat(
partial_max: tkl.Register[B, M, tkl.f32],
partial_sum: tkl.Register[B, M, tkl.f32],
acc: tkl.Register[B, N, M, tkl.f32],
) -> (
tkl.Register[B, M, tkl.f32],
tkl.Register[B, M, tkl.f32],
tkl.Register[B, N, M, tkl.f32],
):
imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0)
q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# b_reg: tkw.Register[B, N, K, tkl.f16]
k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD)
# acc: tkw.Register[B, N, M, tkl.f32]
inner_acc = tkw.mma(k_reg, q_reg, imm_reg)
x_j = tkw.permute(inner_acc, target_shape=[B, M, K2])
m_j = tkw.max(x_j, partial_max, dim=K2)
e_delta_max = tkw.exp2(partial_max - m_j)
e_delta = tkw.exp2(x_j - m_j)
e_init = partial_sum * e_delta_max
d_j = tkw.sum(e_delta, e_init, dim=K2)
imm_f16 = tkw.cast(e_delta, tkl.f16)
v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD)
new_acc = acc * e_delta_max
acc = tkw.mma(v_reg, imm_f16, new_acc)
return m_j, d_j, acc

# repeat represents the results of the loop
res_max, res_sum, res_mm = repeat
res = res_mm / res_sum
tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant),
STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant),
BLOCK_B: 1,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K2: 32,
B: shape[0],
M: shape[1],
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
}
config = {"backend": "rocm", "device": "hip", "target": "gfx942"}
if run_bench:
config["benchmark_batch_size"] = 10
config["benchmark_repetitions"] = 3
if dump_perf is not None:
perf_filename = request.node.name + ".json"
config["benchmark_results_file"] = os.path.join(
dump_perf, "tk_" + perf_filename
)

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
run=True,
run_bench=run_bench,
run_config=config,
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
):
torch.manual_seed(0)
q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16)
k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16)
v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16)
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
log2e = 1.44269504089
dk_sqrt = math.sqrt(1.0 / shape[3])
# TODO: Add scaling of QK as part of kernel.
# TODO: Add variant of non-transposed V attention kernel.
mb = base_attention(q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), output)
torch_ref = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None
)

if test_dump_generated_mlir:
filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())

# TODO: Fix transposed writes to output.
assert_allclose(output.permute([0, 2, 1]), torch_ref)

0 comments on commit 8846708

Please sign in to comment.