From 14c958278e0736ca12dfcc3c4c37c84e9e98b983 Mon Sep 17 00:00:00 2001 From: Wei Su Date: Fri, 16 Aug 2024 17:57:47 -0700 Subject: [PATCH] Add unit test for int4 to int4 sequence CPU TBE (#2997) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2997 X-link: https://github.com/facebookresearch/FBGEMM/pull/88 Unit test for int4 to int4 sequential CPU TBE Reviewed By: sryap Differential Revision: D61305982 --- .../test/tbe/inference/nbit_forward_test.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 4598f360c0..1a61d99b38 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -558,8 +558,14 @@ def execute_nbit_forward_( # noqa C901 f = torch.cat([f.view(B, -1) for f in fs], dim=1) else: f = torch.cat(fs, dim=0).view(-1, D) + if fc2.dtype == torch.quint4x2: + fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat( + fc2.cpu(), bit_rate=4 + ) + else: + fc2_float = fc2.float() torch.testing.assert_close( - fc2.float().cpu(), + fc2_float.cpu(), f.float().cpu(), atol=1.0e-2, rtol=1.0e-2, @@ -941,6 +947,49 @@ def test_nbit_forward_cpu_seq_int8( equal_nan=False, ) + @given( + D=st.sampled_from([32, 256, 384, 512, 1024]), + B=st.integers(min_value=8, max_value=32), + T=st.integers(min_value=10, max_value=20), + L=st.integers(min_value=10, max_value=100), + MAXH=st.integers(min_value=50, max_value=100), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + ) + def test_nbit_forward_cpu_seq_int4( + self, + D: int, + B: int, + T: int, + L: int, + MAXH: int, + ) -> None: + """ + we init a quant table split embedding bag with int4 weights and scale of 1 and 0 bias + and compare brute force table lookup vs tbe based int4 output lookup. + """ + self.execute_nbit_forward_( + T, + D, + B, + log_E=4, + L=L, + weighted=False, + mixed=False, + pooling_mode=PoolingMode.NONE, + weights_ty=SparseType.INT4, + use_cache=False, + cache_algorithm=CacheAlgorithm.LRU, # doesn't matter since we don't use cache + use_cpu=True, + use_array_for_index_remapping=True, + do_pruning=False, + mixed_weights_ty=False, + output_dtype=SparseType.INT4, + ) + @unittest.skipIf(*gpu_unavailable) @given( nbit_weights_ty=st.sampled_from(