Skip to content

Commit

Permalink
Re-organize jagged tensor tests, pt 4 (#2409)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2409

- Re-organize jagged tensor tests, pt 4

Reviewed By: spcyppt

Differential Revision: D54705747
  • Loading branch information
q10 authored and facebook-github-bot committed Mar 11, 2024
1 parent dc3a268 commit 98a4a5d
Show file tree
Hide file tree
Showing 4 changed files with 549 additions and 486 deletions.
249 changes: 249 additions & 0 deletions fbgemm_gpu/test/jagged/batched_dense_vec_jagged_2d_mul_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
# pyre-ignore-all-errors[56]

import unittest

import hypothesis.strategies as st
import torch
import torch._dynamo
from hypothesis import assume, given, settings, Verbosity

from .common import additional_decorators, open_source, torch_compiled

if open_source:
# pyre-ignore[21]
from test_utils import (
cpu_and_maybe_gpu,
gradcheck,
optests,
symint_vector_unsupported,
)
else:
from fbgemm_gpu.test.test_utils import (
cpu_and_maybe_gpu,
gradcheck,
optests,
symint_vector_unsupported,
)


@optests.generate_opcheck_tests(additional_decorators=additional_decorators)
class BatchedDenseVecJagged2DMulTest(unittest.TestCase):
@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
@given(
B=st.integers(0, 32),
H=st.integers(1, 3),
max_L=st.integers(1, 32),
D=st.integers(0, 32),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
device=cpu_and_maybe_gpu(),
)
def test_batched_dense_vec_jagged_2d_mul(
self,
B: int,
H: int,
max_L: int,
D: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
assume(H == 1 or B != 0)
# CPU doesn't support bfloat16
assume(device != torch.device("cpu") or dtype != torch.bfloat16)

torch.backends.cuda.matmul.allow_tf32 = False

# Sometimes length[i] exceed max_L meaning jagged->dense will be
# truncation vs. padding
lengths = torch.randint(max_L * 2, size=(B,), device=device)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device)
dense = torch.rand((B * H, max_L), dtype=dtype, device=device)
padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
values,
[offsets],
[max_L],
) # [B, N, H * D]

bmm_arg1 = dense.unsqueeze(1)
bmm_arg2 = (
padded_values.reshape(B, max_L, H, D)
.transpose(1, 2)
.reshape(B * H, max_L, D)
)
# torch.bmm not implemented for Half on CPU
if dtype in [torch.half, torch.bfloat16] and device == torch.device("cpu"):
bmm_arg1 = bmm_arg1.float()
bmm_arg2 = bmm_arg2.float()
output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze(
1
) # [B H, 1, N] x [B H, N, D] = [B H, 1, D]
if dtype in [torch.half, torch.bfloat16] and device == torch.device("cpu"):
output_ref = output_ref.to(dtype)
output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul(
dense, values, offsets
)
torch.testing.assert_close(
output,
output_ref,
rtol=1e-2 if dtype in [torch.half, torch.bfloat16] else None,
atol=1e-2 if dtype in [torch.half, torch.bfloat16] else None,
)

gradcheck(
torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul,
(
dense.clone().detach().float().requires_grad_(True),
values.clone().detach().float().requires_grad_(True),
offsets,
),
eps=1e-2,
atol=1e-3,
rtol=1e-3,
)

@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
@given(
B=st.integers(0, 32),
H=st.integers(1, 3),
max_L=st.integers(1, 32),
D=st.integers(0, 32),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
device_type=st.sampled_from(["meta"]),
)
def test_batched_dense_vec_jagged_2d_mul_meta_backend(
self,
B: int,
H: int,
max_L: int,
D: int,
dtype: torch.dtype,
device_type: str,
) -> None:
assume(H == 1 or B != 0)

device = torch.device("cpu")
torch.backends.cuda.matmul.allow_tf32 = False

# Sometimes length[i] exceed max_L meaning jagged->dense will be
# truncation vs. padding
lengths = torch.randint(max_L * 2, size=(B,), device=device)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device)
dense = torch.rand((B * H, max_L), dtype=dtype, device=device)
padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
values,
[offsets],
[max_L],
) # [B, N, H * D]

bmm_arg1 = dense.unsqueeze(1)
bmm_arg2 = (
padded_values.reshape(B, max_L, H, D)
.transpose(1, 2)
.reshape(B * H, max_L, D)
)
# torch.bmm not implemented for Half on CPU
if dtype in [torch.half, torch.bfloat16]:
bmm_arg1 = bmm_arg1.float()
bmm_arg2 = bmm_arg2.float()
output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze(
1
) # [B H, 1, N] x [B H, N, D] = [B H, 1, D]
dense.to(device_type)
values.to(device_type)
output = torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul(
dense, values, offsets
)
assert output.size() == output_ref.size()

@optests.dontGenerateOpCheckTests("tests that call torch.compile are slow")
@unittest.skipIf(*symint_vector_unsupported())
@settings(
verbosity=Verbosity.verbose,
max_examples=20,
deadline=None,
)
@given(
B=st.integers(2, 32),
H=st.integers(1, 3),
max_L=st.integers(1, 32),
D=st.integers(2, 32),
dtype=st.sampled_from([torch.float, torch.half, torch.bfloat16]),
device_type=st.just("cpu"),
)
def test_batched_dense_vec_jagged_2d_mul_dynamic_shape(
self,
B: int,
H: int,
max_L: int,
D: int,
dtype: torch.dtype,
device_type: str,
) -> None:
# Start a fresh compile for each parameter of the test case
torch._dynamo.reset()

assume(H == 1 or B != 0)

device = torch.device(device_type)
torch.backends.cuda.matmul.allow_tf32 = False

# Sometimes length[i] exceed max_L meaning jagged->dense will be
# truncation vs. padding
lengths = torch.randint(low=1, high=max_L * 2, size=(B,), device=device)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
values = torch.rand((offsets[-1], H * D), dtype=dtype, device=device)
dense = torch.rand((B * H, max_L), dtype=dtype, device=device)
padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
values,
[offsets],
[max_L],
) # [B, N, H * D]

bmm_arg1 = dense.unsqueeze(1)
bmm_arg2 = (
padded_values.reshape(B, max_L, H, D)
.transpose(1, 2)
.reshape(B * H, max_L, D)
)
# torch.bmm not implemented for Half on CPU
if dtype in [torch.half, torch.bfloat16]:
bmm_arg1 = bmm_arg1.float()
bmm_arg2 = bmm_arg2.float()
output_ref = torch.bmm(bmm_arg1, bmm_arg2).squeeze(
1
) # [B H, 1, N] x [B H, N, D] = [B H, 1, D]
dense.to(device_type)
values.to(device_type)

torch._dynamo.mark_dynamic(dense, 0)
torch._dynamo.mark_dynamic(values, 0)
torch._dynamo.mark_dynamic(values, 1)
torch._dynamo.mark_dynamic(offsets, 0)

output = torch_compiled(
torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul,
fullgraph=True,
dynamic=True,
)(dense, values, offsets)
assert output.size() == output_ref.size()


if __name__ == "__main__":
unittest.main()
23 changes: 8 additions & 15 deletions fbgemm_gpu/test/jagged/dense_bmm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

if open_source:
# pyre-ignore[21]
from test_utils import gpu_available, optests, symint_vector_unsupported
from test_utils import cpu_and_maybe_gpu, optests, symint_vector_unsupported
else:
from fbgemm_gpu.test.test_utils import (
gpu_available,
cpu_and_maybe_gpu,
optests,
symint_vector_unsupported,
)
Expand All @@ -36,9 +36,7 @@ class DenseBmmTest(unittest.TestCase):
N=st.integers(1, 32),
max_L=st.integers(1, 32),
dtype=st.sampled_from([torch.float]),
device_type=(
st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu")
),
device=cpu_and_maybe_gpu(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_jagged_jagged_bmm(
Expand All @@ -48,10 +46,9 @@ def test_jagged_jagged_bmm(
N: int,
max_L: int,
dtype: torch.dtype,
device_type: str,
device: torch.device,
) -> None:
assume(B != 0)
device = torch.device(device_type)
torch.backends.cuda.matmul.allow_tf32 = False
lengths = torch.randint(max_L + 1, size=(B,), device=device)
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
Expand Down Expand Up @@ -101,9 +98,7 @@ def test_jagged_jagged_bmm(
N=st.integers(1, 32),
max_L=st.integers(1, 32),
dtype=st.sampled_from([torch.float]),
device_type=(
st.sampled_from(["cpu", "cuda"]) if gpu_available else st.just("cpu")
),
device=cpu_and_maybe_gpu(),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
def test_jagged_dense_bmm(
Expand All @@ -113,10 +108,9 @@ def test_jagged_dense_bmm(
N: int,
max_L: int,
dtype: torch.dtype,
device_type: str,
device: torch.device,
) -> None:
assume(B != 0)
device = torch.device(device_type)
torch.backends.cuda.matmul.allow_tf32 = False
lengths = torch.randint(max_L + 1, size=(B,), device=device)
total_length = int(lengths.sum().item())
Expand Down Expand Up @@ -165,7 +159,7 @@ def test_jagged_dense_bmm(
N=st.integers(2, 32),
max_L=st.integers(2, 32),
dtype=st.sampled_from([torch.float]),
device_type=st.just("cpu"),
device=st.just(torch.device("cpu")),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_jagged_dense_bmm_dynamic_shape(
Expand All @@ -175,13 +169,12 @@ def test_jagged_dense_bmm_dynamic_shape(
N: int,
max_L: int,
dtype: torch.dtype,
device_type: str,
device: torch.device,
) -> None:
# Start a fresh compile for each parameter of the test case
torch._dynamo.reset()

assume(B != 0)
device = torch.device(device_type)
torch.backends.cuda.matmul.allow_tf32 = False
lengths = torch.randint(low=1, high=max_L + 1, size=(B,), device=device)
total_length = int(lengths.sum().item())
Expand Down
Loading

0 comments on commit 98a4a5d

Please sign in to comment.