From 581fcec178832e44d5ff461352bdd3beb6d60bde Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Mon, 20 May 2024 16:19:35 -0700 Subject: [PATCH] Add FP8 rowwise and tensorwise unittests with cudagraph (#2609) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2609 Add cudagraph coverage in FP8 rowwise and tensorwise unittests to guard FP8 GEMM kernel-wise issues with cudagraph TODO Merge f8f8bf16 and f8f8bf16_tensorwise into one kernel Reviewed By: jasonjk-park Differential Revision: D57578841 fbshipit-source-id: 780481848ff3443f5addc1afdc1f563fa4f1df82 --- .../gen_ai/test/quantize/quantize_test.py | 70 +++++++++++++++---- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 8f0235f5d..38b5e158a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -113,9 +113,17 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None: Mode=st.sampled_from(["tensorwise", "tensorwise_broadcast", "rowwise"]), QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]), Bias=st.sampled_from([True, False]), + CudaGraph=st.sampled_from([True, False]), ) def test_quantize_fp8_matmul( - self, B_T: int, D: int, HD_L: int, Mode: str, QType: torch.dtype, Bias: bool + self, + B_T: int, + D: int, + HD_L: int, + Mode: str, + QType: torch.dtype, + Bias: bool, + CudaGraph: bool, ) -> None: x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1 w = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 @@ -126,23 +134,57 @@ def test_quantize_fp8_matmul( ) if Mode == "tensorwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) - if bias is not None: - zq += bias + if CudaGraph: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + if bias is not None: + zq += bias + g.replay() + else: + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + if bias is not None: + zq += bias elif Mode == "tensorwise_broadcast": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16_tensorwise( - xq, wq, (x_scale * w_scale).item() - ) - if bias is not None: - zq += bias + x_scale = x_scale.item() + w_scale = w_scale.item() + if CudaGraph: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale) + if bias is not None: + zq += bias + g.replay() + else: + zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale) + if bias is not None: + zq += bias elif Mode == "rowwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - zq = torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale, bias=bias) + if CudaGraph: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x, output_dtype=QType + ) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + zq = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, wq, x_scale, w_scale, bias=bias + ) + g.replay() + else: + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x, output_dtype=QType + ) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + zq = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, wq, x_scale, w_scale, bias=bias + ) else: raise ValueError(f"Invalid mode {Mode}")