Skip to content

Commit

Permalink
permute_*pooled_embs_split Autograd pt2 compat (#2619)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2619

permute_pooled_ _auto_grad was not compatible with PT2 because it was calling directly _impl function.
For PT2 compatibility we need to call it via Dispatcher.

Meta functions were missing.

Reviewed By: ezyang, Microve

Differential Revision: D57671023

fbshipit-source-id: 25977a9ed8de82b23e1bb02f8d8bbbe796acabc8
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 22, 2024
1 parent 7ff0bdc commit 935494c
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 11 deletions.
25 changes: 25 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:embedding_ops_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
)


import torch.utils._pytree as pytree
Expand Down Expand Up @@ -875,3 +878,25 @@ def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract(
saved_tensor: torch.Tensor,
) -> torch.Tensor:
return grad.new_empty([torch.library.get_ctx().new_dynamic_size()])


@impl_abstract("fbgemm::permute_pooled_embs_split")
def permute_pooled_embs_split_abstract(
pooled_embs: Tensor,
offset_dim_list: Tensor,
permute_list: Tensor,
inv_offset_dim_list: Tensor,
inv_permute_list: Tensor,
) -> Tensor:
return torch.empty_like(pooled_embs)


@impl_abstract("fbgemm::permute_duplicate_pooled_embs_split")
def permute_duplicate_pooled_embs_split_abstract(
pooled_embs: Tensor,
offset_dim_list: Tensor,
permute_list: Tensor,
inv_offset_dim_list: Tensor,
inv_permute_list: Tensor,
) -> Tensor:
return torch.empty_like(pooled_embs)
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class PermutePooledEmbsFunctionSplit
const at::Tensor& permute_list,
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list) {
at::AutoDispatchBelowADInplaceOrView guard;
ctx->saved_data["offset_dim_list"] = offset_dim_list;
ctx->saved_data["permute_list"] = permute_list;
ctx->saved_data["inv_offset_dim_list"] = inv_offset_dim_list;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,67 @@ Tensor permute_duplicate_pooled_embs_split_cpu(
true);
}

Tensor permute_pooled_embs_auto_grad_split_cpu(
Tensor permute_pooled_embs_split_dispatch_call(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_pooled_embs_split", "")
.typed<decltype(fbgemm_gpu::permute_pooled_embs_split_cpu)>();
return op.call(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_split_dispatch_call(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<permute_pooled_embs_split_cpu>::apply(
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_duplicate_pooled_embs_split", "")
.typed<
decltype(fbgemm_gpu::permute_duplicate_pooled_embs_split_cpu)>();
return op.call(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_pooled_embs_auto_grad_split_cpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_pooled_embs_split_dispatch_call>::
apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_duplicate_pooled_embs_split_cpu>::
permute_duplicate_pooled_embs_split_dispatch_call>::
apply(
pooled_embs,
offset_dim_list,
Expand All @@ -116,6 +155,7 @@ Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu(
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.set_python_module("fbgemm_gpu.sparse_ops");
m.def(
"permute_pooled_embs_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
Expand All @@ -127,12 +167,12 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
fbgemm_gpu::permute_duplicate_pooled_embs_split_cpu);
m.def(
"permute_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
DISPATCH_TO_AUTOGRAD_CPU(
"permute_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_pooled_embs_auto_grad_split_cpu);
m.def(
"permute_duplicate_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
DISPATCH_TO_AUTOGRAD_CPU(
"permute_duplicate_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_cpu);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,73 @@

using Tensor = at::Tensor;

namespace fbgemm_gpu {
namespace {

Tensor permute_pooled_embs_auto_grad_split_gpu(
Tensor permute_pooled_embs_split_dispatch_call(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<permute_pooled_embs_split_gpu>::apply(
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_pooled_embs_split", "")
.typed<decltype(fbgemm_gpu::permute_pooled_embs_split_gpu)>();
return op.call(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_split_dispatch_call(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_duplicate_pooled_embs_split", "")
.typed<
decltype(fbgemm_gpu::permute_duplicate_pooled_embs_split_gpu)>();
return op.call(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

} // namespace

namespace fbgemm_gpu {

Tensor permute_pooled_embs_auto_grad_split_gpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_pooled_embs_split_dispatch_call>::
apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_auto_grad_split_gpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_duplicate_pooled_embs_split_gpu>::
permute_duplicate_pooled_embs_split_dispatch_call>::
apply(
pooled_embs,
offset_dim_list,
Expand All @@ -57,10 +100,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA(
"permute_duplicate_pooled_embs_split",
fbgemm_gpu::permute_duplicate_pooled_embs_split_gpu);
DISPATCH_TO_CUDA(
DISPATCH_TO_AUTOGRAD_CUDA(
"permute_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_pooled_embs_auto_grad_split_gpu);
DISPATCH_TO_CUDA(
DISPATCH_TO_AUTOGRAD_CUDA(
"permute_duplicate_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_gpu);
}

0 comments on commit 935494c

Please sign in to comment.