From 259dd02ab3347b67611647b9ef4fdcc9b8b17579 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 30 Sep 2024 10:32:10 -0700 Subject: [PATCH] Make some fbgemm fp8 triton ops pt2 friendly (#3188) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/283 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3188 Make some fbgemm fp8 triton ops pt2 friendly.. # What this diff tries to do * stop using TensorWrapper and tl.reinterpret * Remove the use of triton_heuristics for _kernel_matmul_fp8_row # What this diff won't help: * triton_herustics use cases of EVEN_K. One option is to just merge that into the autotuning configs # need to do in the future: * Update other ops, like quantize_fp8_row. * Update documentation. Feels pretty outdated, and some still reference to TensorWrapper. Differential Revision: D63560103 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 99 ++++++++++--------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 07765fa21..9a9c64cea 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -6,7 +6,7 @@ # pyre-unsafe import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import triton # @manual @@ -15,10 +15,7 @@ from torch._tensor import Tensor from triton import Config # @manual -from triton.ops.matmul_perf_model import ( # @manual - early_config_prune, - estimate_matmul_time, -) +from triton.ops.matmul_perf_model import early_config_prune # @manual from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual logger: logging.Logger = logging.getLogger(__name__) @@ -43,7 +40,7 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 -def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: +def reinterpret_fp8_type(tensor: torch.Tensor, dtype: tl.dtype) -> TensorWrapper: """ Converts tensor to triton fp8 type. @@ -213,11 +210,6 @@ def get_configs_io_bound() -> List[Config]: "k_key", ], ) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) @triton.jit def _kernel_matmul_fp8_row( A_ptr, @@ -246,7 +238,6 @@ def _kernel_matmul_fp8_row( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, USE_BIAS: tl.constexpr, AB_DTYPE: tl.constexpr, NUM_SMS: tl.constexpr, @@ -964,7 +955,7 @@ def get_tma_descriptor_kernel_param(self, name): return self.cuda_descriptors[name] -@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=()) +@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=()) def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, @@ -995,15 +986,15 @@ def matmul_fp8_row( torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper torch fp8 dtype. + assert a.dtype == pt_fp8_dtype + assert b.dtype == pt_fp8_dtype M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = ( - prep_matmul(a_tl, b_tl, dot_out_dtype) + prep_matmul(a, b, dot_out_dtype) ) output_shape = a_shape[:-1] + (N,) @@ -1049,22 +1040,22 @@ def persistent_grid_tma(META): nonlocal desc_helper desc_helper.fill_2d_tma_descriptor( "a", - a_tl.data_ptr(), + a.data_ptr(), M, K, META["BLOCK_M"], META["BLOCK_K"], - a_tl.element_size(), + a.element_size(), ) desc_helper.fill_2d_tma_descriptor( "b", - b_tl.data_ptr(), + b.data_ptr(), N, K, META["BLOCK_N"], META["BLOCK_K"], - b_tl.element_size(), + b.element_size(), ) desc_helper.fill_2d_tma_descriptor( "c", @@ -1111,8 +1102,10 @@ def persistent_grid_tma(META): desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") - # pyre-ignore[28]: - _kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma]( + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[ + persistent_grid_tma + ]( desc_a, desc_b, desc_c, @@ -1141,9 +1134,9 @@ def persistent_grid_tma(META): USE_BIAS=bias is not None, ) elif imprecise_acc: - _kernel_matmul_fp8_row_imprecise_acc[grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid]( + a, + b, c, M, N, @@ -1168,9 +1161,9 @@ def persistent_grid_tma(META): AB_DTYPE=False, ) elif fp8_fast_accum: - _kernel_matmul_fp8_row[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid]( + a, + b, c, M, N, @@ -1196,9 +1189,11 @@ def persistent_grid_tma(META): NUM_SMS=NUM_SMS, ) else: - _kernel_matmul_fp8_row_no_fast_acc[persistent_grid]( - a_tl, - b_tl, + torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[ + persistent_grid + ]( + a, + b, c, M, N, @@ -1269,8 +1264,6 @@ def prune_configs_block(configs, named_args, **kwargs): ], # TODO caller side bin keys so similar shapes can use same triton.autotune. prune_configs_by={ "early_config_prune": prune_configs_block, - "perf_model": estimate_matmul_time, - "top_k": 10, }, ) @triton.heuristics( @@ -1465,8 +1458,6 @@ def _kernel_matmul_fp8_block_fastacc( ], # TODO caller side bin keys so similar shapes can use same triton.autotune. prune_configs_by={ "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, }, ) @triton.heuristics( @@ -1659,13 +1650,13 @@ def matmul_fp8_block( Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + _, tl_fp8_dtype, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper triton fp8 dtype. + a_tl = reinterpret_fp8_type(a, tl_fp8_dtype) + b_tl = reinterpret_fp8_type(b, tl_fp8_dtype) M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul( a_tl, b_tl, dot_out_dtype @@ -1794,14 +1785,18 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]: def prep_matmul( - a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype] -) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]: + a: Union[TensorWrapper, torch.Tensor], + b: Union[TensorWrapper, torch.Tensor], + dot_out_dtype: Optional[torch.dtype], +) -> Tuple[ + int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device +]: """ Shared bookkeeping for a @ b.T matmul. Args: - a (TensorWrapper): [M, K] input tensor. - b (TensorWrapper): [N, K] input tensor. + a (torch.Tensor): [M, K] input tensor. + b (torch.Tensor): [N, K] input tensor. dot_out_dtype (tl.dtype): Output type of tensor core. Returns: @@ -1812,7 +1807,8 @@ def prep_matmul( n_key (int): Autotuning key for N dim. k_key (int): Autotuning key for K dim. c (Tensor): [M, N] output tensor. - dot_out_dtype (torch.dtype): Output type of tensor core. + c_dtype_triton (tl.dtype): Type of output tensor. + dot_out_dtype (tl.dtype): Output type of tensor core. device (torch.device): Device of output tensor. """ device = a.device @@ -1827,11 +1823,20 @@ def prep_matmul( # allocates output assert a.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5, tl.float8e4b8, - ] and b.dtype in [ + ] + assert b.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, tl.float8e4nv, tl.float8e4b15, tl.float8e5,