Skip to content

Commit

Permalink
Make some fbgemm fp8 triton ops pt2 friendly (pytorch#3188)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#283

Pull Request resolved: pytorch#3188

Make some fbgemm fp8 triton ops pt2 friendly..

# What this diff tries to do
* stop using TensorWrapper and tl.reinterpret
* Remove the use of triton_heuristics for _kernel_matmul_fp8_row

# What this diff won't help:
* triton_herustics use cases of EVEN_K. One option is to just merge that into the autotuning configs

# need to do in the future:
* Update other ops, like quantize_fp8_row.
* Update documentation. Feels pretty outdated, and some still reference to TensorWrapper.

Reviewed By: jwfromm

Differential Revision: D63560103
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Sep 29, 2024
1 parent 00f2fd5 commit 4f38802
Showing 1 changed file with 39 additions and 39 deletions.
78 changes: 39 additions & 39 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
early_config_prune,
estimate_matmul_time,
)
from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -43,18 +42,18 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]:
return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12


def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper:
def convert_fp8_type(tensor: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
"""
Converts tensor to triton fp8 type.
Converts tensor to fp8 type.
Args:
tensor (torch.Tensor): input tensor.
dtype (tl.dtype): target triton dtype.
dtype (torch.dtype): target torch dtype.
Returns:
triton.TensorWrapper: fp8 tensor.
torch.Tensor: fp8 tensor.
"""
return tl_reinterpret(tensor, dtype=dtype)
return tensor.view(dtype=dtype)


def init_to_zero(name):
Expand Down Expand Up @@ -213,11 +212,6 @@ def get_configs_io_bound() -> List[Config]:
"k_key",
],
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row(
A_ptr,
Expand Down Expand Up @@ -246,7 +240,6 @@ def _kernel_matmul_fp8_row(
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
USE_BIAS: tl.constexpr,
AB_DTYPE: tl.constexpr,
NUM_SMS: tl.constexpr,
Expand Down Expand Up @@ -964,7 +957,7 @@ def get_tma_descriptor_kernel_param(self, name):
return self.cuda_descriptors[name]


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -995,13 +988,13 @@ def matmul_fp8_row(
torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :])
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
pt_fp8_dtype, _, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper torch fp8 dtype.
a_tl = convert_fp8_type(a, pt_fp8_dtype)
b_tl = convert_fp8_type(b, pt_fp8_dtype)
M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = (
prep_matmul(a_tl, b_tl, dot_out_dtype)
)
Expand Down Expand Up @@ -1111,8 +1104,10 @@ def persistent_grid_tma(META):
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma](
# pyre-ignore
torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[
persistent_grid_tma
](
desc_a,
desc_b,
desc_c,
Expand Down Expand Up @@ -1141,7 +1136,7 @@ def persistent_grid_tma(META):
USE_BIAS=bias is not None,
)
elif imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid](
a_tl,
b_tl,
c,
Expand All @@ -1168,7 +1163,7 @@ def persistent_grid_tma(META):
AB_DTYPE=False,
)
elif fp8_fast_accum:
_kernel_matmul_fp8_row[persistent_grid](
torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid](
a_tl,
b_tl,
c,
Expand Down Expand Up @@ -1196,7 +1191,9 @@ def persistent_grid_tma(META):
NUM_SMS=NUM_SMS,
)
else:
_kernel_matmul_fp8_row_no_fast_acc[persistent_grid](
torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[
persistent_grid
](
a_tl,
b_tl,
c,
Expand Down Expand Up @@ -1659,13 +1656,13 @@ def matmul_fp8_block(
Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale)
"""
# Get datatypes and constants to use.
_, tl_dtype, _, _ = get_fp8_constants()
pt_fp8_dtype, _, _, _ = get_fp8_constants()
# Handle 3D+ a shape
a_shape = a.shape
a = a.view(-1, a.size(-1))
# Reinterpret inputs into proper triton fp8 dtype.
a_tl = convert_fp8_type(a, tl_dtype)
b_tl = convert_fp8_type(b, tl_dtype)
# View inputs into proper torch fp8 dtype.
a_tl = convert_fp8_type(a, pt_fp8_dtype)
b_tl = convert_fp8_type(b, pt_fp8_dtype)

M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul(
a_tl, b_tl, dot_out_dtype
Expand Down Expand Up @@ -1794,14 +1791,16 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]:


def prep_matmul(
a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype]
) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]:
a: torch.Tensor, b: torch.Tensor, dot_out_dtype: Optional[torch.dtype]
) -> Tuple[
int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device
]:
"""
Shared bookkeeping for a @ b.T matmul.
Args:
a (TensorWrapper): [M, K] input tensor.
b (TensorWrapper): [N, K] input tensor.
a (torch.Tensor): [M, K] input tensor.
b (torch.Tensor): [N, K] input tensor.
dot_out_dtype (tl.dtype): Output type of tensor core.
Returns:
Expand All @@ -1827,15 +1826,16 @@ def prep_matmul(

# allocates output
assert a.dtype in [
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
] and b.dtype in [
tl.float8e4nv,
tl.float8e4b15,
tl.float8e5,
tl.float8e4b8,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]
assert b.dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]
c_dtype = torch.bfloat16
c_dtype_triton = tl.bfloat16
Expand Down

0 comments on commit 4f38802

Please sign in to comment.