From c6b13bc0141723b70db320ee6e5870085eeb8915 Mon Sep 17 00:00:00 2001 From: Henry Tsang Date: Sun, 29 Sep 2024 16:51:03 -0700 Subject: [PATCH] Make some fbgemm fp8 triton ops pt2 friendly (#3188) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/283 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3188 Make some fbgemm fp8 triton ops pt2 friendly.. # What this diff tries to do * stop using TensorWrapper and tl.reinterpret * Remove the use of triton_heuristics for _kernel_matmul_fp8_row # What this diff won't help: * triton_herustics use cases of EVEN_K. One option is to just merge that into the autotuning configs # need to do in the future: * Update other ops, like quantize_fp8_row. * Update documentation. Feels pretty outdated, and some still reference to TensorWrapper. Reviewed By: jwfromm Differential Revision: D63560103 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 78 +++++++++---------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 07765fa21..7703d5ef5 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -19,7 +19,6 @@ early_config_prune, estimate_matmul_time, ) -from triton.runtime.jit import reinterpret as tl_reinterpret, TensorWrapper # @manual logger: logging.Logger = logging.getLogger(__name__) @@ -43,18 +42,18 @@ def get_fp8_constants() -> Tuple[torch.dtype, tl.dtype, float, float]: return pt_fp8_dtype, tl_fp8_dtype, torch.finfo(pt_fp8_dtype).max, 1e-12 -def convert_fp8_type(tensor, dtype) -> triton.TensorWrapper: +def convert_fp8_type(tensor: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor: """ - Converts tensor to triton fp8 type. + Converts tensor to fp8 type. Args: tensor (torch.Tensor): input tensor. - dtype (tl.dtype): target triton dtype. + dtype (torch.dtype): target torch dtype. Returns: - triton.TensorWrapper: fp8 tensor. + torch.Tensor: fp8 tensor. """ - return tl_reinterpret(tensor, dtype=dtype) + return tensor.view(dtype=dtype) def init_to_zero(name): @@ -213,11 +212,6 @@ def get_configs_io_bound() -> List[Config]: "k_key", ], ) -@triton.heuristics( - { - "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - } -) @triton.jit def _kernel_matmul_fp8_row( A_ptr, @@ -246,7 +240,6 @@ def _kernel_matmul_fp8_row( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, - EVEN_K: tl.constexpr, USE_BIAS: tl.constexpr, AB_DTYPE: tl.constexpr, NUM_SMS: tl.constexpr, @@ -964,7 +957,7 @@ def get_tma_descriptor_kernel_param(self, name): return self.cuda_descriptors[name] -@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=()) +@torch._library.triton_op("triton::matmul_fp8_row", mutates_args=()) def matmul_fp8_row( a: torch.Tensor, b: torch.Tensor, @@ -995,13 +988,13 @@ def matmul_fp8_row( torch.Tensor: [M, N] Output tensor a @ b / (a_scale[:, None] * b_scale[None, :]) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper torch fp8 dtype. + a_tl = convert_fp8_type(a, pt_fp8_dtype) + b_tl = convert_fp8_type(b, pt_fp8_dtype) M, N, K, m_key, n_key, k_key, c, c_dtype_triton, dot_out_dtype_triton, device = ( prep_matmul(a_tl, b_tl, dot_out_dtype) ) @@ -1111,8 +1104,10 @@ def persistent_grid_tma(META): desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") - # pyre-ignore[28]: - _kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma]( + # pyre-ignore + torch._library.capture_triton(_kernel_matmul_fp8_row_tma_persistent)[ + persistent_grid_tma + ]( desc_a, desc_b, desc_c, @@ -1141,7 +1136,7 @@ def persistent_grid_tma(META): USE_BIAS=bias is not None, ) elif imprecise_acc: - _kernel_matmul_fp8_row_imprecise_acc[grid]( + torch._library.capture_triton(_kernel_matmul_fp8_row_imprecise_acc)[grid]( a_tl, b_tl, c, @@ -1168,7 +1163,7 @@ def persistent_grid_tma(META): AB_DTYPE=False, ) elif fp8_fast_accum: - _kernel_matmul_fp8_row[persistent_grid]( + torch._library.capture_triton(_kernel_matmul_fp8_row)[persistent_grid]( a_tl, b_tl, c, @@ -1196,7 +1191,9 @@ def persistent_grid_tma(META): NUM_SMS=NUM_SMS, ) else: - _kernel_matmul_fp8_row_no_fast_acc[persistent_grid]( + torch._library.capture_triton(_kernel_matmul_fp8_row_no_fast_acc)[ + persistent_grid + ]( a_tl, b_tl, c, @@ -1659,13 +1656,13 @@ def matmul_fp8_block( Tensor: [M, N] output tensor, (a / a_scale) @ (b / b_scale) """ # Get datatypes and constants to use. - _, tl_dtype, _, _ = get_fp8_constants() + pt_fp8_dtype, _, _, _ = get_fp8_constants() # Handle 3D+ a shape a_shape = a.shape a = a.view(-1, a.size(-1)) - # Reinterpret inputs into proper triton fp8 dtype. - a_tl = convert_fp8_type(a, tl_dtype) - b_tl = convert_fp8_type(b, tl_dtype) + # View inputs into proper torch fp8 dtype. + a_tl = convert_fp8_type(a, pt_fp8_dtype) + b_tl = convert_fp8_type(b, pt_fp8_dtype) M, N, K, m_key, n_key, k_key, c, _, dot_out_dtype_triton, device = prep_matmul( a_tl, b_tl, dot_out_dtype @@ -1794,14 +1791,16 @@ def get_matmul_tune(M: int, N: int, K: int) -> Tuple[int, int, int]: def prep_matmul( - a: TensorWrapper, b: TensorWrapper, dot_out_dtype: Optional[torch.dtype] -) -> Tuple[int, int, int, int, int, int, torch.Tensor, str, str, torch.device]: + a: torch.Tensor, b: torch.Tensor, dot_out_dtype: Optional[torch.dtype] +) -> Tuple[ + int, int, int, int, int, int, torch.Tensor, tl.dtype, tl.dtype, torch.device +]: """ Shared bookkeeping for a @ b.T matmul. Args: - a (TensorWrapper): [M, K] input tensor. - b (TensorWrapper): [N, K] input tensor. + a (torch.Tensor): [M, K] input tensor. + b (torch.Tensor): [N, K] input tensor. dot_out_dtype (tl.dtype): Output type of tensor core. Returns: @@ -1827,15 +1826,16 @@ def prep_matmul( # allocates output assert a.dtype in [ - tl.float8e4nv, - tl.float8e4b15, - tl.float8e5, - tl.float8e4b8, - ] and b.dtype in [ - tl.float8e4nv, - tl.float8e4b15, - tl.float8e5, - tl.float8e4b8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ] + assert b.dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, ] c_dtype = torch.bfloat16 c_dtype_triton = tl.bfloat16