From 464493927de801c19dfded16a487f88adaeb43a2 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Mon, 20 May 2024 12:14:31 -0700 Subject: [PATCH] Add FP8 rowwise and tensorwise unittests with cudagraph (#2609) Summary: Add cudagraph coverage in FP8 rowwise and tensorwise unittests to guard FP8 GEMM kernel-wise issues with cudagraph TODO Merge f8f8bf16 and f8f8bf16_tensorwise into one kernel Differential Revision: D57578841 --- .../gen_ai/test/quantize/quantize_test.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 8f0235f5d..4b3057746 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -126,23 +126,36 @@ def test_quantize_fp8_matmul( ) if Mode == "tensorwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + g.replay() if bias is not None: zq += bias elif Mode == "tensorwise_broadcast": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16_tensorwise( - xq, wq, (x_scale * w_scale).item() - ) + x_scale = x_scale.item() + w_scale = w_scale.item() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale) + g.replay() if bias is not None: zq += bias elif Mode == "rowwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - zq = torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale, bias=bias) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x, output_dtype=QType + ) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + zq = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, wq, x_scale, w_scale, bias=bias + ) + g.replay() else: raise ValueError(f"Invalid mode {Mode}")