Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make some fbgemm fp8 triton ops pt2 friendly #3188

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-unsafe
import logging
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import triton # @manual
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12


def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:
def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper:
"""
Converts tensor to triton fp8 type.

Expand Down Expand Up @@ -213,11 +213,6 @@ def get_configs_io_bound() -> List[Config]:
"k_key",
],
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row(
A_ptr,
Expand Down Expand Up @@ -246,7 +241,6 @@ def _kernel_matmul_fp8_row(
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
USE_BIAS: tl.constexpr,
AB_DTYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
Expand Down Expand Up @@ -964,7 +958,7 @@ def get_tma_descriptor_kernel_param(self, name):
return self.cuda_descriptors[name]


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -995,15 +989,15 @@ def matmul_fp8_row(
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
pt_fp8_dtype, _, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper torch fp8 dtype.
assert a.dtype == pt_fp8_dtype
assert b.dtype == pt_fp8_dtype
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
prep_matmul(a_tl, b_tl, dot_out_dtype)
prep_matmul(a, b, dot_out_dtype)
)

output_shape = a_shape[:-1] + (N,)
Expand Down Expand Up @@ -1049,22 +1043,22 @@ def persistent_grid_tma(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a_tl.data_ptr(),
a.data_ptr(),
M,
K,
META["BLOCK_M"],
META["BLOCK_K"],
a_tl.element_size(),
a.element_size(),
)

desc_helper.fill_2d_tma_descriptor(
"b",
b_tl.data_ptr(),
b.data_ptr(),
N,
K,
META["BLOCK_N"],
META["BLOCK_K"],
b_tl.element_size(),
b.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
Expand Down Expand Up @@ -1111,8 +1105,10 @@ def persistent_grid_tma(META):
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma](
# pyre-ignore
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
persistent_grid_tma
](
desc_a,
desc_b,
desc_c,
Expand Down Expand Up @@ -1141,9 +1137,9 @@ def persistent_grid_tma(META):
USE_BIAS=bias is not None,
)
elif imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
a,
b,
c,
M,
N,
Expand All @@ -1168,9 +1164,9 @@ def persistent_grid_tma(META):
AB_DTYPE=False,
)
elif fp8_fast_accum:
_kernel_matmul_fp8_row[persistent_grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
a,
b,
c,
M,
N,
Expand All @@ -1196,9 +1192,11 @@ def persistent_grid_tma(META):
NUM_SMS=NUM_SMS,
)
else:
_kernel_matmul_fp8_row_no_fast_acc[persistent_grid](
a_tl,
b_tl,
torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
persistent_grid
](
a,
b,
c,
M,
N,
Expand Down Expand Up @@ -1659,13 +1657,13 @@ def matmul_fp8_block(
Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
_, tl_fp8_dtype, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper triton fp8 dtype.
a_tl = reinterpret_fp8_type(a, tl_fp8_dtype)
b_tl = reinterpret_fp8_type(b, tl_fp8_dtype)

M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
a_tl, b_tl, dot_out_dtype
Expand Down Expand Up @@ -1794,14 +1792,18 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:


def prep_matmul(
a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype]
) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]:
a: Union[TensorWrapper, torch.Tensor],
b: Union[TensorWrapper, torch.Tensor],
dot_out_dtype: Optional[torch.dtype],
) -> Tuple[
int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
]:
"""
Shared bookkeeping for a @ b.T matmul.

Args:
a (TensorWrapper): [M, K] input tensor.
b (TensorWrapper): [N, K] input tensor.
a (torch.Tensor): [M, K] input tensor.
b (torch.Tensor): [N, K] input tensor.
dot_out_dtype (tl.dtype): Output type of tensor core.

Returns:
Expand All @@ -1812,7 +1814,8 @@ def prep_matmul(
n_key (int): Autotuning key for N dim.
k_key (int): Autotuning key for K dim.
c (Tensor): [M, N] output tensor.
dot_out_dtype (torch.dtype): Output type of tensor core.
c_dtype_triton (tl.dtype): Type of output tensor.
dot_out_dtype (tl.dtype): Output type of tensor core.
device (torch.device): Device of output tensor.
"""
device = a.device
Expand All @@ -1827,11 +1830,20 @@ def prep_matmul(

# allocates output
assert a.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
] and b.dtype in [
]
assert b.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
Expand Down
Loading