From 33855a35b3857a39dd954d5614d1504619fe44bd Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Wed, 2 Oct 2024 22:26:37 -0700 Subject: [PATCH] Rowwise F8F8BF16 GEMMs - Auto-generate kernel library, auto-generated heuristics cache, add to FBGEMM quantize_bench (#3210) Summary: # Summary - Auto-generated F8F8BF16 Rowwise Scaled Kernels. - Auto-generation of Heuristic Cache. - Add to quantize_bench # Performance Improvements ## DisaggBench Cultass Prefill B=1 T=2048: Elapsed: 109.13ms FLOPs: 333.74TF/s Prefill B=1 T=4928: Elapsed: 272.55ms FLOPs: 338.62TF/s Prefill B=1 T=6336: Elapsed: 354.93ms FLOPs: 342.55TF/s Prefill B=1 T=8192: Elapsed: 468.64ms FLOPs: 346.06TF/s Cultass extensions Prefill B=1 T=2048: Elapsed: 108.83ms FLOPs: 334.66TF/s Prefill B=1 T=4928: Elapsed: 260.46ms FLOPs: 354.34TF/s Prefill B=1 T=6336: Elapsed: 336.39ms FLOPs: 361.43TF/s Prefill B=1 T=8192: Elapsed: 442.64ms FLOPs: 366.39TF/s Differential Revision: D63744054 --- .../experimental/gen_ai/bench/quantize_ops.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 65b34a956..8673bdb34 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -500,7 +500,11 @@ def cuda(self) -> bool: return True +#################################################################################################### # CUTLASS kernel v2 +#################################################################################################### + + @register_quantize_op class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase): """ @@ -535,6 +539,44 @@ def cuda(self) -> bool: return True +# CUTLASS kernel v2 +@register_quantize_op +class CutlassFP8RowwiseGemm_v2(QuantizeOpBase): + """ + FP8 matmul with rowwise scaling. + """ + + def quantize(self, x, w): + # Quantize both input tensors. + xq, x_scale = quantize_fp8_row(x) + wq, w_scale = quantize_fp8_row(w) + return xq, wq, x_scale, w_scale + + def compute(self, xq, wq, x_scale, w_scale): + return torch.ops.cutlass_extensions.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) + + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale = self.quantize(x, w) + return self.compute(xq, wq, x_scale, w_scale) + + @property + def name(self) -> str: + return "cutlass_rowwise_v2" + + @property + def hip(self) -> bool: + # Need to add support for better quantize kernel. + # Also may have an issue with cuda graphs. + return False + + @property + def cuda(self) -> bool: + return True + + +#################################################################################################### + + @register_quantize_op class F8I4RowwiseGemm(QuantizeOpBase): """