From 4f245a877c9cde0c467a721a54aba249c1f2c326 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 16 Jul 2024 10:48:54 -0700 Subject: [PATCH] Add persistent+TMA matmul to fp8 gemm benchmark (#2377) Summary: Pull Request resolved: https://github.com/pytorch/benchmark/pull/2377 Reviewed By: xuzhao9, sijiac Differential Revision: D59812172 Pulled By: bertmaher fbshipit-source-id: 450229888e09c22b9cd11a37015e6d601ec919ce --- torchbenchmark/operators/fp8_gemm/fp8_gemm.py | 23 +- .../operators/fp8_gemm/persistent.py | 368 ++++++++++++++++++ torchbenchmark/operators/fp8_gemm/tutorial.py | 348 +++++++++++++++++ 3 files changed, 734 insertions(+), 5 deletions(-) create mode 100644 torchbenchmark/operators/fp8_gemm/persistent.py create mode 100644 torchbenchmark/operators/fp8_gemm/tutorial.py diff --git a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py index f25fb3149..5f8d97256 100644 --- a/torchbenchmark/operators/fp8_gemm/fp8_gemm.py +++ b/torchbenchmark/operators/fp8_gemm/fp8_gemm.py @@ -2,7 +2,6 @@ import os import statistics import torch -import triton.ops import triton.language as tl from triton.runtime.jit import reinterpret @@ -17,10 +16,15 @@ register_metric, ) +from .tutorial import matmul as tutorial_matmul +from .persistent import matmul_persistent, matmul_tma_persistent, allocate_matmul_tma def parse_args(args): parser = argparse.ArgumentParser(description="TritonBench fp8_gemm") parser.add_argument("--llama", action="store_true") + parser.add_argument("--m", type=int) + parser.add_argument("--k", type=int) + parser.add_argument("--n", type=int) return parser.parse_args(args) @@ -45,7 +49,8 @@ def args(m, n, k): if self.extra_args.llama: for m, n, k, _bias in llama_shapes(): yield args(m, n, k) - + elif self.extra_args.m: + yield args(self.extra_args.m, self.extra_args.n, self.extra_args.k) else: for i in range(10, 15): for j in range(0, 4): @@ -70,9 +75,17 @@ def torch_fp8_gemm(self, a, b): @register_benchmark() def triton_fp8_gemm(self, a, b): - a = reinterpret(a, tl.float8e4nv) - b = reinterpret(b, tl.float8e4nv) - return lambda: triton.ops.matmul(a, b) + return lambda: tutorial_matmul(a, b) + + @register_benchmark() + def triton_persistent_fp8_gemm(self, a, b): + return lambda: matmul_persistent(a, b) + + @register_benchmark() + def triton_tma_persistent_fp8_gemm(self, a, b): + b = b.T.contiguous() + c, desc_a, desc_b, desc_c = allocate_matmul_tma(a, b) + return lambda: matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c) @register_metric() def gbps( diff --git a/torchbenchmark/operators/fp8_gemm/persistent.py b/torchbenchmark/operators/fp8_gemm/persistent.py new file mode 100644 index 000000000..7e8d482f0 --- /dev/null +++ b/torchbenchmark/operators/fp8_gemm/persistent.py @@ -0,0 +1,368 @@ +from functools import lru_cache + +import torch +import triton +import triton.language as tl +import triton.tools.experimental_descriptor + +if torch.cuda.is_available(): + from triton._C.libtriton import nvidia + + cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + cublas = nvidia.cublas.CublasLt(cublas_workspace) +else: + cublas = None + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_tma(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def _matmul_launch_metadata(grid, kernel, args): + ret = {} + M, N, K = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" + ret["flops8"] = 2.0 * M * N * K + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret["bytes"] = bytes_per_elem * (M * K + N * K) + return ret + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + pid_m = 0 + pid_n = 0 + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = tl.arange(0, BLOCK_SIZE_M) + offs_bn = tl.arange(0, BLOCK_SIZE_N) + offs_am = tl.where(offs_am < M - start_m, offs_am, 0) + offs_bn = tl.where(offs_bn < N - start_n, offs_bn, 0) + offs_am = tl.max_contiguous( + tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M + ) + offs_bn = tl.max_contiguous( + tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N + ) + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + a = tl.load( + a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + b = tl.load( + b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0 + ) + accumulator = tl.dot(a, b, accumulator) + + if ki == k_tiles - 1: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if c_ptr.dtype == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_persistent(a, b): + configs = { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + matmul_kernel_persistent[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_tma_persistent( + a_desc_ptr, + b_desc_ptr, + c_desc_ptr, # + M, + N, + K, # + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + FP8_OUTPUT: tl.constexpr, # + NUM_SMS: tl.constexpr, +): # + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + + offs_k = ki * BLOCK_SIZE_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_SIZE_N, BLOCK_SIZE_K], dtype + ) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + tl._experimental_descriptor_store(c_desc_ptr, c, [offs_am, offs_bn]) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def matmul_configs(): + # Autotuner does not work with TMA. Use manual config. + return { + torch.float8_e4m3fn: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + + +def allocate_matmul_tma(a, b): + configs = matmul_configs() + + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + c = torch.empty((M, N), device=a.device, dtype=dtype) + desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + a.data_ptr(), + M, + K, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_K"], + a.element_size(), + ) + desc_b = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + b.data_ptr(), + N, + K, + configs[dtype]["BLOCK_SIZE_N"], + configs[dtype]["BLOCK_SIZE_K"], + b.element_size(), + ) + desc_c = triton.tools.experimental_descriptor.create_2d_tma_descriptor( + c.data_ptr(), + M, + N, + configs[dtype]["BLOCK_SIZE_M"], + configs[dtype]["BLOCK_SIZE_N"], + c.element_size(), + ) + return c, desc_a, desc_b, desc_c + + +def matmul_tma_persistent(a, b, c, desc_a, desc_b, desc_c): + configs = matmul_configs() + + # Check constraints. + assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed + assert a.dtype == b.dtype, "Incompatible dtypes" + + M, K = a.shape + N, K = b.shape + dtype = a.dtype + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = lambda META: ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + matmul_kernel_tma_persistent[grid]( + desc_a, + desc_b, + desc_c, # + M, + N, + K, # + BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # + BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # + BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # + GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # + FP8_OUTPUT=dtype == torch.float8_e4m3fn, # + NUM_SMS=NUM_SMS, # + num_stages=configs[dtype]["num_stages"], # + num_warps=configs[dtype]["num_warps"], # + ) + return c diff --git a/torchbenchmark/operators/fp8_gemm/tutorial.py b/torchbenchmark/operators/fp8_gemm/tutorial.py new file mode 100644 index 000000000..d4225b6ea --- /dev/null +++ b/torchbenchmark/operators/fp8_gemm/tutorial.py @@ -0,0 +1,348 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP16 matrix multiplication kernel that achieves +performance on par with cuBLAS or rocBLAS. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = tl.program_id(axis=0) +# grid_n = tl.cdiv(N, BLOCK_SIZE_N) +# pid_m = pid // grid_n +# pid_n = pid % grid_n +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch + +import triton +import triton.language as tl + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_mi200(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == 'hip' and target.arch == 'gfx90a' + + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4) + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, + num_warps=8, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, + num_warps=4, num_stages=0), + triton.Config( + {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, + num_warps=4, num_stages=0), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=['M', 'N', 'K'], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # + ) + return c