diff --git a/fbgemm_gpu/test/jagged/batched_dense_vec_jagged_2d_mul_test.py b/fbgemm_gpu/test/jagged/batched_dense_vec_jagged_2d_mul_test.py new file mode 100644 index 000000000..ada35e2b3 --- /dev/null +++ b/fbgemm_gpu/test/jagged/batched_dense_vec_jagged_2d_mul_test.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[56] + +import unittest + +import hypothesis.strategies as st +import torch +import torch._dynamo +from hypothesis import assume, given, settings, Verbosity + +from .common import additional_decorators, open_source, torch_compiled + +if open_source: + # pyre-ignore[21] + from test_utils import ( + cpu_and_maybe_gpu, + gradcheck, + optests, + symint_vector_unsupported, + ) +else: + from fbgemm_gpu.test.test_utils import ( + cpu_and_maybe_gpu, + gradcheck, + optests, + symint_vector_unsupported, + ) + + +@optests.generate_opcheck_tests(additional_decorators=additional_decorators) +class BatchedDenseVecJagged2DMulTest(unittest.TestCase): + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + @given( + B=st.integers(0, 32), + H=st.integers(1, 3), + max_L=st.integers(1, 32), + D=st.integers(0, 32), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=cpu_and_maybe_gpu(), + ) + def test_batched_dense_vec_jagged_2d_mul( + self, + B: int, + H: int, + max_L: int, + D: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + assume(H == 1 or B != 0) + # CPU doesn't support bfloat16 + assume(device != torch.device("cpu") or dtype != torch.bfloat16) + + torch.backends.cuda.matmul.allow_tf32 = False + + # Sometimes length[i] exceed max_L meaning jagged->dense will be + # truncation vs. padding + lengths = torch.randint(max_L * 2, size=(B,), device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) + dense = torch.rand((B * H, max_L), dtype=dtype, device=device) + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values, + [offsets], + [max_L], + ) # [B, N, H * D] + + bmm_arg1 = dense.unsqueeze(1) + bmm_arg2 = ( + padded_values.reshape(B, max_L, H, D) + .transpose(1, 2) + .reshape(B * H, max_L, D) + ) + # torch.bmm not implemented for Half on CPU + if dtype in [torch.half, torch.bfloat16] and device == torch.device("cpu"): + bmm_arg1 = bmm_arg1.float() + bmm_arg2 = bmm_arg2.float() + output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( + 1 + ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] + if dtype in [torch.half, torch.bfloat16] and device == torch.device("cpu"): + output_ref = output_ref.to(dtype) + output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( + dense, values, offsets + ) + torch.testing.assert_close( + output, + output_ref, + rtol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, + atol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, + ) + + gradcheck( + torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, + ( + dense.clone().detach().float().requires_grad_(True), + values.clone().detach().float().requires_grad_(True), + offsets, + ), + eps=1e-2, + atol=1e-3, + rtol=1e-3, + ) + + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + @given( + B=st.integers(0, 32), + H=st.integers(1, 3), + max_L=st.integers(1, 32), + D=st.integers(0, 32), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device_type=st.sampled_from(["meta"]), + ) + def test_batched_dense_vec_jagged_2d_mul_meta_backend( + self, + B: int, + H: int, + max_L: int, + D: int, + dtype: torch.dtype, + device_type: str, + ) -> None: + assume(H == 1 or B != 0) + + device = torch.device("cpu") + torch.backends.cuda.matmul.allow_tf32 = False + + # Sometimes length[i] exceed max_L meaning jagged->dense will be + # truncation vs. padding + lengths = torch.randint(max_L * 2, size=(B,), device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) + dense = torch.rand((B * H, max_L), dtype=dtype, device=device) + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values, + [offsets], + [max_L], + ) # [B, N, H * D] + + bmm_arg1 = dense.unsqueeze(1) + bmm_arg2 = ( + padded_values.reshape(B, max_L, H, D) + .transpose(1, 2) + .reshape(B * H, max_L, D) + ) + # torch.bmm not implemented for Half on CPU + if dtype in [torch.half, torch.bfloat16]: + bmm_arg1 = bmm_arg1.float() + bmm_arg2 = bmm_arg2.float() + output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( + 1 + ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] + dense.to(device_type) + values.to(device_type) + output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( + dense, values, offsets + ) + assert output.size() == output_ref.size() + + @optests.dontGenerateOpCheckTests("tests that call torch.compile are slow") + @unittest.skipIf(*symint_vector_unsupported()) + @settings( + verbosity=Verbosity.verbose, + max_examples=20, + deadline=None, + ) + @given( + B=st.integers(2, 32), + H=st.integers(1, 3), + max_L=st.integers(1, 32), + D=st.integers(2, 32), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device_type=st.just("cpu"), + ) + def test_batched_dense_vec_jagged_2d_mul_dynamic_shape( + self, + B: int, + H: int, + max_L: int, + D: int, + dtype: torch.dtype, + device_type: str, + ) -> None: + # Start a fresh compile for each parameter of the test case + torch._dynamo.reset() + + assume(H == 1 or B != 0) + + device = torch.device(device_type) + torch.backends.cuda.matmul.allow_tf32 = False + + # Sometimes length[i] exceed max_L meaning jagged->dense will be + # truncation vs. padding + lengths = torch.randint(low=1, high=max_L * 2, size=(B,), device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) + dense = torch.rand((B * H, max_L), dtype=dtype, device=device) + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values, + [offsets], + [max_L], + ) # [B, N, H * D] + + bmm_arg1 = dense.unsqueeze(1) + bmm_arg2 = ( + padded_values.reshape(B, max_L, H, D) + .transpose(1, 2) + .reshape(B * H, max_L, D) + ) + # torch.bmm not implemented for Half on CPU + if dtype in [torch.half, torch.bfloat16]: + bmm_arg1 = bmm_arg1.float() + bmm_arg2 = bmm_arg2.float() + output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( + 1 + ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] + dense.to(device_type) + values.to(device_type) + + torch._dynamo.mark_dynamic(dense, 0) + torch._dynamo.mark_dynamic(values, 0) + torch._dynamo.mark_dynamic(values, 1) + torch._dynamo.mark_dynamic(offsets, 0) + + output = torch_compiled( + torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, + fullgraph=True, + dynamic=True, + )(dense, values, offsets) + assert output.size() == output_ref.size() + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/jagged/dense_bmm_test.py b/fbgemm_gpu/test/jagged/dense_bmm_test.py index 74cf7e467..1cace6b69 100644 --- a/fbgemm_gpu/test/jagged/dense_bmm_test.py +++ b/fbgemm_gpu/test/jagged/dense_bmm_test.py @@ -19,10 +19,10 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available, optests, symint_vector_unsupported + from test_utils import cpu_and_maybe_gpu, optests, symint_vector_unsupported else: from fbgemm_gpu.test.test_utils import ( - gpu_available, + cpu_and_maybe_gpu, optests, symint_vector_unsupported, ) @@ -36,9 +36,7 @@ class DenseBmmTest(unittest.TestCase): N=st.integers(1, 32), max_L=st.integers(1, 32), dtype=st.sampled_from([torch.float]), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), + device=cpu_and_maybe_gpu(), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) def test_jagged_jagged_bmm( @@ -48,10 +46,9 @@ def test_jagged_jagged_bmm( N: int, max_L: int, dtype: torch.dtype, - device_type: str, + device: torch.device, ) -> None: assume(B != 0) - device = torch.device(device_type) torch.backends.cuda.matmul.allow_tf32 = False lengths = torch.randint(max_L + 1, size=(B,), device=device) offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) @@ -101,9 +98,7 @@ def test_jagged_jagged_bmm( N=st.integers(1, 32), max_L=st.integers(1, 32), dtype=st.sampled_from([torch.float]), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), + device=cpu_and_maybe_gpu(), ) @settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None) def test_jagged_dense_bmm( @@ -113,10 +108,9 @@ def test_jagged_dense_bmm( N: int, max_L: int, dtype: torch.dtype, - device_type: str, + device: torch.device, ) -> None: assume(B != 0) - device = torch.device(device_type) torch.backends.cuda.matmul.allow_tf32 = False lengths = torch.randint(max_L + 1, size=(B,), device=device) total_length = int(lengths.sum().item()) @@ -165,7 +159,7 @@ def test_jagged_dense_bmm( N=st.integers(2, 32), max_L=st.integers(2, 32), dtype=st.sampled_from([torch.float]), - device_type=st.just("cpu"), + device=st.just(torch.device("cpu")), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) def test_jagged_dense_bmm_dynamic_shape( @@ -175,13 +169,12 @@ def test_jagged_dense_bmm_dynamic_shape( N: int, max_L: int, dtype: torch.dtype, - device_type: str, + device: torch.device, ) -> None: # Start a fresh compile for each parameter of the test case torch._dynamo.reset() assume(B != 0) - device = torch.device(device_type) torch.backends.cuda.matmul.allow_tf32 = False lengths = torch.randint(low=1, high=max_L + 1, size=(B,), device=device) total_length = int(lengths.sum().item()) diff --git a/fbgemm_gpu/test/jagged/dense_to_jagged_test.py b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py new file mode 100644 index 000000000..b774537a2 --- /dev/null +++ b/fbgemm_gpu/test/jagged/dense_to_jagged_test.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +# pyre-ignore-all-errors[56] + +import unittest +from typing import List, Tuple + +import hypothesis.strategies as st +import torch +import torch._dynamo +from hypothesis import given, settings, Verbosity + +from .common import additional_decorators, generate_jagged_tensor, open_source + +if open_source: + # pyre-ignore[21] + from test_utils import ( + cpu_and_maybe_gpu, + gpu_unavailable, + optests, + symint_vector_unsupported, + ) +else: + from fbgemm_gpu.test.test_utils import ( + cpu_and_maybe_gpu, + gpu_unavailable, + optests, + symint_vector_unsupported, + ) + + +@optests.generate_opcheck_tests(additional_decorators=additional_decorators) +class DenseToJaggedTest(unittest.TestCase): + def _test_dense_to_jagged( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + # Generate multi-dim jagged tensor + values_2d, offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device + ) + values_2d = values_2d.clone().detach().requires_grad_(True) + + # jagged -> dense + dense = torch.ops.fbgemm.jagged_to_padded_dense(values_2d, offsets, max_lengths) + + # dense -> jagged (op which is being tested) + if precompute_total_L: + total_L = values_2d.size(0) + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets, total_L + ) + else: + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets + ) + + # jagged -> dense + dense2 = torch.ops.fbgemm.jagged_to_padded_dense( + jagged_values, jagged_offsets, max_lengths + ) + + # verify forward + torch.testing.assert_close(dense, dense2) + + # verify backward + dense.retain_grad() + ref_output_values = jagged_values.clone().detach().requires_grad_(True) + ref_values = dense.clone().detach().requires_grad_(True) + jagged_values.backward(ref_output_values) + torch.testing.assert_close(dense.grad, ref_values) + + @given( + num_jagged_dim=st.integers(1, 5), + outer_dense_size=st.integers(0, 5), + inner_dense_size=st.integers(0, 5), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=cpu_and_maybe_gpu(), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + @unittest.skipIf(*gpu_unavailable) + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.integers(0, 6000), + inner_dense_size=st.sampled_from([8, 16, 23, 24, 48, 50, 64, 72, 96, 192]), + dtype=st.just(torch.half), + device=cpu_and_maybe_gpu(), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged_opt( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + # (8000+1) * 8 (size of the element of LongTensor/int64_t offsets) + # = ~62.5KB > 48KB default shared memory on V100/A100. + @unittest.skipIf(*gpu_unavailable) + @given( + num_jagged_dim=st.just(1), + outer_dense_size=st.just(8000), + inner_dense_size=st.just(16), + dtype=st.just(torch.half), + device=cpu_and_maybe_gpu(), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_dense_to_jagged_opt_large_batch( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + self._test_dense_to_jagged( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + precompute_total_L, + ) + + @given( + num_jagged_dim=st.integers(1, 5), + outer_dense_size=st.integers(0, 5), + inner_dense_size=st.integers(0, 5), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=st.sampled_from([torch.device("meta")]), + precompute_total_L=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged_meta_backend( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + precompute_total_L: bool, + ) -> None: + device = torch.device("cpu") + values_2d, offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device + ) + values_2d = values_2d.clone().detach().requires_grad_(True) + + # jagged -> dense + dense = torch.ops.fbgemm.jagged_to_padded_dense(values_2d, offsets, max_lengths) + + # dense -> jagged (op which is being tested) + if precompute_total_L: + total_L = values_2d.size(0) + dense.to(device) + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets, total_L + ) + else: + dense.to(device) + jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( + dense, offsets + ) + + jagged_values.to(device) + # jagged -> dense + dense2 = torch.ops.fbgemm.jagged_to_padded_dense( + jagged_values, jagged_offsets, max_lengths + ) + + # verify forward + assert dense.size() == dense2.size() + + @optests.dontGenerateOpCheckTests("tests that call torch.compile are slow") + @unittest.skipIf(*symint_vector_unsupported()) + @given( + num_jagged_dim=st.integers(1, 5), + # TODO: size = 0/1 will be incorrectly specialized + outer_dense_size=st.integers(2, 5), + inner_dense_size=st.integers(2, 5), + dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), + device=cpu_and_maybe_gpu(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_dense_to_jagged_dynamic_shape( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + # Start a fresh compile for each parameter of the test case + torch._dynamo.reset() + + values_2d, offsets, max_lengths = generate_jagged_tensor( + num_jagged_dim, + outer_dense_size, + inner_dense_size, + dtype, + device, + mark_dynamic=True, + ) + values_2d = values_2d.clone().detach().requires_grad_(True) + + def jagged_to_dense( + values: torch.Tensor, + offsets: List[torch.LongTensor], + max_lengths: List[int], + ) -> torch.Tensor: + return torch.ops.fbgemm.jagged_to_padded_dense(values, offsets, max_lengths) + + # jagged -> dense + dense = jagged_to_dense(values_2d, offsets, max_lengths.tolist()) + + # dense -> jagged, it is required to pre-compute totalL + total_L = values_2d.size(0) + dense = dense.clone().detach().to(device) + + torch._dynamo.mark_dynamic(dense, 0) + torch._dynamo.mark_dynamic(dense, -1) + + def dense_to_jagged_withL( + dense: torch.Tensor, offsets: List[torch.LongTensor], total_L: List[int] + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.fbgemm.dense_to_jagged(dense, offsets, total_L) + + def dense_to_jagged_noL( + dense: torch.Tensor, offsets: List[torch.LongTensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.fbgemm.dense_to_jagged(dense, offsets) + + jagged_values, jagged_offsets = dense_to_jagged_noL(dense, offsets) + jagged_values, jagged_offsets = dense_to_jagged_withL(dense, offsets, total_L) + + jagged_values.to(device) + # jagged -> dense + dense2 = torch.ops.fbgemm.jagged_to_padded_dense( + jagged_values, jagged_offsets, max_lengths + ) + + # verify forward + assert dense.size() == dense2.size() + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/jagged/jagged_tensor_ops_test.py b/fbgemm_gpu/test/jagged/jagged_tensor_ops_test.py index e4467b2a7..a269dd0e3 100644 --- a/fbgemm_gpu/test/jagged/jagged_tensor_ops_test.py +++ b/fbgemm_gpu/test/jagged/jagged_tensor_ops_test.py @@ -11,7 +11,7 @@ import itertools import random import unittest -from typing import List, Tuple +from typing import List import hypothesis.strategies as st import torch @@ -23,7 +23,6 @@ generate_jagged_tensor, open_source, to_padded_dense, - torch_compiled, ) if open_source: @@ -132,264 +131,6 @@ def test_expand_into_jagged_permute( output_permute_gpu.cpu(), output_permute_ref_tensor ) - def _test_dense_to_jagged( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - precompute_total_L: bool, - ) -> None: - # Generate multi-dim jagged tensor - device = torch.device(device_type) - values_2d, offsets, max_lengths = generate_jagged_tensor( - num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device - ) - values_2d = values_2d.clone().detach().requires_grad_(True) - - # jagged -> dense - dense = torch.ops.fbgemm.jagged_to_padded_dense(values_2d, offsets, max_lengths) - - # dense -> jagged (op which is being tested) - if precompute_total_L: - total_L = values_2d.size(0) - jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( - dense, offsets, total_L - ) - else: - jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( - dense, offsets - ) - - # jagged -> dense - dense2 = torch.ops.fbgemm.jagged_to_padded_dense( - jagged_values, jagged_offsets, max_lengths - ) - - # verify forward - torch.testing.assert_close(dense, dense2) - - # verify backward - dense.retain_grad() - ref_output_values = jagged_values.clone().detach().requires_grad_(True) - ref_values = dense.clone().detach().requires_grad_(True) - jagged_values.backward(ref_output_values) - torch.testing.assert_close(dense.grad, ref_values) - - @given( - num_jagged_dim=st.integers(1, 5), - outer_dense_size=st.integers(0, 5), - inner_dense_size=st.integers(0, 5), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), - precompute_total_L=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_dense_to_jagged( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - precompute_total_L: bool, - ) -> None: - self._test_dense_to_jagged( - num_jagged_dim, - outer_dense_size, - inner_dense_size, - dtype, - device_type, - precompute_total_L, - ) - - @unittest.skipIf(*gpu_unavailable) - @given( - num_jagged_dim=st.just(1), - outer_dense_size=st.integers(0, 6000), - inner_dense_size=st.sampled_from([8, 16, 23, 24, 48, 50, 64, 72, 96, 192]), - dtype=st.just(torch.half), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), - precompute_total_L=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_dense_to_jagged_opt( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - precompute_total_L: bool, - ) -> None: - self._test_dense_to_jagged( - num_jagged_dim, - outer_dense_size, - inner_dense_size, - dtype, - device_type, - precompute_total_L, - ) - - # (8000+1) * 8 (size of the element of LongTensor/int64_t offsets) - # = ~62.5KB > 48KB default shared memory on V100/A100. - @unittest.skipIf(*gpu_unavailable) - @given( - num_jagged_dim=st.just(1), - outer_dense_size=st.just(8000), - inner_dense_size=st.just(16), - dtype=st.just(torch.half), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), - precompute_total_L=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) - def test_dense_to_jagged_opt_large_batch( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - precompute_total_L: bool, - ) -> None: - self._test_dense_to_jagged( - num_jagged_dim, - outer_dense_size, - inner_dense_size, - dtype, - device_type, - precompute_total_L, - ) - - @given( - num_jagged_dim=st.integers(1, 5), - outer_dense_size=st.integers(0, 5), - inner_dense_size=st.integers(0, 5), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=st.sampled_from(["meta"]), - precompute_total_L=st.booleans(), - ) - @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_dense_to_jagged_meta_backend( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - precompute_total_L: bool, - ) -> None: - device = torch.device("cpu") - values_2d, offsets, max_lengths = generate_jagged_tensor( - num_jagged_dim, outer_dense_size, inner_dense_size, dtype, device - ) - values_2d = values_2d.clone().detach().requires_grad_(True) - - # jagged -> dense - dense = torch.ops.fbgemm.jagged_to_padded_dense(values_2d, offsets, max_lengths) - - # dense -> jagged (op which is being tested) - if precompute_total_L: - total_L = values_2d.size(0) - dense.to(device_type) - jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( - dense, offsets, total_L - ) - else: - dense.to(device_type) - jagged_values, jagged_offsets = torch.ops.fbgemm.dense_to_jagged( - dense, offsets - ) - - jagged_values.to(device_type) - # jagged -> dense - dense2 = torch.ops.fbgemm.jagged_to_padded_dense( - jagged_values, jagged_offsets, max_lengths - ) - - # verify forward - assert dense.size() == dense2.size() - - @optests.dontGenerateOpCheckTests("tests that call torch.compile are slow") - @unittest.skipIf(*symint_vector_unsupported()) - @given( - num_jagged_dim=st.integers(1, 5), - # TODO: size = 0/1 will be incorrectly specialized - outer_dense_size=st.integers(2, 5), - inner_dense_size=st.integers(2, 5), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), - ) - @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) - def test_dense_to_jagged_dynamic_shape( - self, - num_jagged_dim: int, - outer_dense_size: int, - inner_dense_size: int, - dtype: torch.dtype, - device_type: str, - ) -> None: - # Start a fresh compile for each parameter of the test case - torch._dynamo.reset() - - values_2d, offsets, max_lengths = generate_jagged_tensor( - num_jagged_dim, - outer_dense_size, - inner_dense_size, - dtype, - torch.device(device_type), - mark_dynamic=True, - ) - values_2d = values_2d.clone().detach().requires_grad_(True) - - def jagged_to_dense( - values: torch.Tensor, - offsets: List[torch.LongTensor], - max_lengths: List[int], - ) -> torch.Tensor: - return torch.ops.fbgemm.jagged_to_padded_dense(values, offsets, max_lengths) - - # jagged -> dense - dense = jagged_to_dense(values_2d, offsets, max_lengths.tolist()) - - # dense -> jagged, it is required to pre-compute totalL - total_L = values_2d.size(0) - dense = dense.clone().detach().to(device_type) - - torch._dynamo.mark_dynamic(dense, 0) - torch._dynamo.mark_dynamic(dense, -1) - - def dense_to_jagged_withL( - dense: torch.Tensor, offsets: List[torch.LongTensor], total_L: List[int] - ) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.ops.fbgemm.dense_to_jagged(dense, offsets, total_L) - - def dense_to_jagged_noL( - dense: torch.Tensor, offsets: List[torch.LongTensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: - return torch.ops.fbgemm.dense_to_jagged(dense, offsets) - - jagged_values, jagged_offsets = dense_to_jagged_noL(dense, offsets) - jagged_values, jagged_offsets = dense_to_jagged_withL(dense, offsets, total_L) - - jagged_values.to(device_type) - # jagged -> dense - dense2 = torch.ops.fbgemm.jagged_to_padded_dense( - jagged_values, jagged_offsets, max_lengths - ) - - # verify forward - assert dense.size() == dense2.size() - @given( num_jagged_dim=st.integers(1, 5), outer_dense_size=st.integers(0, 5), @@ -512,217 +253,6 @@ def test_jagged_to_padded_dense_meta_backend( assert output.size() == output_ref.size() - @settings( - verbosity=Verbosity.verbose, - max_examples=20, - deadline=None, - ) - @given( - B=st.integers(0, 32), - H=st.integers(1, 3), - max_L=st.integers(1, 32), - D=st.integers(0, 32), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=( - st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu") - ), - ) - def test_batched_dense_vec_jagged_2d_mul( - self, - B: int, - H: int, - max_L: int, - D: int, - dtype: torch.dtype, - device_type: str, - ) -> None: - assume(H == 1 or B != 0) - # CPU doesn't support bfloat16 - assume(device_type != "cpu" or dtype != torch.bfloat16) - - device = torch.device(device_type) - torch.backends.cuda.matmul.allow_tf32 = False - - # Sometimes length[i] exceed max_L meaning jagged->dense will be - # truncation vs. padding - lengths = torch.randint(max_L * 2, size=(B,), device=device) - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) - dense = torch.rand((B * H, max_L), dtype=dtype, device=device) - padded_values = torch.ops.fbgemm.jagged_to_padded_dense( - values, - [offsets], - [max_L], - ) # [B, N, H * D] - - bmm_arg1 = dense.unsqueeze(1) - bmm_arg2 = ( - padded_values.reshape(B, max_L, H, D) - .transpose(1, 2) - .reshape(B * H, max_L, D) - ) - # torch.bmm not implemented for Half on CPU - if dtype in [torch.half, torch.bfloat16] and device_type == "cpu": - bmm_arg1 = bmm_arg1.float() - bmm_arg2 = bmm_arg2.float() - output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( - 1 - ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] - if dtype in [torch.half, torch.bfloat16] and device_type == "cpu": - output_ref = output_ref.to(dtype) - output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( - dense, values, offsets - ) - torch.testing.assert_close( - output, - output_ref, - rtol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, - atol=1e-2 if dtype in [torch.half, torch.bfloat16] else None, - ) - - gradcheck( - torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, - ( - dense.clone().detach().float().requires_grad_(True), - values.clone().detach().float().requires_grad_(True), - offsets, - ), - eps=1e-2, - atol=1e-3, - rtol=1e-3, - ) - - @settings( - verbosity=Verbosity.verbose, - max_examples=20, - deadline=None, - ) - @given( - B=st.integers(0, 32), - H=st.integers(1, 3), - max_L=st.integers(1, 32), - D=st.integers(0, 32), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=st.sampled_from(["meta"]), - ) - def test_batched_dense_vec_jagged_2d_mul_meta_backend( - self, - B: int, - H: int, - max_L: int, - D: int, - dtype: torch.dtype, - device_type: str, - ) -> None: - assume(H == 1 or B != 0) - - device = torch.device("cpu") - torch.backends.cuda.matmul.allow_tf32 = False - - # Sometimes length[i] exceed max_L meaning jagged->dense will be - # truncation vs. padding - lengths = torch.randint(max_L * 2, size=(B,), device=device) - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) - dense = torch.rand((B * H, max_L), dtype=dtype, device=device) - padded_values = torch.ops.fbgemm.jagged_to_padded_dense( - values, - [offsets], - [max_L], - ) # [B, N, H * D] - - bmm_arg1 = dense.unsqueeze(1) - bmm_arg2 = ( - padded_values.reshape(B, max_L, H, D) - .transpose(1, 2) - .reshape(B * H, max_L, D) - ) - # torch.bmm not implemented for Half on CPU - if dtype in [torch.half, torch.bfloat16]: - bmm_arg1 = bmm_arg1.float() - bmm_arg2 = bmm_arg2.float() - output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( - 1 - ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] - dense.to(device_type) - values.to(device_type) - output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul( - dense, values, offsets - ) - assert output.size() == output_ref.size() - - @optests.dontGenerateOpCheckTests("tests that call torch.compile are slow") - @unittest.skipIf(*symint_vector_unsupported()) - @settings( - verbosity=Verbosity.verbose, - max_examples=20, - deadline=None, - ) - @given( - B=st.integers(2, 32), - H=st.integers(1, 3), - max_L=st.integers(1, 32), - D=st.integers(2, 32), - dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]), - device_type=st.just("cpu"), - ) - def test_batched_dense_vec_jagged_2d_mul_dynamic_shape( - self, - B: int, - H: int, - max_L: int, - D: int, - dtype: torch.dtype, - device_type: str, - ) -> None: - # Start a fresh compile for each parameter of the test case - torch._dynamo.reset() - - assume(H == 1 or B != 0) - - device = torch.device(device_type) - torch.backends.cuda.matmul.allow_tf32 = False - - # Sometimes length[i] exceed max_L meaning jagged->dense will be - # truncation vs. padding - lengths = torch.randint(low=1, high=max_L * 2, size=(B,), device=device) - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) - values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device) - dense = torch.rand((B * H, max_L), dtype=dtype, device=device) - padded_values = torch.ops.fbgemm.jagged_to_padded_dense( - values, - [offsets], - [max_L], - ) # [B, N, H * D] - - bmm_arg1 = dense.unsqueeze(1) - bmm_arg2 = ( - padded_values.reshape(B, max_L, H, D) - .transpose(1, 2) - .reshape(B * H, max_L, D) - ) - # torch.bmm not implemented for Half on CPU - if dtype in [torch.half, torch.bfloat16]: - bmm_arg1 = bmm_arg1.float() - bmm_arg2 = bmm_arg2.float() - output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze( - 1 - ) # [B H, 1, N] x [B H, N, D] = [B H, 1, D] - dense.to(device_type) - values.to(device_type) - - torch._dynamo.mark_dynamic(dense, 0) - torch._dynamo.mark_dynamic(values, 0) - torch._dynamo.mark_dynamic(values, 1) - torch._dynamo.mark_dynamic(offsets, 0) - - output = torch_compiled( - torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, - fullgraph=True, - dynamic=True, - )(dense, values, offsets) - assert output.size() == output_ref.size() - @staticmethod def jagged_index_select_2d_ref( values: torch.Tensor,