From 7a4472a8d14912ab7a3b7ca12bca030448f8fec8 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Fri, 4 Oct 2024 16:08:13 -0700 Subject: [PATCH] Fine-tune FP8 BMM performance (#3224) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3224 X-link: https://github.com/facebookresearch/FBGEMM/pull/322 Fine-tune FP8 BMM performance to get additional 20% performance gain. Reviewed By: jianyuh Differential Revision: D63882833 fbshipit-source-id: 04fd5d38e8e127edd2d8771681a180757eaf7321 --- .../f8f8bf16_rowwise_batched.cu | 13 +++++++++++++ .../cutlass_extensions/include/kernel_mode.h | 16 ++++++++-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu index 313c81298..a34c694e0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -335,6 +335,19 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel( UseBias, InputDType, BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Medium) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 1, + 2, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); } else if (kernel == KernelMode::Large) { return f8f8bf16_rowwise_batched_impl< 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 93b96fb04..9a267193a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -12,7 +12,7 @@ namespace fbgemm_gpu { -enum class KernelMode { Small, Large, Default }; +enum class KernelMode { Small, Medium, Large, Default }; inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto M = XQ.size(0); @@ -37,14 +37,14 @@ inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto K = XQ.size(2); auto N = WQ.size(1); auto BM = B * M; - auto BN = B * N; - auto BK = B * K; - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) || - (BK >= 2048 && BN >= 2048)); - if (BM <= 128 || BN <= 128) { + // Heuristic to determine kernel mode + bool use_medium_kernel = + ((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192)))); + bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024))); + if (BM <= 128 || N <= 128) { return KernelMode::Small; + } else if (use_medium_kernel) { + return KernelMode::Medium; } else if (use_large_kernel) { return KernelMode::Large; } else {