diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b9385b5cc..27f21978b 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2775,6 +2775,52 @@ def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None: for out, ref in zip(outputs, refs): torch.testing.assert_close(out, ref) + @repeat_test( + device_str=["meta", "cpu", "cuda"], + dtype=[ + # torch.int, + # torch.uint8, + # torch.int8, + # torch.int16, + # torch.float64, + torch.float, + torch.float32, + torch.float16, + torch.bfloat16, + ], + ) + def test_multi_permute_dtype(self, device_str: str, dtype: torch.dtype) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) + batch_size = 4 + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] + lengths = [[3, 4], [5, 6, 7], [8]] + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] + values = [ + torch.randn(batch_size, sum(L), device=device, dtype=dtype) for L in lengths + ] + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( + values[0], keys, lengths, groups + ) + outputs = torch.ops.fbgemm.permute_multi_embedding( + values, permutes, in_shapes, out_shapes, out_lengths + ) + + if device_str == "meta": + for out, ref in zip(outputs, out_lengths): + self.assertEqual(out.shape, (batch_size, ref)) + else: + refs = [[] for _ in groups] + for i in range(permutes.size(0)): + in_idx, out, in_start, _, length, _ = permutes[i].tolist() + refs[out].append(values[in_idx][:, in_start : (in_start + length)]) + refs = [torch.cat(ref, dim=1) for ref in refs] + for out, ref in zip(outputs, refs): + torch.testing.assert_close(out, ref) + self.assertEqual(out.dtype, ref.dtype) + @repeat_test( ["cpu", 32, [[3, 4], [5, 6, 7], [8]]], ["cuda", 128, [[96, 256], [512, 128, 768], [1024]]],