From 69107f14fe3258e4be8a38ab431ed61f8279b203 Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Tue, 9 Jul 2024 08:37:33 +0000 Subject: [PATCH] Move Stream K kernel to benchmark folder --- .../xetla_benchmark/gemm_streamk_benchmark.py | 128 ++++++++++-------- 1 file changed, 73 insertions(+), 55 deletions(-) rename python/tutorials/10-experimental-block-pointer-streamk.py => benchmarks/xetla_benchmark/gemm_streamk_benchmark.py (74%) diff --git a/python/tutorials/10-experimental-block-pointer-streamk.py b/benchmarks/xetla_benchmark/gemm_streamk_benchmark.py similarity index 74% rename from python/tutorials/10-experimental-block-pointer-streamk.py rename to benchmarks/xetla_benchmark/gemm_streamk_benchmark.py index 518cc44dd8..38c408821e 100644 --- a/python/tutorials/10-experimental-block-pointer-streamk.py +++ b/benchmarks/xetla_benchmark/gemm_streamk_benchmark.py @@ -29,6 +29,18 @@ def swizzle_tile(tile_id, return pid_m, pid_n +@triton.jit +def linear_tile(tile_id, + # Matrix dimensions + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid_m = tile_id // tl.cdiv(N, BLOCK_SIZE_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_SIZE_N) + return pid_m, pid_n + + # Multiply-accumulate loop in GEMM Stream K tiles @triton.jit def mac_loop( @@ -45,9 +57,11 @@ def mac_loop( tile_id = start_iter // iters_per_tile remain_iters = start_iter % iters_per_tile - # Assume GROUP_M > 0 - # pid swizzle to get better L2 cache performance - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) + if GROUP_SIZE_M > 0: + # pid swizzle to get better L2 cache performance + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) a_ptr += BLOCK_SIZE_K * stride_ak * remain_iters a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), @@ -82,10 +96,8 @@ def mac_loop( @triton.autotune( configs=[ triton.Config( - { - 'double_GRF': True, 'threads_per_warp': 16, 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': - 32, 'GROUP_SIZE_M': 4 - }, num_stages=4, num_warps=32), + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, "threads_per_warp": 16}, + num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], ) @@ -119,10 +131,8 @@ def first_wave( @triton.autotune( configs=[ triton.Config( - { - 'double_GRF': True, 'threads_per_warp': 16, 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': - 32, 'GROUP_SIZE_M': 4 - }, num_stages=4, num_warps=32), + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, "threads_per_warp": 16}, + num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], ) @@ -144,8 +154,11 @@ def full_tiles( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): tile_id = tl.program_id(axis=0) + streamk_tiles - # Assume GROUP_M > 0 - pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) + if GROUP_SIZE_M > 0: + # pid swizzle to get better L2 cache performance + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M) a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), @@ -169,8 +182,11 @@ def full_tiles( tl.store(c_block_ptr, acc, boundary_check=(0, 1)) -# We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. +# --------------------------------------------------------------------------- +# Wrapper +# --------------------------------------------------------------------------- + + def matmul(a: torch.Tensor, b: torch.Tensor): num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count'] streamk_programs = num_xe_core @@ -248,43 +264,45 @@ def matmul(a: torch.Tensor, b: torch.Tensor): else: exit("❌ Triton and Torch differ") -# # Benchmark Performance -# @triton.testing.perf_report( -# triton.testing.Benchmark( -# # argument names to use as an x-axis for the plot -# x_names=['M', 'K', 'N'], -# x_vals=[[3072, 4096, 3072]], -# line_arg='provider', -# # argument name whose value corresponds to a different line in the plot -# # possible values for `line_arg`` -# line_vals=['onednn', 'triton'], -# # label name for the lines -# line_names=["onednn", "Triton"], -# # line styles -# # styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], -# ylabel="TFLOPS", # label name for the y-axis -# plot_name="matmul-performance", -# # name for the plot. Used also as a file name for saving the plot. -# args={}, -# )) -# def benchmark(M, N, K, provider): -# torch.manual_seed(0) -# a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) -# b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16) -# quantiles = [0.5, 0.2, 0.8] - -# # calculate tflops for oneDNN kernel -# def calculate_tflops(ms): -# return 2 * M * N * K * 1e-12 / (ms * 1e-3) - -# if provider == 'onednn': -# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, -# fast_flush=False) -# print(f"oneDNN Peak TFlops {calculate_tflops(min_ms)}") -# if provider == 'triton': -# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles, -# fast_flush=False) - -# return calculate_tflops(ms), calculate_tflops(min_ms), calculate_tflops(max_ms) - -# benchmark.run(show_plots=True, print_data=True) + +# Benchmark Performance +@triton.testing.perf_report( + triton.testing.Benchmark( + # argument names to use as an x-axis for the plot + x_names=['M', 'K', 'N'], + x_vals=[[3072, 4096, 3072]], + line_arg='provider', + # argument name whose value corresponds to a different line in the plot + # possible values for `line_arg`` + line_vals=['onednn', 'triton'], + # label name for the lines + line_names=["onednn", "Triton"], + # line styles + # styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], + ylabel="TFLOPS", # label name for the y-axis + plot_name="matmul-performance", + # name for the plot. Used also as a file name for saving the plot. + args={}, + )) +def benchmark(M, N, K, provider): + torch.manual_seed(0) + a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16) + b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16) + quantiles = [0.5, 0.2, 0.8] + + # calculate tflops for oneDNN kernel + def calculate_tflops(ms): + return 2 * M * N * K * 1e-12 / (ms * 1e-3) + + if provider == 'onednn': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) + print(f"oneDNN Peak TFlops {calculate_tflops(min_ms)}") + if provider == 'triton': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) + + return calculate_tflops(ms), calculate_tflops(min_ms), calculate_tflops(max_ms) + + +benchmark.run(show_plots=True, print_data=True)