Skip to content

Commit

Permalink
Add transposed outputs
Browse files Browse the repository at this point in the history
Adds transpose mapping to write to bring
it closer to attention.

Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Oct 25, 2024
1 parent 7a8a23c commit 6633188
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
13 changes: 8 additions & 5 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def get_chain_mmt_asm(
B, M, K1, input_dtype = query_type.split("x")
B, K2, K1, input_dtype = key_type.split("x")
B, N, K2, input_dtype = value_type.split("x")
B, M, N, output_dtype = output_type.split("x")
B, N, M, output_dtype = output_type.split("x")
intermediate_output_type = f"{B}x{K2}x{M}x{output_dtype}"
intermediate_cast_type = f"{B}x{K2}x{M}x{input_dtype}"
transposed_cast_type = f"{B}x{M}x{K2}x{input_dtype}"
transposed_output_type = f"{B}x{M}x{N}x{output_dtype}"
return f"""
func.func @chain_mmt(%query: tensor<{query_type}>, %key: tensor<{key_type}>, %value: tensor<{value_type}>) -> tensor<{output_type}> {{
%c0 = arith.constant 0.0 : f32
Expand All @@ -30,11 +31,13 @@ def get_chain_mmt_asm(
%trunc = arith.truncf %result : tensor<{intermediate_output_type}> to tensor<{intermediate_cast_type}>
%init2 = tensor.empty() : tensor<{transposed_cast_type}>
%transpose = linalg.transpose ins(%trunc: tensor<{intermediate_cast_type}>) outs(%init2: tensor<{transposed_cast_type}>) permutation=[0, 2, 1]
%init3 = tensor.empty() : tensor<{output_type}>
%inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{output_type}>) -> tensor<{output_type}>
%init3 = tensor.empty() : tensor<{transposed_output_type}>
%inital_result3 = linalg.fill ins(%c0 : f32) outs(%init3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%result2 = linalg.batch_matmul_transpose_b ins(%transpose, %value: tensor<{transposed_cast_type}>, tensor<{value_type}>)
outs(%inital_result3 : tensor<{output_type}>) -> tensor<{output_type}>
return %result2 : tensor<{output_type}>
outs(%inital_result3 : tensor<{transposed_output_type}>) -> tensor<{transposed_output_type}>
%init4 = tensor.empty() : tensor<{output_type}>
%transpose2 = linalg.transpose ins(%result2: tensor<{transposed_output_type}>) outs(%init4: tensor<{output_type}>) permutation=[0, 2, 1]
return %transpose2 : tensor<{output_type}>
}}"""


Expand Down
17 changes: 13 additions & 4 deletions tests/kernel/wave/wave_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,19 @@ def testChainedGemm(
)
]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.iterator(2)
mapping = tkw.IndexMapping(
num_iterators=3, inputs={B: i, M: j, N: k}, outputs={B: i, N: k, M: j}
)

@tkw.wave(constraints)
def chained_gemm(
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
c: tkl.Memory[B, N, M, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[B, M, N, tkl.f32](0.0)

Expand All @@ -126,7 +133,9 @@ def repeat(
return acc

# repeat represents the results of the loop
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)
tkw.write(
repeat, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD
)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
Expand Down Expand Up @@ -172,15 +181,15 @@ def repeat(
q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16)
k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16)
v = torch.randn(shape[0], shape[2], shape[4], dtype=torch.float16)
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
output = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32)
mb = chained_gemm(q, k, v, output)

if test_dump_generated_mlir:
filename = f"wave_cgemm_{'x'.join(map(str, shape))}.mlir"
with open(filename, "w") as f:
f.write(mb.module_op.get_asm())

iree_ref = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
iree_ref = torch.zeros(shape[0], shape[2], shape[1], dtype=torch.float32)
generate_iree_ref(
"chain_mmt", [q, k, v], [iree_ref], config, run_bench=run_bench
)
Expand Down

0 comments on commit 6633188

Please sign in to comment.