From 935494c3ee336743f7392318990f94ec6745ba13 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Wed, 22 May 2024 13:40:56 -0700 Subject: [PATCH] permute_*pooled_embs_split Autograd pt2 compat (#2619) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2619 permute_pooled_ _auto_grad was not compatible with PT2 because it was calling directly _impl function. For PT2 compatibility we need to call it via Dispatcher. Meta functions were missing. Reviewed By: ezyang, Microve Differential Revision: D57671023 fbshipit-source-id: 25977a9ed8de82b23e1bb02f8d8bbbe796acabc8 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 25 +++++++++ .../permute_pooled_embs_function_split.h | 1 + ...permute_pooled_embedding_ops_split_cpu.cpp | 50 +++++++++++++++-- ...permute_pooled_embedding_ops_split_gpu.cpp | 55 +++++++++++++++++-- 4 files changed, 120 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index c06e5aa51..2cbc1af1c 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -46,6 +46,9 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu" + ) import torch.utils._pytree as pytree @@ -875,3 +878,25 @@ def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract( saved_tensor: torch.Tensor, ) -> torch.Tensor: return grad.new_empty([torch.library.get_ctx().new_dynamic_size()]) + + +@impl_abstract("fbgemm::permute_pooled_embs_split") +def permute_pooled_embs_split_abstract( + pooled_embs: Tensor, + offset_dim_list: Tensor, + permute_list: Tensor, + inv_offset_dim_list: Tensor, + inv_permute_list: Tensor, +) -> Tensor: + return torch.empty_like(pooled_embs) + + +@impl_abstract("fbgemm::permute_duplicate_pooled_embs_split") +def permute_duplicate_pooled_embs_split_abstract( + pooled_embs: Tensor, + offset_dim_list: Tensor, + permute_list: Tensor, + inv_offset_dim_list: Tensor, + inv_permute_list: Tensor, +) -> Tensor: + return torch.empty_like(pooled_embs) diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h b/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h index e5938f6cb..9d4a49179 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h @@ -34,6 +34,7 @@ class PermutePooledEmbsFunctionSplit const at::Tensor& permute_list, const at::Tensor& inv_offset_dim_list, const at::Tensor& inv_permute_list) { + at::AutoDispatchBelowADInplaceOrView guard; ctx->saved_data["offset_dim_list"] = offset_dim_list; ctx->saved_data["permute_list"] = permute_list; ctx->saved_data["inv_offset_dim_list"] = inv_offset_dim_list; diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp index 1935f2a83..2075f4bff 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp @@ -84,13 +84,36 @@ Tensor permute_duplicate_pooled_embs_split_cpu( true); } -Tensor permute_pooled_embs_auto_grad_split_cpu( +Tensor permute_pooled_embs_split_dispatch_call( + const Tensor& pooled_embs, // [B_local][Sum_T_global(D)] + const Tensor& offset_dim_list, + const Tensor& permute_list, + const Tensor& inv_offset_dim_list, + const Tensor& inv_permute_list) { + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_pooled_embs_split", "") + .typed(); + return op.call( + pooled_embs, + offset_dim_list, + permute_list, + inv_offset_dim_list, + inv_permute_list); +} + +Tensor permute_duplicate_pooled_embs_split_dispatch_call( const Tensor& pooled_embs, const Tensor& offset_dim_list, const Tensor& permute_list, const Tensor& inv_offset_dim_list, const Tensor& inv_permute_list) { - return PermutePooledEmbsFunctionSplit::apply( + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_duplicate_pooled_embs_split", "") + .typed< + decltype(fbgemm_gpu::permute_duplicate_pooled_embs_split_cpu)>(); + return op.call( pooled_embs, offset_dim_list, permute_list, @@ -98,6 +121,22 @@ Tensor permute_pooled_embs_auto_grad_split_cpu( inv_permute_list); } +Tensor permute_pooled_embs_auto_grad_split_cpu( + const Tensor& pooled_embs, + const Tensor& offset_dim_list, + const Tensor& permute_list, + const Tensor& inv_offset_dim_list, + const Tensor& inv_permute_list) { + return PermutePooledEmbsFunctionSplit< + permute_pooled_embs_split_dispatch_call>:: + apply( + pooled_embs, + offset_dim_list, + permute_list, + inv_offset_dim_list, + inv_permute_list); +} + Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu( const Tensor& pooled_embs, const Tensor& offset_dim_list, @@ -105,7 +144,7 @@ Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu( const Tensor& inv_offset_dim_list, const Tensor& inv_permute_list) { return PermutePooledEmbsFunctionSplit< - permute_duplicate_pooled_embs_split_cpu>:: + permute_duplicate_pooled_embs_split_dispatch_call>:: apply( pooled_embs, offset_dim_list, @@ -116,6 +155,7 @@ Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu( } // namespace fbgemm_gpu TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.set_python_module("fbgemm_gpu.sparse_ops"); m.def( "permute_pooled_embs_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); DISPATCH_TO_CPU( @@ -127,12 +167,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { fbgemm_gpu::permute_duplicate_pooled_embs_split_cpu); m.def( "permute_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); - DISPATCH_TO_CPU( + DISPATCH_TO_AUTOGRAD_CPU( "permute_pooled_embs_auto_grad_split", fbgemm_gpu::permute_pooled_embs_auto_grad_split_cpu); m.def( "permute_duplicate_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor"); - DISPATCH_TO_CPU( + DISPATCH_TO_AUTOGRAD_CPU( "permute_duplicate_pooled_embs_auto_grad_split", fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_cpu); } diff --git a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp index 2831a22fb..1ec23e0a3 100644 --- a/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp +++ b/fbgemm_gpu/src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp @@ -18,15 +18,19 @@ using Tensor = at::Tensor; -namespace fbgemm_gpu { +namespace { -Tensor permute_pooled_embs_auto_grad_split_gpu( +Tensor permute_pooled_embs_split_dispatch_call( const Tensor& pooled_embs, const Tensor& offset_dim_list, const Tensor& permute_list, const Tensor& inv_offset_dim_list, const Tensor& inv_permute_list) { - return PermutePooledEmbsFunctionSplit::apply( + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_pooled_embs_split", "") + .typed(); + return op.call( pooled_embs, offset_dim_list, permute_list, @@ -34,6 +38,45 @@ Tensor permute_pooled_embs_auto_grad_split_gpu( inv_permute_list); } +Tensor permute_duplicate_pooled_embs_split_dispatch_call( + const Tensor& pooled_embs, + const Tensor& offset_dim_list, + const Tensor& permute_list, + const Tensor& inv_offset_dim_list, + const Tensor& inv_permute_list) { + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_duplicate_pooled_embs_split", "") + .typed< + decltype(fbgemm_gpu::permute_duplicate_pooled_embs_split_gpu)>(); + return op.call( + pooled_embs, + offset_dim_list, + permute_list, + inv_offset_dim_list, + inv_permute_list); +} + +} // namespace + +namespace fbgemm_gpu { + +Tensor permute_pooled_embs_auto_grad_split_gpu( + const Tensor& pooled_embs, + const Tensor& offset_dim_list, + const Tensor& permute_list, + const Tensor& inv_offset_dim_list, + const Tensor& inv_permute_list) { + return PermutePooledEmbsFunctionSplit< + permute_pooled_embs_split_dispatch_call>:: + apply( + pooled_embs, + offset_dim_list, + permute_list, + inv_offset_dim_list, + inv_permute_list); +} + Tensor permute_duplicate_pooled_embs_auto_grad_split_gpu( const Tensor& pooled_embs, const Tensor& offset_dim_list, @@ -41,7 +84,7 @@ Tensor permute_duplicate_pooled_embs_auto_grad_split_gpu( const Tensor& inv_offset_dim_list, const Tensor& inv_permute_list) { return PermutePooledEmbsFunctionSplit< - permute_duplicate_pooled_embs_split_gpu>:: + permute_duplicate_pooled_embs_split_dispatch_call>:: apply( pooled_embs, offset_dim_list, @@ -57,10 +100,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { DISPATCH_TO_CUDA( "permute_duplicate_pooled_embs_split", fbgemm_gpu::permute_duplicate_pooled_embs_split_gpu); - DISPATCH_TO_CUDA( + DISPATCH_TO_AUTOGRAD_CUDA( "permute_pooled_embs_auto_grad_split", fbgemm_gpu::permute_pooled_embs_auto_grad_split_gpu); - DISPATCH_TO_CUDA( + DISPATCH_TO_AUTOGRAD_CUDA( "permute_duplicate_pooled_embs_auto_grad_split", fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_gpu); }