Skip to content

Commit

Permalink
Refactor the GIS to reuse same autograd function for all backends (#3216
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #3216

X-link: facebookresearch/FBGEMM#313

This diff refactors the GIS impl to share the AutogradFunc for all device backends.

Device backends only need to impl and register following two ops. For backward compatibility, we keep the op signature the same.

```
torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
    at::TensorList all_indices_input,
    const int64_t group_size) {
  throw std::runtime_error(
      "group_index_select_dim0_forward_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
    at::TensorList all_inputs,
    c10::SymIntArrayRef output_shape_group_ref) {
  throw std::runtime_error(
      "group_index_select_dim0_backward_impl is not implemented for CPU");
}
```

```
  DISPATCH_TO_CPU(
      "group_index_select_dim0_gpu_impl", fbgemm_gpu::group_index_select_dim0_forward_impl_cpu);
  DISPATCH_TO_CPU(
      "group_index_select_dim0_gpu_backward", fbgemm_gpu::group_index_select_dim0_forward_impl_cpu);
```

Reviewed By: spcyppt

Differential Revision: D63809121

fbshipit-source-id: 4f6ca5bdf1810241730a77ab125d6aef31b0cd5b
  • Loading branch information
egienvalue authored and facebook-github-bot committed Oct 3, 2024
1 parent d9ae5c4 commit 788cd2a
Show file tree
Hide file tree
Showing 3 changed files with 512 additions and 472 deletions.
41 changes: 41 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <torch/csrc/autograd/custom_function.h>

#include <cstdint>

namespace fbgemm_gpu {
Expand Down Expand Up @@ -924,6 +927,44 @@ at::Tensor index_add_with_unique_indices_cuda(
const int consecutive_range_start,
const int consecutive_range_length);

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group);

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size);

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref);

std::pair<std::vector<at::Tensor>, std::vector<at::Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size);

class GroupIndexSelectDim0Op
: public torch::autograd::Function<GroupIndexSelectDim0Op> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
at::TensorList all_indices_input,
const int64_t group_size);

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output_group);
};

///@ingroup sparse-data-cuda
void group_index_select_or_add_cuda(
const int64_t* input_ptrs,
Expand Down
176 changes: 155 additions & 21 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2851,19 +2851,84 @@ Tensor pack_segments_cpu(
const int64_t max_length) {
return pack_segments_forward_cpu(t_in, lengths, max_length);
}
namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);

torch::autograd::variable_list group_index_select_dim0_autograd_impl(
at::TensorList all_indices_input,
const int64_t group_size) {
return GroupIndexSelectDim0Op::apply(all_indices_input, group_size);
}

std::pair<std::vector<Tensor>, std::vector<Tensor>>
group_index_select_dim0_unpack(
at::TensorList all_indices_input,
const int64_t group_size) {
std::vector<Tensor> indices_group;
std::vector<Tensor> input_group;

indices_group.reserve(group_size);
input_group.reserve(group_size);

for (const auto i : c10::irange(group_size)) {
indices_group.push_back(all_indices_input[i]);
input_group.push_back(all_indices_input[group_size + i]);
}

TORCH_CHECK(group_size == static_cast<int64_t>(indices_group.size()));

return std::make_pair(input_group, indices_group);
}

torch::autograd::variable_list group_index_select_dim0(
at::TensorList input_group,
at::TensorList indices_group) {
const auto group_size = indices_group.size();
std::vector<Tensor> output_group;

if (group_size == 0) {
return std::vector<Tensor>();
}

// Pack input_group and indices_group into TensorList
std::vector<Tensor> all_indices_input_vec;
all_indices_input_vec.reserve(group_size * 2);

for (const Tensor& index : indices_group) {
all_indices_input_vec.push_back(index);
}
for (const Tensor& input : input_group) {
all_indices_input_vec.push_back(input);
}

at::TensorList all_indices_input_tensor = all_indices_input_vec;

static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "")
.typed<decltype(group_index_select_dim0_autograd_impl)>();
auto res = forward_op.call(all_indices_input_tensor, group_size);
TORCH_CHECK(res.size() == group_size + 2);
// only return the outputs (the first group_size elements)
res.resize(group_size);
return res;
}

torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size) {
throw std::runtime_error(
"group_index_select_dim0_forward_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
throw std::runtime_error(
"group_index_select_dim0_backward_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_decomposed(
at::TensorList input_group,
at::TensorList indices_group) {
int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> output_group;
Expand All @@ -2874,18 +2939,83 @@ torch::autograd::variable_list group_index_select_dim0(
return output_group;
}

torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu(
torch::autograd::variable_list GroupIndexSelectDim0Op::forward(
torch::autograd::AutogradContext* ctx,
at::TensorList all_indices_input,
const int64_t group_size) {
throw std::runtime_error(
"group_index_select_dim0_gpu_impl is not implemented for CPU");
at::AutoDispatchBelowADInplaceOrView guard;
static auto forward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "")
.typed<decltype(group_index_select_dim0_forward_impl_cpu)>();
auto result = forward_op.call(all_indices_input, group_size);
TORCH_CHECK(static_cast<int64_t>(result.size()) == group_size + 2);

auto [input_group, indices_group] =
group_index_select_dim0_unpack(all_indices_input, group_size);
const auto input_dim = input_group[0].dim();
std::vector<c10::SymInt> input_shape_group;
input_shape_group.reserve(group_size * input_dim);

for (const auto i : c10::irange(group_size)) {
const auto& input = input_group[i];
// Copy input shape
auto input_shape = input.sym_sizes().vec();
input_shape_group.insert(
input_shape_group.end(), input_shape.begin(), input_shape.end());
}

// save indices, args_tensor, saved_data
auto saved_tensors = std::vector<at::Tensor>(indices_group);
saved_tensors.insert(
saved_tensors.end(), result.cbegin() + group_size, result.cend());
saved_tensors.push_back(input_group[0]);
ctx->save_for_backward(saved_tensors);
ctx->saved_data["input_shape_group"] = input_shape_group;

return result;
}

torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
throw std::runtime_error(
"group_index_select_dim0_gpu_backward is not implemented for CPU");
torch::autograd::variable_list GroupIndexSelectDim0Op::backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_output_group) {
TORCH_CHECK(grad_output_group.size() >= 2);
if (grad_output_group.size() == 2) {
// empty outputs
return torch::autograd::variable_list(1);
}
// remove redundant grads
auto group_size = grad_output_group.size() - 2;
grad_output_group.resize(group_size);

auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
auto output_shape_group =
ctx->saved_data["input_shape_group"].toSymIntVector();
grad_output_group.insert(
grad_output_group.end(), saved_tensors.begin(), saved_tensors.end());
static auto backward_op =
at::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_backward", "")
.typed<decltype(group_index_select_dim0_backward_impl_cpu)>();
auto res = backward_op.call(grad_output_group, output_shape_group);
// 1) Add group_size Variable()'s for indices
// Replace all empty tensors with Variable(). This must be done after the
// op.call to make __torch_dispatch__ work for the backward op.
std::fill(res.begin(), res.begin() + group_size, torch::autograd::Variable());
// 3) Add 1 Variable() for group_size
res.push_back({});
return res;
}

namespace {
Tensor index_select_dim0(
const Tensor& input,
const Tensor& indices,
std::optional<int64_t> /*consecutive_range_start*/,
std::optional<int64_t> /*consecutive_range_length*/,
std::optional<bool> /*skip_indices_sorting_fwd*/) {
return at::index_select(input, 0, indices);
}

Tensor bottom_k_per_row(
Expand Down Expand Up @@ -3132,13 +3262,14 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu);
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
"group_index_select_dim0",
fbgemm_gpu::group_index_select_dim0_decomposed);
DISPATCH_TO_CPU(
"group_index_select_dim0_gpu_impl",
fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu);
fbgemm_gpu::group_index_select_dim0_forward_impl_cpu);
DISPATCH_TO_CPU(
"group_index_select_dim0_gpu_backward",
fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu);
fbgemm_gpu::group_index_select_dim0_backward_impl_cpu);
DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row);
}

Expand All @@ -3147,11 +3278,14 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) {
}

TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) {
m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0);
m.impl(
"group_index_select_dim0",
&fbgemm_gpu::group_index_select_dim0_decomposed);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
// CPU group_index_select_dim0 is decomposable
m.impl(
"group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0));
"group_index_select_dim0",
TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed));
}
Loading

0 comments on commit 788cd2a

Please sign in to comment.