Skip to content

Commit

Permalink
Move Stream K kernel to benchmark folder
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyangLingIntel committed Jul 10, 2024
1 parent 8d9922e commit 69107f1
Showing 1 changed file with 73 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def swizzle_tile(tile_id,
return pid_m, pid_n


@triton.jit
def linear_tile(tile_id,
# Matrix dimensions
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
pid_m = tile_id // tl.cdiv(N, BLOCK_SIZE_N)
pid_n = tile_id % tl.cdiv(N, BLOCK_SIZE_N)
return pid_m, pid_n


# Multiply-accumulate loop in GEMM Stream K tiles
@triton.jit
def mac_loop(
Expand All @@ -45,9 +57,11 @@ def mac_loop(

tile_id = start_iter // iters_per_tile
remain_iters = start_iter % iters_per_tile
# Assume GROUP_M > 0
# pid swizzle to get better L2 cache performance
pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)
if GROUP_SIZE_M > 0:
# pid swizzle to get better L2 cache performance
pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)
else:
pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)

a_ptr += BLOCK_SIZE_K * stride_ak * remain_iters
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
Expand Down Expand Up @@ -82,10 +96,8 @@ def mac_loop(
@triton.autotune(
configs=[
triton.Config(
{
'double_GRF': True, 'threads_per_warp': 16, 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K':
32, 'GROUP_SIZE_M': 4
}, num_stages=4, num_warps=32),
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, "threads_per_warp": 16},
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
)
Expand Down Expand Up @@ -119,10 +131,8 @@ def first_wave(
@triton.autotune(
configs=[
triton.Config(
{
'double_GRF': True, 'threads_per_warp': 16, 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K':
32, 'GROUP_SIZE_M': 4
}, num_stages=4, num_warps=32),
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, "threads_per_warp": 16},
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
)
Expand All @@ -144,8 +154,11 @@ def full_tiles(
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):

tile_id = tl.program_id(axis=0) + streamk_tiles
# Assume GROUP_M > 0
pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)
if GROUP_SIZE_M > 0:
# pid swizzle to get better L2 cache performance
pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)
else:
pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M)

a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
Expand All @@ -169,8 +182,11 @@ def full_tiles(
tl.store(c_block_ptr, acc, boundary_check=(0, 1))


# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------


def matmul(a: torch.Tensor, b: torch.Tensor):
num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count']
streamk_programs = num_xe_core
Expand Down Expand Up @@ -248,43 +264,45 @@ def matmul(a: torch.Tensor, b: torch.Tensor):
else:
exit("❌ Triton and Torch differ")

# # Benchmark Performance
# @triton.testing.perf_report(
# triton.testing.Benchmark(
# # argument names to use as an x-axis for the plot
# x_names=['M', 'K', 'N'],
# x_vals=[[3072, 4096, 3072]],
# line_arg='provider',
# # argument name whose value corresponds to a different line in the plot
# # possible values for `line_arg``
# line_vals=['onednn', 'triton'],
# # label name for the lines
# line_names=["onednn", "Triton"],
# # line styles
# # styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
# ylabel="TFLOPS", # label name for the y-axis
# plot_name="matmul-performance",
# # name for the plot. Used also as a file name for saving the plot.
# args={},
# ))
# def benchmark(M, N, K, provider):
# torch.manual_seed(0)
# a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
# b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]

# # calculate tflops for oneDNN kernel
# def calculate_tflops(ms):
# return 2 * M * N * K * 1e-12 / (ms * 1e-3)

# if provider == 'onednn':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
# fast_flush=False)
# print(f"oneDNN Peak TFlops {calculate_tflops(min_ms)}")
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles,
# fast_flush=False)

# return calculate_tflops(ms), calculate_tflops(min_ms), calculate_tflops(max_ms)

# benchmark.run(show_plots=True, print_data=True)

# Benchmark Performance
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['M', 'K', 'N'],
x_vals=[[3072, 4096, 3072]],
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
line_vals=['onednn', 'triton'],
# label name for the lines
line_names=["onednn", "Triton"],
# line styles
# styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
ylabel="TFLOPS", # label name for the y-axis
plot_name="matmul-performance",
# name for the plot. Used also as a file name for saving the plot.
args={},
))
def benchmark(M, N, K, provider):
torch.manual_seed(0)
a = torch.rand((M, K), device='xpu', dtype=torch.bfloat16)
b = torch.rand((K, N), device='xpu', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]

# calculate tflops for oneDNN kernel
def calculate_tflops(ms):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)

if provider == 'onednn':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles,
fast_flush=False)
print(f"oneDNN Peak TFlops {calculate_tflops(min_ms)}")
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles,
fast_flush=False)

return calculate_tflops(ms), calculate_tflops(min_ms), calculate_tflops(max_ms)


benchmark.run(show_plots=True, print_data=True)

0 comments on commit 69107f1

Please sign in to comment.