From 63ca6dcedc501f66cd1306c1b57817d9842239f0 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Thu, 23 May 2024 19:20:11 -0700 Subject: [PATCH] Add INT4-FP8 rowwise matmul tests (#2622) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2622 - Add INT4-FP8 rowwise matmul tests with cudagraph and eager - Change number of groups to group size to align with E2E variable - Fix a minor dtype issue Reviewed By: jwfromm Differential Revision: D57677844 fbshipit-source-id: 78578490025d72a5d73dd207fce4518304283de0 --- .../gen_ai/test/quantize/quantize_test.py | 78 ++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 8d96dc0f3..1fb88ffcb 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -48,6 +48,42 @@ def fp8_col_quantize_ref(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return xq, scale.reciprocal().to(torch.float32) +def int4_row_quantize( + x: torch.Tensor, + group_size: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + N = x.shape[0] + K = x.shape[1] + assert K >= group_size and K % group_size == 0 + num_groups = K // group_size + x = x.view(N, group_size, num_groups).to(torch.float) + # Zero point should be values in each row closest to zero. + zp_ind = torch.argmin(torch.abs(x), dim=1, keepdim=True) + zero_point = torch.gather(x, 1, zp_ind) # + (.5 / scale) + x_centered = x - zero_point + scale = 7.5 / torch.max(torch.abs(x_centered), dim=1, keepdim=True).values + xq = torch.clamp(torch.round((x_centered * scale) - 0.5), min=-8, max=7) + return ( + xq.to(torch.int8).view(N, -1).contiguous(), + scale.reciprocal().view(N, -1).contiguous(), # pyre-ignore + (zero_point + (0.5 / scale)).view(N, -1).contiguous(), + ) + + +def pack_int4(x: torch.Tensor) -> torch.Tensor: + # Given int8 x, pack adjacent int4 values into a single int8. + low_x = x[:, ::2] + high_x = x[:, 1::2] + + # High bits need to left shift, this also masks off extra bits. + high_x = torch.bitwise_left_shift(high_x, 4) + # Low bits need to have sign bits removed. + low_x = torch.bitwise_and(low_x, 0xF) + + # Recombine into a single value with bitwise or. + return torch.bitwise_or(low_x, high_x).contiguous() + + @unittest.skipIf( not torch.cuda.is_available() or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, @@ -121,7 +157,7 @@ def test_f8f8bf16_rowwise_simple(self) -> None: x = xq.bfloat16() w = wq.bfloat16() - zq_ref = (x @ w.T).to(torch.bfloat16) * x_scale[:, None] * w_scale[None, :] + zq_ref = (x @ w.T * x_scale[:, None] * w_scale[None, :]).to(torch.bfloat16) torch.testing.assert_close(zq, zq_ref, atol=1.0e-3, rtol=1.0e-3) @@ -216,6 +252,46 @@ def test_quantize_fp8_matmul( zq_ref += bias torch.testing.assert_close(zq, zq_ref, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: built in quantize ops not yet suported." + ) + @settings(deadline=None) + @given( + B_T=st.sampled_from([2048, 4096]), + D=st.sampled_from([128, 256]), + HD_L=st.sampled_from([256, 512]), + QType=st.sampled_from([torch.float8_e4m3fn, torch.float8_e5m2]), + CudaGraph=st.sampled_from([True, False]), + ) + def test_quantize_int4_fp8_matmul( + self, + B_T: int, + D: int, + HD_L: int, + QType: torch.dtype, + CudaGraph: bool, + ) -> None: + x = torch.randn(size=(B_T, D), dtype=torch.bfloat16, device="cuda") * 0.1 + w = torch.randn(size=(HD_L, D), dtype=torch.bfloat16, device="cuda") * 0.01 + + wq, w_scale, w_zp = int4_row_quantize(w, 128) + wq = pack_int4(wq).contiguous().to(device="cuda") + w_scale = w_scale.contiguous().to(device="cuda") + w_zp = w_zp.contiguous().to(device="cuda") + + if CudaGraph: + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x) + zq = torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp) + g.replay() + else: + xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x) + zq = torch.ops.fbgemm.f8i4bf16_rowwise(xq, wq, x_scale, w_scale, w_zp) + + zq_ref = (x @ w.T).to(torch.bfloat16) + torch.testing.assert_close(zq, zq_ref, atol=8.0e-2, rtol=8.0e-2) + @unittest.skipIf( not torch.version.cuda, "Skip on AMD: built in quantize ops not yet suported." )