Skip to content

Commit

Permalink
add PT2 support for permute_multi_embedding (pytorch#2381)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2381

It looks like test_pt2 already passed. not sure why the test can't capture the PT2 incompatibility in the op.

graph breaks: P1557581728
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmphgx6wM/rank_0/failures_and_restarts.html

Differential Revision: D62226292
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Sep 11, 2024
1 parent 0bc1baa commit dfadd39
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
kjt_for_pt2_tracing,
register_fake_classes,
)
from torchrec.sparse.jagged_tensor import _kt_regroup_arguments

try:
# pyre-ignore
Expand Down Expand Up @@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None:
inp = torch.randn(12, 3)
_test_compile_fwd_bwd(m, inp, device)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_permute_multi_embedding(self) -> None:
device = "cuda"
batch_size = 16

def func(values, permutes, in_shapes, out_shapes, out_lengths):
return torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths.tolist()
)

keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [torch.randn(batch_size, sum(L), device=device) for L in lengths]
for embs in values:
torch._dynamo.mark_dynamic(embs, 0)
torch._dynamo.mark_dynamic(embs, 1)
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
values[0], keys, lengths, groups
)
out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32)
inp = (values, permutes, in_shapes, out_shapes, out_lengths)
_test_compile_fwd_bwd(func, inp, device, unpack_inp=True)

@unittest.skipIf(
torch.cuda.device_count() < 1,
"Not enough GPUs, this test requires at least one GPU",
Expand Down

0 comments on commit dfadd39

Please sign in to comment.