Skip to content

Commit

Permalink
Add non-persistent fp8 triton_rowwise kernel (#3212)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/benchmark#2484

Pull Request resolved: #3212

X-link: facebookresearch/FBGEMM#308

 triton_rowwise persistent kernel performs poorly on MI300 compared to the non-persistent kernel, when both are run with exhaustive AMD-specific tuning.

Reviewed By: htyu

Differential Revision: D63741099

fbshipit-source-id: c276415ddf8f5d24ffeba70b8ee6493011b393e1
  • Loading branch information
karthik-man authored and facebook-github-bot committed Oct 3, 2024
1 parent 9a845cc commit 8e2d4a0
Showing 1 changed file with 243 additions and 1 deletion.
244 changes: 243 additions & 1 deletion fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def matmul_fp8_row(
fp8_fast_accum: bool = True,
imprecise_acc: bool = False,
tma_persistent: bool = True,
no_use_persistent: bool = False,
) -> torch.Tensor:
"""
Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N].
Expand Down Expand Up @@ -1056,7 +1057,38 @@ def persistent_grid(META):
),
)

if tma_persistent:
if no_use_persistent:
logger.info("Using non-persistent kernel")
if bias is not None:
raise AssertionError("bias is not supported in non-persistent kernel")
# pyre-ignore
torch._library.capture_triton(_kernel_matmul_fp8_row_non_persistent)[grid](
a,
b,
c,
M,
N,
K,
m_key,
n_key,
k_key,
a_scale,
b_scale,
# bias,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
dot_out_dtype=dot_out_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
# GROUP_M=8,
# USE_BIAS=bias is not None,
AB_DTYPE=False,
)
elif tma_persistent:
# used by TMA persistent kernel
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
Expand Down Expand Up @@ -2422,3 +2454,213 @@ def quantize_fp8_block(
x_scale = x_scale.to(output_device) # pyre-ignore
del x, x_padded
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def prune_configs(configs, named_args, **kwargs):

SIZE_M = named_args["A"].shape[0]
SIZE_N = named_args["B"].shape[1]
SIZE_K = named_args["C"].shape[1]

pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_SIZE_M, BLOCK_SIZE_N, _ = (
kw["BLOCK_M"],
kw["BLOCK_N"],
kw["BLOCK_K"],
)
SPLIT_K = kw["SPLIT_K"]
if SIZE_M <= 32 and BLOCK_SIZE_M != 32:
continue
if SIZE_N <= 32 and BLOCK_SIZE_N != 32:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K):
continue
pruned_configs.append(config)
logging.info(f"pruned_configs: config len{len(pruned_configs)}")
return pruned_configs


def get_full_non_persistent_tuning_space(use_split_k):
if torch.version.hip is None:
logger.warning("Using HIP configs on CUDA device, this may be slow.")
configs = []
block_mn_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128]
split_k_range = [1]
num_warps_range = [1, 2, 4, 8, 16]
group_m_range = [1, 4, 8]
num_stage_range = [0]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
configs.append(
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_M": group_m,
"SPLIT_K": split_k,
},
num_stages=num_stages,
num_warps=num_warps,
)
)

return configs


MATMUL_CONFIGS: List[Config] = get_full_non_persistent_tuning_space(True)


@triton.autotune(
configs=MATMUL_CONFIGS,
key=["M", "N", "K"],
prune_configs_by={
"early_config_prune": prune_configs,
"perf_model": None,
"top_k": None,
},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row_non_persistent(
A,
B,
C,
M,
N,
K,
m_key,
n_key,
k_key,
A_scale,
B_scale,
stride_am,
stride_ak,
stride_bn,
stride_bk,
stride_cm,
stride_cn,
dot_out_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
fp8_fast_accum: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
AB_DTYPE: tl.constexpr,
) -> None:
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles.
Args:
A (TensorWrapper): [M, K] input tensor.
B (TensorWrapper): [N, K] input tensor.
C (TensorWrapper): [M, N] output tensor.
M (int): M dimension of input tensor.
N (int): N dimension of input tensor.
K (int): K dimension of input tensor.
m_key (int): Autotuning key for M dimension of input tensor.
n_key (int): Autotuning key for N dimension of input tensor.
k_key (int): Autotuning key for K dimension of input tensor.
A_scale (TensorWrapper): [M] reciprocal scale tensor per row. A * A_scale = original A
B_scale (TensorWrapper): [N] reciprocal scale tensor per row. B * B_scale = original B
stride_am (int): Stride of M dimension of A.
stride_ak (int): Stride of K dimension of A.
stride_bn (int): Stride of N dimension of B.
stride_bk (int): Stride of K dimension of B.
stride_cm (int): Stride of M dimension of C.
stride_cn (int): Stride of N dimension of C.
dot_out_dtype (torch.dtype): Output type of tensor core.
allow_tf32 (bool): Whether to use TF32 for tensor core.
fp8_fast_accum (bool): Whether to use fast accumulation for tensor core.
BLOCK_M (int): Block size for M dimension.
BLOCK_N (int): Block size for N dimension.
BLOCK_K (int): Block size for K dimension.
GROUP_M (int): Number of groups for M dimension swizzle.
SPLIT_K (int): Number of SM's to launch per row.
EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K.
AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core.
"""
# Matrix multiplication.
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# Re-order program ID for better L2 performance (swizzle).
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# Do matrix multiplication.
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# Pointers.
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)

for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
if fp8_fast_accum:
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
else:
acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk

# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# Invert scaling.
a_scale = tl.load(A_scale + rm, mask=rm < M)
b_scale = tl.load(B_scale + rn, mask=rn < N)
# Invert vector, then multiply on matrix for speed.
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale

acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# Handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)

0 comments on commit 8e2d4a0

Please sign in to comment.