diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 8f0235f5d..4b3057746 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -126,23 +126,36 @@ def test_quantize_fp8_matmul( ) if Mode == "tensorwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) + zq = torch.ops.fbgemm.f8f8bf16(xq, wq, x_scale * w_scale) + g.replay() if bias is not None: zq += bias elif Mode == "tensorwise_broadcast": xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - zq = torch.ops.fbgemm.f8f8bf16_tensorwise( - xq, wq, (x_scale * w_scale).item() - ) + x_scale = x_scale.item() + w_scale = w_scale.item() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + zq = torch.ops.fbgemm.f8f8bf16_tensorwise(xq, wq, x_scale * w_scale) + g.replay() if bias is not None: zq += bias elif Mode == "rowwise": - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, output_dtype=QType) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) - zq = torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale, bias=bias) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( + x, output_dtype=QType + ) + wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) + zq = torch.ops.fbgemm.f8f8bf16_rowwise( + xq, wq, x_scale, w_scale, bias=bias + ) + g.replay() else: raise ValueError(f"Invalid mode {Mode}")