From ae2b84feba6ea6fa6367239059151327d0594ba2 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Thu, 23 May 2024 14:21:19 -0700 Subject: [PATCH] permute_2D_sparse_data Autogad formula (#2625) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2625 X-link: https://github.com/pytorch/torchrec/pull/2034 Adding permute_2D_sparse_data python formula for Inductor compilation. It's needed for Torchrec sync collectives - verifying with torchrec/distributed/tests:test_comm that tests fwd and bwd for collectives that use this op. Reviewed By: williamwen42 Differential Revision: D57625720 fbshipit-source-id: 4596111afb12e1e72e643f929497e5941b8706bc --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 37 +++++++++++++++++++++++ fbgemm_gpu/test/sparse/common.py | 6 ---- fbgemm_gpu/test/sparse/failures_dict.json | 22 ++------------ 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index c43053a03..ca935b2f2 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -102,6 +102,43 @@ def permute_2D_sparse_data_meta( return permuted_lengths, permuted_indices, permuted_weights +@impl_abstract("fbgemm::invert_permute") +def invert_permute_abstract(permute: Tensor) -> Tensor: + return torch.empty_like(permute) + + +# pyre-ignore +def permute_2D_sparse_data_setup_context(ctx, inputs, output): + permute, lengths, values, weights, permuted_lengths_sum = inputs + permuted_lengths, permuted_values, permuted_weights = output + ctx.permute = permute + ctx.permuted_lengths = permuted_lengths + + +# pyre-ignore +def permute_2D_sparse_data_backward(ctx, grad_lengths, grad_values, grad_weights): + inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute) + permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = ( + torch.ops.fbgemm.permute_2D_sparse_data( + inv_permute, ctx.permuted_lengths, grad_values, grad_weights + ) + ) + return ( + None, + permuted_grad_lengths, + permuted_grad_values, + permuted_grad_weights, + None, + ) + + +torch.library.register_autograd( + "fbgemm::permute_2D_sparse_data", + permute_2D_sparse_data_backward, + setup_context=permute_2D_sparse_data_setup_context, +) + + @impl_abstract("fbgemm::permute_1D_sparse_data") def permute_1D_sparse_data_meta( permute: Tensor, diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 5a3fba27c..a8cbe714c 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -129,12 +129,6 @@ def extend_test_class( additional_decorators = { **(additional_decorators or {}), - **{ - "test_pt2_compliant_tag_fbgemm_permute_2D_sparse_data": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ] - }, } # Only generate tests for PyTorch 2.2+ diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 35e6d52d2..a8e067d37 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -144,16 +144,7 @@ "status": "xfail" } }, - "fbgemm::invert_permute": { - "MiscOpsTest.test_aot_dispatch_dynamic__test_invert_permute": { - "comment": "", - "status": "xfail" - }, - "MiscOpsTest.test_faketensor__test_invert_permute": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::invert_permute": {}, "fbgemm::pack_segments": {}, "fbgemm::permute102_baddbmm_permute102": { "MiscOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": { @@ -166,16 +157,7 @@ } }, "fbgemm::permute_1D_sparse_data": {}, - "fbgemm::permute_2D_sparse_data": { - "PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": { - "comment": "", - "status": "xfail" - }, - "PermuteIndicesTest.test_aot_dispatch_dynamic__test_permute_indices": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::permute_2D_sparse_data": {}, "fbgemm::permute_sequence_embeddings": { "PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": { "comment": "",