Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert D63335491: Multisect successfully blamed "D63335491: Refactor the GIS to reuse same autograd function for all backends" for one test failure #3214

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,11 +1073,11 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
)
impl_abstract("fbgemm::bounds_check_indices", bounds_check_indices_abstract)
impl_abstract(
"fbgemm::group_index_select_dim0_forward_impl",
"fbgemm::group_index_select_dim0_gpu_impl",
group_index_select_dim0_gpu_impl_abstract,
)
impl_abstract(
"fbgemm::group_index_select_dim0_backward_impl",
"fbgemm::group_index_select_dim0_gpu_backward",
group_index_select_dim0_gpu_backward_abstract,
)
impl_abstract(
Expand Down
104 changes: 0 additions & 104 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <torch/csrc/autograd/custom_function.h>

#include <cstdint>

namespace fbgemm_gpu {
Expand Down Expand Up @@ -927,107 +924,6 @@ at::Tensor index_add_with_unique_indices_cuda(
const int consecutive_range_start,
const int consecutive_range_length);

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref);

std::pair<std::vector<at::Tensor>, std::vector<at::Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size);

class GroupIndexSelectDim0Op
: public torch::autograd::Function<GroupIndexSelectDim0Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
at::TensorList all_indices_input,
const int64_t group_size) {
at::AutoDispatchBelowADInplaceOrView guard;
static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_forward_impl", "")
.typed<decltype(group_index_select_dim0_forward_impl_cpu)>();
auto result = forward_op.call(all_indices_input, group_size);
TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2);

auto [input_group, indices_group] =
group_index_select_dim0_unpack(all_indices_input, group_size);
const auto input_dim = input_group[0].dim();
std::vector<c10::SymInt> input_shape_group;
input_shape_group.reserve(group_size * input_dim);

for (const auto i : c10::irange(group_size)) {
const auto& input = input_group[i];
// Copy input shape
auto input_shape = input.sym_sizes().vec();
input_shape_group.insert(
input_shape_group.end(), input_shape.begin(), input_shape.end());
}

// save indices, args_tensor, saved_data
auto saved_tensors = std::vector<at::Tensor>(indices_group);
saved_tensors.insert(
saved_tensors.end(), result.cbegin() + group_size, result.cend());
saved_tensors.push_back(input_group[0]);
ctx->save_for_backward(saved_tensors);
ctx->saved_data["input_shape_group"] = input_shape_group;

return result;
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output_group) {
TORCH_CHECK(grad_output_group.size() >= 2);
if (grad_output_group.size() == 2) {
// empty outputs
return torch::autograd::variable_list(1);
}
// remove redundant grads
auto group_size = grad_output_group.size() - 2;
grad_output_group.resize(group_size);

auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
auto output_shape_group =
ctx->saved_data["input_shape_group"].toSymIntVector();
grad_output_group.insert(
grad_output_group.end(), saved_tensors.begin(), saved_tensors.end());
static auto backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_backward_impl", "")
.typed<decltype(group_index_select_dim0_backward_impl_cpu)>();
auto res = backward_op.call(grad_output_group, output_shape_group);
// 1) Add group_size Variable()'s for indices
// Replace all empty tensors with Variable(). This must be done after the
// op.call to make __torch_dispatch__ work for the backward op.
std::fill(
res.begin(), res.begin() + group_size, torch::autograd::Variable());
// 3) Add 1 Variable() for group_size
res.push_back({});
return res;
}
};

///@ingroup sparse-data-cuda
void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
Expand Down
122 changes: 27 additions & 95 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2851,103 +2851,41 @@ Tensor pack_segments_cpu(
const int64_t max_length) {
return pack_segments_forward_cpu(t_in, lengths, max_length);
}

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size) {
return GroupIndexSelectDim0Op::apply(all_indices_input, group_size);
}

std::pair<std::vector<Tensor>, std::vector<Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size) {
std::vector<Tensor> indices_group;
std::vector<Tensor> input_group;

indices_group.reserve(group_size);
input_group.reserve(group_size);

for (const auto i : c10::irange(group_size)) {
indices_group.push_back(all_indices_input[i]);
input_group.push_back(all_indices_input[group_size + i]);
}

TORCH_CHECK(group_size == static_cast<int64_t>(indices_group.size()));

return std::make_pair(input_group, indices_group);
namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);
}

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group) {
const auto group_size = indices_group.size();
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;

if (group_size == 0) {
return std::vector<Tensor>();
}

// Pack input_group and indices_group into TensorList
std::vector<Tensor> all_indices_input_vec;
all_indices_input_vec.reserve(group_size * 2);

for (const Tensor& index : indices_group) {
all_indices_input_vec.push_back(index);
}
for (const Tensor& input : input_group) {
all_indices_input_vec.push_back(input);
for (const auto i : c10::irange(num_groups)) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}

at::TensorList all_indices_input_tensor = all_indices_input_vec;

static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow(
"fbgemm::group_index_select_dim0_autograd_impl", "")
.typed<decltype(group_index_select_dim0_autograd_impl)>();
auto res = forward_op.call(all_indices_input_tensor, group_size);
TORCH_CHECK(res.size() == group_size + 2);
// only return the outputs (the first group_size elements)
res.resize(group_size);
return res;
return output_group;
}

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size) {
throw std::runtime_error(
"group_index_select_dim0_forward_impl is not implemented for CPU");
"group_index_select_dim0_gpu_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
throw std::runtime_error(
"group_index_select_dim0_backward_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
for (const auto i : c10::irange(num_groups)) {
output_group.push_back(
at::index_select(input_group[i], 0, indices_group[i]));
}
return output_group;
}

namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);
"group_index_select_dim0_gpu_backward is not implemented for CPU");
}

Tensor bottom_k_per_row(
Expand Down Expand Up @@ -3108,11 +3046,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{PT2_COMPLIANT_TAG});
// group_index_select_dim0_gpu helper functions - not defined for CPU!
m.def(
"group_index_select_dim0_autograd_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_forward_impl(Tensor[] inputs, int group_size) -> Tensor[]");
"group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_backward_impl(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
"group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
// This is an one-off op to be used in split_embedding_utils.py for zipf
// generation w/o replacement along dim=-1. If requires_unique=True, find
// smallest unique k. If the number of unique elements is less than k,
Expand Down Expand Up @@ -3196,14 +3132,13 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0",
fbgemm_gpu::group_index_select_dim0_decomposed);
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0_forward_impl",
fbgemm_gpu::group_index_select_dim0_forward_impl_cpu);
"group_index_select_dim0_gpu_impl",
fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu);
DISPATCH_TO_CPU(
"group_index_select_dim0_backward_impl",
fbgemm_gpu::group_index_select_dim0_backward_impl_cpu);
"group_index_select_dim0_gpu_backward",
fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu);
DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row);
}

Expand All @@ -3212,14 +3147,11 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
}

TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) {
m.impl(
"group_index_select_dim0",
&fbgemm_gpu::group_index_select_dim0_decomposed);
m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
// CPU group_index_select_dim0 is decomposable
m.impl(
"group_index_select_dim0",
TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed));
"group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0));
}
Loading
Loading