diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index 65b34a956..8673bdb34 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -500,7 +500,11 @@ def cuda(self) -> bool: return True +#################################################################################################### # CUTLASS kernel v2 +#################################################################################################### + + @register_quantize_op class CutlassFP8TensorwiseGemm_v2(QuantizeOpBase): """ @@ -535,6 +539,44 @@ def cuda(self) -> bool: return True +# CUTLASS kernel v2 +@register_quantize_op +class CutlassFP8RowwiseGemm_v2(QuantizeOpBase): + """ + FP8 matmul with rowwise scaling. + """ + + def quantize(self, x, w): + # Quantize both input tensors. + xq, x_scale = quantize_fp8_row(x) + wq, w_scale = quantize_fp8_row(w) + return xq, wq, x_scale, w_scale + + def compute(self, xq, wq, x_scale, w_scale): + return torch.ops.cutlass_extensions.f8f8bf16_rowwise(xq, wq, x_scale, w_scale) + + def quantize_and_compute(self, x, w): + xq, wq, x_scale, w_scale = self.quantize(x, w) + return self.compute(xq, wq, x_scale, w_scale) + + @property + def name(self) -> str: + return "cutlass_rowwise_v2" + + @property + def hip(self) -> bool: + # Need to add support for better quantize kernel. + # Also may have an issue with cuda graphs. + return False + + @property + def cuda(self) -> bool: + return True + + +#################################################################################################### + + @register_quantize_op class F8I4RowwiseGemm(QuantizeOpBase): """