Skip to content

Commit

Permalink
Add unit test for int4 to int4 sequence CPU TBE (#2997)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2997

X-link: facebookresearch/FBGEMM#88

Unit test for int4 to int4 sequential CPU TBE

Reviewed By: sryap

Differential Revision: D61305982
  • Loading branch information
Wei Su authored and facebook-github-bot committed Aug 22, 2024
1 parent 71cc276 commit 2c71929
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion fbgemm_gpu/test/tbe/inference/nbit_forward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,14 @@ def execute_nbit_forward_( # noqa C901
f = torch.cat([f.view(B, -1) for f in fs], dim=1)
else:
f = torch.cat(fs, dim=0).view(-1, D)
if fc2.dtype == torch.quint4x2:
fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat(
fc2.cpu(), bit_rate=4
)
else:
fc2_float = fc2.float()
torch.testing.assert_close(
fc2.float().cpu(),
fc2_float.cpu(),
f.float().cpu(),
atol=1.0e-2,
rtol=1.0e-2,
Expand Down Expand Up @@ -941,6 +947,49 @@ def test_nbit_forward_cpu_seq_int8(
equal_nan=False,
)

@given(
D=st.sampled_from([32, 256, 384, 512, 1024]),
B=st.integers(min_value=8, max_value=32),
T=st.integers(min_value=10, max_value=20),
L=st.integers(min_value=10, max_value=100),
MAXH=st.integers(min_value=50, max_value=100),
)
@settings(
verbosity=VERBOSITY,
max_examples=MAX_EXAMPLES_LONG_RUNNING,
deadline=None,
)
def test_nbit_forward_cpu_seq_int4(
self,
D: int,
B: int,
T: int,
L: int,
MAXH: int,
) -> None:
"""
we init a quant table split embedding bag with int4 weights and scale of 1 and 0 bias
and compare brute force table lookup vs tbe based int4 output lookup.
"""
self.execute_nbit_forward_(
T,
D,
B,
log_E=4,
L=L,
weighted=False,
mixed=False,
pooling_mode=PoolingMode.NONE,
weights_ty=SparseType.INT4,
use_cache=False,
cache_algorithm=CacheAlgorithm.LRU, # doesn't matter since we don't use cache
use_cpu=True,
use_array_for_index_remapping=True,
do_pruning=False,
mixed_weights_ty=False,
output_dtype=SparseType.INT4,
)

@unittest.skipIf(*gpu_unavailable)
@given(
nbit_weights_ty=st.sampled_from(
Expand Down

0 comments on commit 2c71929

Please sign in to comment.