Skip to content

Commit

Permalink
Revert D63335491: Multisect successfully blamed "D63335491: Refactor …
Browse files Browse the repository at this point in the history
…the GIS to reuse same autograd function for all backends" for one test failure

Summary:
X-link: facebookresearch/FBGEMM#311

This diff reverts D63335491
D63335491: Refactor the GIS to reuse same autograd function for all backends by egienvalue causes the following test failure:

Tests affected:
- [cogwheel:cogwheel_model_import_inference_ads_v0_test#test_ads_v0_inference_model_import](https://www.internalfb.com/intern/test/281475118337004/)

Here's the Multisect link:
https://www.internalfb.com/multisect/11357230
Here are the tasks that are relevant to this breakage:
T203477897: Test cogwheel:cogwheel_model_import_inference_ads_v0_test#test_ads_v0_inference_model_import failing for ai_test_validation

The backout may land if someone accepts it.

If this diff has been generated in error, you can Commandeer and Abandon it.

Differential Revision: D63757977
  • Loading branch information
Dark Knight authored and facebook-github-bot committed Oct 2, 2024
1 parent 4a4d187 commit f847440
Show file tree
Hide file tree
Showing 4 changed files with 483 additions and 521 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,11 +1073,11 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
)
impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract)
impl_abstract(
"fbgemm::group_index_select_dim0_forward_impl",
"fbgemm::group_index_select_dim0_gpu_impl",
group_index_select_dim0_gpu_impl_abstract,
)
impl_abstract(
"fbgemm::group_index_select_dim0_backward_impl",
"fbgemm::group_index_select_dim0_gpu_backward",
group_index_select_dim0_gpu_backward_abstract,
)
impl_abstract(
Expand Down
104 changes: 0 additions & 104 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <torch/csrc/autograd/custom_function.h>

#include <cstdint>

namespace fbgemm_gpu {
Expand Down Expand Up @@ -927,107 +924,6 @@ at::Tensor index_add_with_unique_indices_cuda(
const int consecutive_range_start,
const int consecutive_range_length);

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref);

std::pair<std::vector<at::Tensor>, std::vector<at::Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size);

class GroupIndexSelectDim0Op
: public torch::autograd::Function<GroupIndexSelectDim0Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
at::TensorList all_indices_input,
const int64_t group_size) {
at::AutoDispatchBelowADInplaceOrView guard;
static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_forward_impl", "")
.typed<decltype(group_index_select_dim0_forward_impl_cpu)>();
auto result = forward_op.call(all_indices_input, group_size);
TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2);

auto [input_group, indices_group] =
group_index_select_dim0_unpack(all_indices_input, group_size);
const auto input_dim = input_group[0].dim();
std::vector<c10::SymInt> input_shape_group;
input_shape_group.reserve(group_size * input_dim);

for (const auto i : c10::irange(group_size)) {
const auto& input = input_group[i];
// Copy input shape
auto input_shape = input.sym_sizes().vec();
input_shape_group.insert(
input_shape_group.end(), input_shape.begin(), input_shape.end());
}

// save indices, args_tensor, saved_data
auto saved_tensors = std::vector<at::Tensor>(indices_group);
saved_tensors.insert(
saved_tensors.end(), result.cbegin() + group_size, result.cend());
saved_tensors.push_back(input_group[0]);
ctx->save_for_backward(saved_tensors);
ctx->saved_data["input_shape_group"] = input_shape_group;

return result;
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output_group) {
TORCH_CHECK(grad_output_group.size() >= 2);
if (grad_output_group.size() == 2) {
// empty outputs
return torch::autograd::variable_list(1);
}
// remove redundant grads
auto group_size = grad_output_group.size() - 2;
grad_output_group.resize(group_size);

auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
auto output_shape_group =
ctx->saved_data["input_shape_group"].toSymIntVector();
grad_output_group.insert(
grad_output_group.end(), saved_tensors.begin(), saved_tensors.end());
static auto backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_backward_impl", "")
.typed<decltype(group_index_select_dim0_backward_impl_cpu)>();
auto res = backward_op.call(grad_output_group, output_shape_group);
// 1) Add group_size Variable()'s for indices
// Replace all empty tensors with Variable(). This must be done after the
// op.call to make __torch_dispatch__ work for the backward op.
std::fill(
res.begin(), res.begin() + group_size, torch::autograd::Variable());
// 3) Add 1 Variable() for group_size
res.push_back({});
return res;
}
};

///@ingroup sparse-data-cuda
void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
Expand Down
122 changes: 27 additions & 95 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2851,103 +2851,41 @@ Tensor pack_segments_cpu(
const int64_t max_length) {
return pack_segments_forward_cpu(t_in, lengths, max_length);
}

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size) {
return GroupIndexSelectDim0Op::apply(all_indices_input, group_size);
}

std::pair<std::vector<Tensor>, std::vector<Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size) {
std::vector<Tensor> indices_group;
std::vector<Tensor> input_group;

indices_group.reserve(group_size);
input_group.reserve(group_size);

for (const auto i : c10::irange(group_size)) {
indices_group.push_back(all_indices_input[i]);
input_group.push_back(all_indices_input[group_size + i]);
}

TORCH_CHECK(group_size == static_cast<int64_t>(indices_group.size()));

return std::make_pair(input_group, indices_group);
namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);
}

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group) {
const auto group_size = indices_group.size();
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;

if (group_size == 0) {
return std::vector<Tensor>();
}

// Pack input_group and indices_group into TensorList
std::vector<Tensor> all_indices_input_vec;
all_indices_input_vec.reserve(group_size * 2);

for (const Tensor& index : indices_group) {
all_indices_input_vec.push_back(index);
}
for (const Tensor& input : input_group) {
all_indices_input_vec.push_back(input);
for (const auto i : c10::irange(num_groups)) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}

at::TensorList all_indices_input_tensor = all_indices_input_vec;

static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_autograd_impl", "")
.typed<decltype(group_index_select_dim0_autograd_impl)>();
auto res = forward_op.call(all_indices_input_tensor, group_size);
TORCH_CHECK(res.size() == group_size + 2);
// only return the outputs (the first group_size elements)
res.resize(group_size);
return res;
return output_group;
}

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size) {
throw std::runtime_error(
"group_index_select_dim0_forward_impl is not implemented for CPU");
"group_index_select_dim0_gpu_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
throw std::runtime_error(
"group_index_select_dim0_backward_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
for (const auto i : c10::irange(num_groups)) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}
return output_group;
}

namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);
"group_index_select_dim0_gpu_backward is not implemented for CPU");
}

Tensor bottom_k_per_row(
Expand Down Expand Up @@ -3108,11 +3046,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{PT2_COMPLIANT_TAG});
// group_index_select_dim0_gpu helper functions - not defined for CPU!
m.def(
"group_index_select_dim0_autograd_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_forward_impl(Tensor[] inputs, int group_size) -> Tensor[]");
"group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_backward_impl(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
"group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
// This is an one-off op to be used in split_embedding_utils.py for zipf
// generation w/o replacement along dim=-1. If requires_unique=True, find
// smallest unique k. If the number of unique elements is less than k,
Expand Down Expand Up @@ -3196,14 +3132,13 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0",
fbgemm_gpu::group_index_select_dim0_decomposed);
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0_forward_impl",
fbgemm_gpu::group_index_select_dim0_forward_impl_cpu);
"group_index_select_dim0_gpu_impl",
fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu);
DISPATCH_TO_CPU(
"group_index_select_dim0_backward_impl",
fbgemm_gpu::group_index_select_dim0_backward_impl_cpu);
"group_index_select_dim0_gpu_backward",
fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu);
DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row);
}

Expand All @@ -3212,14 +3147,11 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
}

TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) {
m.impl(
"group_index_select_dim0",
&fbgemm_gpu::group_index_select_dim0_decomposed);
m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
// CPU group_index_select_dim0 is decomposable
m.impl(
"group_index_select_dim0",
TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed));
"group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0));
}
Loading

0 comments on commit f847440

Please sign in to comment.