diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 4598f360c0..684e7eb297 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -558,8 +558,14 @@ def execute_nbit_forward_( # noqa C901 f = torch.cat([f.view(B, -1) for f in fs], dim=1) else: f = torch.cat(fs, dim=0).view(-1, D) + if fc2.dtype == torch.quint4x2: + fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat( + fc2.cpu(), bit_rate=4 + ) + else: + fc2_float = fc2.float() torch.testing.assert_close( - fc2.float().cpu(), + fc2_float.cpu(), f.float().cpu(), atol=1.0e-2, rtol=1.0e-2, @@ -941,6 +947,54 @@ def test_nbit_forward_cpu_seq_int8( equal_nan=False, ) + @given( + D=st.sampled_from([32, 256, 384, 512, 1024]), + B=st.integers(min_value=8, max_value=32), + T=st.integers(min_value=10, max_value=20), + L=st.integers(min_value=10, max_value=100), + MAXH=st.integers(min_value=50, max_value=100), + ) + @settings( + verbosity=VERBOSITY, + max_examples=MAX_EXAMPLES_LONG_RUNNING, + deadline=None, + ) + def test_nbit_forward_cpu_seq_int4( + self, + D: int, + B: int, + T: int, + L: int, + MAXH: int, + ) -> None: + """ + we init a quant table split embedding bag with int4 weights and scale of 1 and 0 bias + and compare brute force table lookup vs tbe based int4 output lookup. + """ + pooling_mode = PoolingMode.NONE + nbit_weights_ty = SparseType.INT4 + log_E = 4 + weighted = False + mixed = False + self.execute_nbit_forward_( + T, + D, + B, + log_E, + L, + weighted, + mixed, + pooling_mode, + nbit_weights_ty, + use_cache=False, + cache_algorithm=CacheAlgorithm.LRU, # doesn't matter since we don't use cache + use_cpu=True, + use_array_for_index_remapping=True, + do_pruning=False, + mixed_weights_ty=False, + output_dtype=nbit_weights_ty, + ) + @unittest.skipIf(*gpu_unavailable) @given( nbit_weights_ty=st.sampled_from(