From e860e0916316b14baeab8d30a0b583e4d61b3d7d Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Fri, 24 May 2024 05:10:00 -0700 Subject: [PATCH] permute_2D_sparse_data Autograd formula (#2629) Summary: Reland of D57625720 without dependency on torchrec Adding permute_2D_sparse_data python formula for Inductor compilation. Differential Revision: D57773001 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 37 +++++++++++++++++++++++ fbgemm_gpu/test/sparse/common.py | 6 ---- fbgemm_gpu/test/sparse/failures_dict.json | 22 ++------------ 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index c43053a034..ca935b2f2f 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -102,6 +102,43 @@ def permute_2D_sparse_data_meta( return permuted_lengths, permuted_indices, permuted_weights +@impl_abstract("fbgemm::invert_permute") +def invert_permute_abstract(permute: Tensor) -> Tensor: + return torch.empty_like(permute) + + +# pyre-ignore +def permute_2D_sparse_data_setup_context(ctx, inputs, output): + permute, lengths, values, weights, permuted_lengths_sum = inputs + permuted_lengths, permuted_values, permuted_weights = output + ctx.permute = permute + ctx.permuted_lengths = permuted_lengths + + +# pyre-ignore +def permute_2D_sparse_data_backward(ctx, grad_lengths, grad_values, grad_weights): + inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute) + permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = ( + torch.ops.fbgemm.permute_2D_sparse_data( + inv_permute, ctx.permuted_lengths, grad_values, grad_weights + ) + ) + return ( + None, + permuted_grad_lengths, + permuted_grad_values, + permuted_grad_weights, + None, + ) + + +torch.library.register_autograd( + "fbgemm::permute_2D_sparse_data", + permute_2D_sparse_data_backward, + setup_context=permute_2D_sparse_data_setup_context, +) + + @impl_abstract("fbgemm::permute_1D_sparse_data") def permute_1D_sparse_data_meta( permute: Tensor, diff --git a/fbgemm_gpu/test/sparse/common.py b/fbgemm_gpu/test/sparse/common.py index 5a3fba27c8..a8cbe714cb 100644 --- a/fbgemm_gpu/test/sparse/common.py +++ b/fbgemm_gpu/test/sparse/common.py @@ -129,12 +129,6 @@ def extend_test_class( additional_decorators = { **(additional_decorators or {}), - **{ - "test_pt2_compliant_tag_fbgemm_permute_2D_sparse_data": [ - # This operator has been grandfathered in. We need to fix this test failure. - unittest.expectedFailure, - ] - }, } # Only generate tests for PyTorch 2.2+ diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 35e6d52d22..a8e067d377 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -144,16 +144,7 @@ "status": "xfail" } }, - "fbgemm::invert_permute": { - "MiscOpsTest.test_aot_dispatch_dynamic__test_invert_permute": { - "comment": "", - "status": "xfail" - }, - "MiscOpsTest.test_faketensor__test_invert_permute": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::invert_permute": {}, "fbgemm::pack_segments": {}, "fbgemm::permute102_baddbmm_permute102": { "MiscOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": { @@ -166,16 +157,7 @@ } }, "fbgemm::permute_1D_sparse_data": {}, - "fbgemm::permute_2D_sparse_data": { - "PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": { - "comment": "", - "status": "xfail" - }, - "PermuteIndicesTest.test_aot_dispatch_dynamic__test_permute_indices": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::permute_2D_sparse_data": {}, "fbgemm::permute_sequence_embeddings": { "PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": { "comment": "",