diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 9bea430ef..41ba190fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -9,8 +9,11 @@ #pragma once #include +#include #include #include +#include + #include namespace fbgemm_gpu { @@ -924,6 +927,44 @@ at::Tensor index_add_with_unique_indices_cuda( const int consecutive_range_start, const int consecutive_range_length); +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0( + at::TensorList input_group, + at::TensorList indices_group); + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size); + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref); + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size); + +class GroupIndexSelectDim0Op + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + at::TensorList all_indices_input, + const int64_t group_size); + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group); +}; + ///@ingroup sparse-data-cuda void group_index_select_or_add_cuda( const int64_t* input_ptrs, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index 7734cc69a..88d9ef2e6 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -2851,19 +2851,84 @@ Tensor pack_segments_cpu( const int64_t max_length) { return pack_segments_forward_cpu(t_in, lengths, max_length); } -namespace { -Tensor index_select_dim0( - const Tensor& input, - const Tensor& indices, - std::optional /*consecutive_range_start*/, - std::optional /*consecutive_range_length*/, - std::optional /*skip_indices_sorting_fwd*/) { - return at::index_select(input, 0, indices); + +torch::autograd::variable_list group_index_select_dim0_autograd_impl( + at::TensorList all_indices_input, + const int64_t group_size) { + return GroupIndexSelectDim0Op::apply(all_indices_input, group_size); +} + +std::pair, std::vector> +group_index_select_dim0_unpack( + at::TensorList all_indices_input, + const int64_t group_size) { + std::vector indices_group; + std::vector input_group; + + indices_group.reserve(group_size); + input_group.reserve(group_size); + + for (const auto i : c10::irange(group_size)) { + indices_group.push_back(all_indices_input[i]); + input_group.push_back(all_indices_input[group_size + i]); + } + + TORCH_CHECK(group_size == static_cast(indices_group.size())); + + return std::make_pair(input_group, indices_group); } torch::autograd::variable_list group_index_select_dim0( at::TensorList input_group, at::TensorList indices_group) { + const auto group_size = indices_group.size(); + std::vector output_group; + + if (group_size == 0) { + return std::vector(); + } + + // Pack input_group and indices_group into TensorList + std::vector all_indices_input_vec; + all_indices_input_vec.reserve(group_size * 2); + + for (const Tensor& index : indices_group) { + all_indices_input_vec.push_back(index); + } + for (const Tensor& input : input_group) { + all_indices_input_vec.push_back(input); + } + + at::TensorList all_indices_input_tensor = all_indices_input_vec; + + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto res = forward_op.call(all_indices_input_tensor, group_size); + TORCH_CHECK(res.size() == group_size + 2); + // only return the outputs (the first group_size elements) + res.resize(group_size); + return res; +} + +torch::autograd::variable_list group_index_select_dim0_forward_impl_cpu( + at::TensorList all_indices_input, + const int64_t group_size) { + throw std::runtime_error( + "group_index_select_dim0_forward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_backward_impl_cpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + throw std::runtime_error( + "group_index_select_dim0_backward_impl is not implemented for CPU"); +} + +torch::autograd::variable_list group_index_select_dim0_decomposed( + at::TensorList input_group, + at::TensorList indices_group) { int num_groups = input_group.size(); TORCH_CHECK(num_groups == (int)indices_group.size()) std::vector output_group; @@ -2874,18 +2939,83 @@ torch::autograd::variable_list group_index_select_dim0( return output_group; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu( +torch::autograd::variable_list GroupIndexSelectDim0Op::forward( + torch::autograd::AutogradContext* ctx, at::TensorList all_indices_input, const int64_t group_size) { - throw std::runtime_error( - "group_index_select_dim0_gpu_impl is not implemented for CPU"); + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") + .typed(); + auto result = forward_op.call(all_indices_input, group_size); + TORCH_CHECK(static_cast(result.size()) == group_size + 2); + + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + const auto input_dim = input_group[0].dim(); + std::vector input_shape_group; + input_shape_group.reserve(group_size * input_dim); + + for (const auto i : c10::irange(group_size)) { + const auto& input = input_group[i]; + // Copy input shape + auto input_shape = input.sym_sizes().vec(); + input_shape_group.insert( + input_shape_group.end(), input_shape.begin(), input_shape.end()); + } + + // save indices, args_tensor, saved_data + auto saved_tensors = std::vector(indices_group); + saved_tensors.insert( + saved_tensors.end(), result.cbegin() + group_size, result.cend()); + saved_tensors.push_back(input_group[0]); + ctx->save_for_backward(saved_tensors); + ctx->saved_data["input_shape_group"] = input_shape_group; + + return result; } -torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - throw std::runtime_error( - "group_index_select_dim0_gpu_backward is not implemented for CPU"); +torch::autograd::variable_list GroupIndexSelectDim0Op::backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output_group) { + TORCH_CHECK(grad_output_group.size() >= 2); + if (grad_output_group.size() == 2) { + // empty outputs + return torch::autograd::variable_list(1); + } + // remove redundant grads + auto group_size = grad_output_group.size() - 2; + grad_output_group.resize(group_size); + + auto saved_tensors = ctx->get_saved_variables(); + TORCH_CHECK(saved_tensors.size() == group_size + 3); + auto output_shape_group = + ctx->saved_data["input_shape_group"].toSymIntVector(); + grad_output_group.insert( + grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); + static auto backward_op = + at::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_backward", "") + .typed(); + auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill(res.begin(), res.begin() + group_size, torch::autograd::Variable()); + // 3) Add 1 Variable() for group_size + res.push_back({}); + return res; +} + +namespace { +Tensor index_select_dim0( + const Tensor& input, + const Tensor& indices, + std::optional /*consecutive_range_start*/, + std::optional /*consecutive_range_length*/, + std::optional /*skip_indices_sorting_fwd*/) { + return at::index_select(input, 0, indices); } Tensor bottom_k_per_row( @@ -3132,13 +3262,14 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "pack_segments_backward", fbgemm_gpu::pack_segments_backward_cpu); DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0); DISPATCH_TO_CPU( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); + "group_index_select_dim0", + fbgemm_gpu::group_index_select_dim0_decomposed); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu); + fbgemm_gpu::group_index_select_dim0_forward_impl_cpu); DISPATCH_TO_CPU( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu); + fbgemm_gpu::group_index_select_dim0_backward_impl_cpu); DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row); } @@ -3147,11 +3278,14 @@ TORCH_LIBRARY_IMPL(fbgemm, Autograd, m) { } TORCH_LIBRARY_IMPL(fbgemm, AutogradCPU, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); + m.impl( + "group_index_select_dim0", + &fbgemm_gpu::group_index_select_dim0_decomposed); } TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { // CPU group_index_select_dim0 is decomposable m.impl( - "group_index_select_dim0", TORCH_FN(fbgemm_gpu::group_index_select_dim0)); + "group_index_select_dim0", + TORCH_FN(fbgemm_gpu::group_index_select_dim0_decomposed)); } diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 6325017e8..0c3966fc3 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -193,442 +193,346 @@ class IndexSelectDim0GPUOp } }; -std::pair, std::vector> -group_index_select_dim0_unpack( +// need to combine input_group and indices_group into one tensor list +// to get this working with autograd. +static torch::autograd::variable_list group_index_select_dim0_forward_impl_gpu( at::TensorList all_indices_input, const int64_t group_size) { - std::vector indices_group; - std::vector input_group; + // Unpack from TensorList + auto [input_group, indices_group] = + group_index_select_dim0_unpack(all_indices_input, group_size); + + // args_tensor stores kernel arguments: + // input_ptrs (group_size int64_t elements) + // output_ptrs (group_size int64_t elements) + // indices_ptrs (group_size int64_t elements) + // warp_offsets_group (group_size + 1 int64_t elements) + // num_cols_group (group_size int32_t elements) + int64_t args_ptrs_offsets[NUM_ARGS + 1]; + + const int64_t numels_num_cols_group_64 = + compute_num_int64s(group_size); + + // Initialize offsets + args_ptrs_offsets[P_input_ptrs] = group_size; + args_ptrs_offsets[P_output_ptrs] = group_size; + args_ptrs_offsets[P_indices_ptrs] = group_size; + args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; + args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; + + // Compute offsets + int64_t offset = 0; + auto next = args_ptrs_offsets[0]; + for (const auto i : c10::irange(NUM_ARGS)) { + args_ptrs_offsets[i] = offset; + offset += next; + next = args_ptrs_offsets[i + 1]; + } + // Total number of int64_t elements required + args_ptrs_offsets[NUM_ARGS] = offset; + + // Allocate memory for GroupIndexSelectArgs + at::Tensor args_tensor = at::empty( + {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, + at::TensorOptions().dtype(at::kByte).pinned_memory(true)); + + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + + // Initialize raw pointers to point to Tensor args_tensor + int64_t* input_ptrs = nullptr; + int64_t* output_ptrs = nullptr; + int64_t* indices_ptrs = nullptr; + int64_t* warp_offsets_group = nullptr; + int32_t* num_cols_group = nullptr; + + // Offset host pointers + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + auto& first_input = input_group[0]; + auto& first_indices = indices_group[0]; + + const int input_dim = first_input.dim(); + const int num_output_rows = first_indices.size(0); + const int num_input_rows = first_input.size(0); + Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); + const int num_cols = input_reshaped.size(1); + const int cols_per_warp = get_group_index_select_cols_per_warp(); + int64_t warp_offset = 0; + bool use_var_cols = false; + + // Allocate memory for output_group + std::vector output_group; + output_group.reserve(group_size + 2); - indices_group.reserve(group_size); - input_group.reserve(group_size); + // We need to store contiguous inputs and indices outside the for-loop to + // guarantee that the contiguous tensors will outlive the kernel + // computation + std::vector> input_contigs; + std::vector> index_contigs; + input_contigs.reserve(group_size); + index_contigs.reserve(group_size); + // For each group, copy input to output for (const auto i : c10::irange(group_size)) { - indices_group.push_back(all_indices_input[i]); - input_group.push_back(all_indices_input[group_size + i]); - } + const auto& input = input_group[i]; + const auto& indices = indices_group[i]; - TORCH_CHECK(group_size == static_cast(indices_group.size())); + // Verify that all input tensors have the same number of dimensions + TORCH_CHECK( + input_dim == input.dim(), + "All inputs in group_index_select must have the same number of dimensions"); - return std::make_pair(input_group, indices_group); -} + // Verify that all tensors are on the same GPU + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); -class GroupIndexSelectDim0GPUOp - : public torch::autograd::Function { - public: - // need to combine input_group and indices_group into one tensor list - // to get this working with autograd. - static torch::autograd::variable_list forward_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - // Unpack from TensorList - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - - // args_tensor stores kernel arguments: - // input_ptrs (group_size int64_t elements) - // output_ptrs (group_size int64_t elements) - // indices_ptrs (group_size int64_t elements) - // warp_offsets_group (group_size + 1 int64_t elements) - // num_cols_group (group_size int32_t elements) - int64_t args_ptrs_offsets[NUM_ARGS + 1]; - - const int64_t numels_num_cols_group_64 = - compute_num_int64s(group_size); - - // Initialize offsets - args_ptrs_offsets[P_input_ptrs] = group_size; - args_ptrs_offsets[P_output_ptrs] = group_size; - args_ptrs_offsets[P_indices_ptrs] = group_size; - args_ptrs_offsets[P_warp_offsets_group_ptrs] = group_size + 1; - args_ptrs_offsets[P_num_cols_group_ptrs] = numels_num_cols_group_64; - - // Compute offsets - int64_t offset = 0; - auto next = args_ptrs_offsets[0]; - for (const auto i : c10::irange(NUM_ARGS)) { - args_ptrs_offsets[i] = offset; - offset += next; - next = args_ptrs_offsets[i + 1]; + auto num_output_rows_ = indices.size(0); + + // Verify that all input tensors have the same shape[0] + TORCH_CHECK( + num_output_rows == num_output_rows_, + "The number of indices to be selected must be the same for the entire group"); + const auto input_reshaped_ = input.reshape({input.size(0), -1}); + + // Number of columns can be different + auto num_cols_ = input_reshaped_.size(1); + auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; + + if (num_cols != num_cols_) { + use_var_cols = true; } - // Total number of int64_t elements required - args_ptrs_offsets[NUM_ARGS] = offset; - - // Allocate memory for GroupIndexSelectArgs - at::Tensor args_tensor = at::empty( - {static_cast(args_ptrs_offsets[NUM_ARGS] * sizeof(int64_t))}, - at::TensorOptions().dtype(at::kByte).pinned_memory(true)); - - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - - // Initialize raw pointers to point to Tensor args_tensor - int64_t* input_ptrs = nullptr; - int64_t* output_ptrs = nullptr; - int64_t* indices_ptrs = nullptr; - int64_t* warp_offsets_group = nullptr; - int32_t* num_cols_group = nullptr; - - // Offset host pointers - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - auto& first_input = input_group[0]; - auto& first_indices = indices_group[0]; - - const int input_dim = first_input.dim(); - const int num_output_rows = first_indices.size(0); - const int num_input_rows = first_input.size(0); - Tensor input_reshaped = first_input.reshape({num_input_rows, -1}); - const int num_cols = input_reshaped.size(1); - const int cols_per_warp = get_group_index_select_cols_per_warp(); - int64_t warp_offset = 0; - bool use_var_cols = false; - - // Allocate memory for output_group - std::vector output_group; - output_group.reserve(group_size + 2); - - // We need to store contiguous inputs and indices outside the for-loop to - // guarantee that the contiguous tensors will outlive the kernel + + // Create output pointers + auto input_shape = input.sizes().vec(); + input_shape[0] = num_output_rows_; + Tensor output = at::empty(input_shape, input.options()); + // Ensure that the allocated output is contiguous + TORCH_CHECK(output.is_contiguous()) + output_group.push_back(output); + + // Store input and indices contigs to keep them alive during the kernel // computation - std::vector> input_contigs; - std::vector> index_contigs; - input_contigs.reserve(group_size); - index_contigs.reserve(group_size); - - // For each group, copy input to output - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - const auto& indices = indices_group[i]; - - // Verify that all input tensors have the same number of dimensions - TORCH_CHECK( - input_dim == input.dim(), - "All inputs in group_index_select must have the same number of dimensions"); - - // Verify that all tensors are on the same GPU - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(input, indices); - - auto num_output_rows_ = indices.size(0); - - // Verify that all input tensors have the same shape[0] - TORCH_CHECK( - num_output_rows == num_output_rows_, - "The number of indices to be selected must be the same for the entire group"); - const auto input_reshaped_ = input.reshape({input.size(0), -1}); - - // Number of columns can be different - auto num_cols_ = input_reshaped_.size(1); - auto warps_per_row = (num_cols_ + cols_per_warp - 1) / cols_per_warp; - - if (num_cols != num_cols_) { - use_var_cols = true; - } - - // Create output pointers - auto input_shape = input.sizes().vec(); - input_shape[0] = num_output_rows_; - Tensor output = at::empty(input_shape, input.options()); - // Ensure that the allocated output is contiguous - TORCH_CHECK(output.is_contiguous()) - output_group.push_back(output); - - // Store input and indices contigs to keep them alive during the kernel - // computation - input_contigs.push_back(input.expect_contiguous()); - index_contigs.push_back(indices.expect_contiguous()); - - // Store args - input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); - output_ptrs[i] = reinterpret_cast(output.data_ptr()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - warp_offsets_group[i] = warp_offset; - num_cols_group[i] = num_cols_; - - warp_offset += warps_per_row * num_output_rows; - } + input_contigs.push_back(input.expect_contiguous()); + index_contigs.push_back(indices.expect_contiguous()); - // Store the last offset - warp_offsets_group[group_size] = warp_offset; - - // Transfer args tensor to GPU - args_tensor = args_tensor.to( - first_input.device(), - /*non_blocking=*/true); - - // Offset raw ptrs in GPU memory - offset_args( - &input_ptrs, - &output_ptrs, - &indices_ptrs, - &warp_offsets_group, - &num_cols_group, - reinterpret_cast(args_tensor.data_ptr()), - args_ptrs_offsets); - - int64_t saved_data[] = { - static_cast(group_size), - use_var_cols, - reinterpret_cast(warp_offsets_group), - reinterpret_cast(num_cols_group), - warp_offset, - }; - auto saved_data_t = at::empty( - {sizeof(saved_data) / sizeof(int64_t)}, - at::TensorOptions().dtype(at::kLong)); - TORCH_CHECK(saved_data_t.is_contiguous()); - memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); - - group_index_select_or_add_cuda( - input_ptrs, - output_ptrs, - indices_ptrs, - warp_offsets_group, - num_cols_group, - first_input.scalar_type(), - first_indices.scalar_type(), - first_input.device().index(), - num_output_rows, - /*total_num_warps=*/warp_offset, - group_size, - /*use_index_select=*/true, - use_var_cols); - - output_group.push_back(args_tensor); - output_group.push_back(saved_data_t); - - // return format: - // (group_size outputs, 1 args_tensor, 1 saved_data) - return output_group; + // Store args + input_ptrs[i] = reinterpret_cast(input_contigs[i]->data_ptr()); + output_ptrs[i] = reinterpret_cast(output.data_ptr()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); + warp_offsets_group[i] = warp_offset; + num_cols_group[i] = num_cols_; + + warp_offset += warps_per_row * num_output_rows; } - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - at::TensorList all_indices_input, - const int64_t group_size) { - at::AutoDispatchBelowADInplaceOrView guard; - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto result = forward_op.call(all_indices_input, group_size); - TORCH_CHECK(static_cast(result.size()) == group_size + 2); - - auto [input_group, indices_group] = - group_index_select_dim0_unpack(all_indices_input, group_size); - const auto input_dim = input_group[0].dim(); - std::vector input_shape_group; - input_shape_group.reserve(group_size * input_dim); - - for (const auto i : c10::irange(group_size)) { - const auto& input = input_group[i]; - // Copy input shape - auto input_shape = input.sym_sizes().vec(); - input_shape_group.insert( - input_shape_group.end(), input_shape.begin(), input_shape.end()); - } + // Store the last offset + warp_offsets_group[group_size] = warp_offset; + + // Transfer args tensor to GPU + args_tensor = args_tensor.to( + first_input.device(), + /*non_blocking=*/true); + + // Offset raw ptrs in GPU memory + offset_args( + &input_ptrs, + &output_ptrs, + &indices_ptrs, + &warp_offsets_group, + &num_cols_group, + reinterpret_cast(args_tensor.data_ptr()), + args_ptrs_offsets); + + int64_t saved_data[] = { + static_cast(group_size), + use_var_cols, + reinterpret_cast(warp_offsets_group), + reinterpret_cast(num_cols_group), + warp_offset, + }; + auto saved_data_t = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_t.is_contiguous()); + memcpy(saved_data_t.data_ptr(), saved_data, sizeof(saved_data)); + + group_index_select_or_add_cuda( + input_ptrs, + output_ptrs, + indices_ptrs, + warp_offsets_group, + num_cols_group, + first_input.scalar_type(), + first_indices.scalar_type(), + first_input.device().index(), + num_output_rows, + /*total_num_warps=*/warp_offset, + group_size, + /*use_index_select=*/true, + use_var_cols); + + output_group.push_back(args_tensor); + output_group.push_back(saved_data_t); + + // return format: + // (group_size outputs, 1 args_tensor, 1 saved_data) + return output_group; +} - // save indices, args_tensor, saved_data - auto saved_tensors = std::vector(indices_group); - saved_tensors.insert( - saved_tensors.end(), result.cbegin() + group_size, result.cend()); - saved_tensors.push_back(input_group[0]); - ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; +static torch::autograd::variable_list group_index_select_dim0_backward_impl_gpu( + at::TensorList all_inputs, + c10::SymIntArrayRef output_shape_group_ref) { + TORCH_CHECK(all_inputs.size() > 2); + + const int64_t group_size = (all_inputs.size() - 3) / 2; + + Tensor fwd_input = all_inputs[2 * group_size + 2]; + const int64_t output_dim = fwd_input.dim(); + Tensor saved_data = all_inputs[2 * group_size + 1]; + Tensor args_tensor_old = all_inputs[2 * group_size]; + Tensor first_indices = all_inputs[group_size]; + + auto grad_output_group = std::vector( + all_inputs.cbegin(), all_inputs.cbegin() + group_size); + std::vector output_shape_group; + output_shape_group.reserve(output_shape_group_ref.size()); + for (const auto& i : output_shape_group_ref) { + output_shape_group.push_back(i.as_int_unchecked()); + } - return result; + auto indices_group = std::vector( + all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); + + // Retrieve saved data + TORCH_CHECK(saved_data.device() == at::kCPU); + TORCH_CHECK(saved_data.is_contiguous()); + int64_t* saved_data_ptr = saved_data.data_ptr(); + // Check that the size is the same + TORCH_CHECK(saved_data_ptr[0] == group_size); + const bool use_var_cols = saved_data_ptr[1]; + int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); + int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); + int64_t total_num_warps = saved_data_ptr[4]; + + // We checked in forward that all output rows are the same for all member + // in the group + const int num_input_rows = grad_output_group[0].size(0); + + std::vector outputs; + // Returning 3 outputs: + // 1) group_size Variable()'s for indices + // 2) group_size gradients for inputs + // 3) 1 Variable() for group_size + outputs.reserve(group_size * 2 + 1); + + // 1) Add group_size Variable()'s for indices + // c10::irange cannot be used in here as it + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. + for (auto i = 0; i < group_size; i++) { + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } - static torch::autograd::variable_list backward_impl( - at::TensorList all_inputs, - c10::SymIntArrayRef output_shape_group_ref) { - TORCH_CHECK(all_inputs.size() > 2); - - const int64_t group_size = (all_inputs.size() - 3) / 2; - - Tensor fwd_input = all_inputs[2 * group_size + 2]; - const int64_t output_dim = fwd_input.dim(); - Tensor saved_data = all_inputs[2 * group_size + 1]; - Tensor args_tensor_old = all_inputs[2 * group_size]; - Tensor first_indices = all_inputs[group_size]; - - auto grad_output_group = std::vector( - all_inputs.cbegin(), all_inputs.cbegin() + group_size); - std::vector output_shape_group; - output_shape_group.reserve(output_shape_group_ref.size()); - for (const auto& i : output_shape_group_ref) { - output_shape_group.push_back(i.as_int_unchecked()); - } + // Allocate Tensor for ptrs of grad output and input, and indices + Tensor args_tensor = at::empty( + {group_size * 3}, + at::TensorOptions().dtype(at::kLong).pinned_memory(true)); + // Ensure that args_tensor is contiguous + TORCH_CHECK(args_tensor.is_contiguous()); + int64_t* grad_output_ptrs = args_tensor.data_ptr(); + int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; + int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; + + int64_t group_grad_input_numel = 0; + std::vector grad_input_numels; + grad_input_numels.reserve(group_size); + + // We need to store contiguous gradients outside the for-loop to guarantee + // that the contiguous tensors will outlive the kernel computation + std::vector> grad_output_contigs; + grad_output_contigs.reserve(group_size); - auto indices_group = std::vector( - all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size); - - // Retrieve saved data - TORCH_CHECK(saved_data.device() == at::kCPU); - TORCH_CHECK(saved_data.is_contiguous()); - int64_t* saved_data_ptr = saved_data.data_ptr(); - // Check that the size is the same - TORCH_CHECK(saved_data_ptr[0] == group_size); - const bool use_var_cols = saved_data_ptr[1]; - int64_t* warp_offsets_group = reinterpret_cast(saved_data_ptr[2]); - int32_t* num_cols_group = reinterpret_cast(saved_data_ptr[3]); - int64_t total_num_warps = saved_data_ptr[4]; - - // We checked in forward that all output rows are the same for all member - // in the group - const int num_input_rows = grad_output_group[0].size(0); - - std::vector outputs; - // Returning 3 outputs: - // 1) group_size Variable()'s for indices - // 2) group_size gradients for inputs - // 3) 1 Variable() for group_size - outputs.reserve(group_size * 2 + 1); - - // 1) Add group_size Variable()'s for indices - // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable. - // Add empty tensor with zero size here to make __torch_dispatch__ work for - // the backward op. Those empty tensors will be replaced with - // torch::autograd::Variable() outside of the op call. - for (auto i = 0; i < group_size; i++) { - outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); - } + for (const auto i : c10::irange(group_size)) { + const auto& grad = grad_output_group[i]; + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - // Allocate Tensor for ptrs of grad output and input, and indices - Tensor args_tensor = at::empty( - {group_size * 3}, - at::TensorOptions().dtype(at::kLong).pinned_memory(true)); - // Ensure that args_tensor is contiguous - TORCH_CHECK(args_tensor.is_contiguous()); - int64_t* grad_output_ptrs = args_tensor.data_ptr(); - int64_t* grad_input_ptrs = args_tensor.data_ptr() + group_size; - int64_t* indices_ptrs = args_tensor.data_ptr() + 2 * group_size; - - int64_t group_grad_input_numel = 0; - std::vector grad_input_numels; - grad_input_numels.reserve(group_size); - - // We need to store contiguous gradients outside the for-loop to guarantee - // that the contiguous tensors will outlive the kernel computation - std::vector> grad_output_contigs; - grad_output_contigs.reserve(group_size); - - for (const auto i : c10::irange(group_size)) { - const auto& grad = grad_output_group[i]; - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad, first_indices); - - // Store grad contigs to keep them alive during the kernel computation - grad_output_contigs.push_back(grad.expect_contiguous()); - - // Compute the total number of elements for all grad_inputs - int64_t grad_input_numel = output_shape_group[i * output_dim]; - for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { - grad_input_numel *= output_shape_group[j]; - } - grad_input_numels.push_back(grad_input_numel); - group_grad_input_numel += grad_input_numel; - - // Put all grad output/input pointers in an array - grad_output_ptrs[i] = - reinterpret_cast(grad_output_contigs[i]->data_ptr()); - } + // Store grad contigs to keep them alive during the kernel computation + grad_output_contigs.push_back(grad.expect_contiguous()); - // Allocate a big tensor to avoid calling many small elementwise kernels - const auto group_grad_input = - at::zeros({group_grad_input_numel}, fwd_input.options()); - TORCH_CHECK(group_grad_input.is_contiguous()); + // Compute the total number of elements for all grad_inputs + int64_t grad_input_numel = output_shape_group[i * output_dim]; + for (auto j = (i * output_dim) + 1; j < (i + 1) * output_dim; j++) { + grad_input_numel *= output_shape_group[j]; + } + grad_input_numels.push_back(grad_input_numel); + group_grad_input_numel += grad_input_numel; - // Split to output_group - auto output_group = group_grad_input.split(grad_input_numels, 0); + // Put all grad output/input pointers in an array + grad_output_ptrs[i] = + reinterpret_cast(grad_output_contigs[i]->data_ptr()); + } - TORCH_CHECK(output_group.size() == static_cast(group_size)); + // Allocate a big tensor to avoid calling many small elementwise kernels + const auto group_grad_input = + at::zeros({group_grad_input_numel}, fwd_input.options()); + TORCH_CHECK(group_grad_input.is_contiguous()); - // Reshape grad inputs and obtain their pointers - for (int i = 0; i < group_size; i++) { - const auto grad_input_shape = std::vector( - output_shape_group.begin() + i * output_dim, - output_shape_group.begin() + (i + 1) * output_dim); - output_group[i] = output_group[i].reshape(grad_input_shape); - TORCH_CHECK(output_group[i].is_contiguous()); - grad_input_ptrs[i] = - reinterpret_cast(output_group[i].data_ptr()); + // Split to output_group + auto output_group = group_grad_input.split(grad_input_numels, 0); - // 2) Add group_size gradients for inputs - outputs.push_back(output_group[i]); - } + TORCH_CHECK(output_group.size() == static_cast(group_size)); - // Calculate indices_ptrs - std::vector> index_contigs; - index_contigs.reserve(group_size); - for (const auto i : c10::irange(group_size)) { - const auto& indices = indices_group[i]; - index_contigs.push_back(indices.expect_contiguous()); - indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); - } + // Reshape grad inputs and obtain their pointers + for (int i = 0; i < group_size; i++) { + const auto grad_input_shape = std::vector( + output_shape_group.begin() + i * output_dim, + output_shape_group.begin() + (i + 1) * output_dim); + output_group[i] = output_group[i].reshape(grad_input_shape); + TORCH_CHECK(output_group[i].is_contiguous()); + grad_input_ptrs[i] = reinterpret_cast(output_group[i].data_ptr()); - // Transfer grad output pointers to GPU - args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); - - group_index_select_or_add_cuda( - args_tensor.data_ptr(), - args_tensor.data_ptr() + group_size, - args_tensor.data_ptr() + 2 * group_size, - warp_offsets_group, - num_cols_group, - fwd_input.scalar_type(), - first_indices.scalar_type(), - fwd_input.device().index(), - num_input_rows, - total_num_warps, - group_size, - /*use_index_select=*/false, - use_var_cols); - - return outputs; + // 2) Add group_size gradients for inputs + outputs.push_back(output_group[i]); } - static torch::autograd::variable_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output_group) { - TORCH_CHECK(grad_output_group.size() >= 2); - if (grad_output_group.size() == 2) { - // empty outputs - return torch::autograd::variable_list(1); - } - // remove redundant grads - auto group_size = grad_output_group.size() - 2; - grad_output_group.resize(group_size); - - auto saved_tensors = ctx->get_saved_variables(); - TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); - grad_output_group.insert( - grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); - static auto backward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow( - "fbgemm::group_index_select_dim0_gpu_backward", "") - .typed(); - auto res = backward_op.call(grad_output_group, output_shape_group); - // 1) Add group_size Variable()'s for indices - // Replace all empty tensors with Variable(). This must be done after the - // op.call to make __torch_dispatch__ work for the backward op. - std::fill( - res.begin(), res.begin() + group_size, torch::autograd::Variable()); - // 3) Add 1 Variable() for group_size - res.push_back({}); - return res; + // Calculate indices_ptrs + std::vector> index_contigs; + index_contigs.reserve(group_size); + for (const auto i : c10::irange(group_size)) { + const auto& indices = indices_group[i]; + index_contigs.push_back(indices.expect_contiguous()); + indices_ptrs[i] = reinterpret_cast(index_contigs[i]->data_ptr()); } -}; + + // Transfer grad output pointers to GPU + args_tensor = args_tensor.to(first_indices.device(), /*non_blocking=*/true); + + group_index_select_or_add_cuda( + args_tensor.data_ptr(), + args_tensor.data_ptr() + group_size, + args_tensor.data_ptr() + 2 * group_size, + warp_offsets_group, + num_cols_group, + fwd_input.scalar_type(), + first_indices.scalar_type(), + fwd_input.device().index(), + num_input_rows, + total_num_warps, + group_size, + /*use_index_select=*/false, + use_var_cols); + + return outputs; +} Tensor pack_segments_cuda( const Tensor& t_in, @@ -654,45 +558,6 @@ Tensor index_select_dim0_gpu( user_skip_indices_sorting_fwd && !c10::InferenceMode::is_enabled())[0]; } -torch::autograd::variable_list group_index_select_dim0_gpu_impl( - at::TensorList all_indices_input, - const int64_t group_size) { - return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size); -} - -torch::autograd::variable_list group_index_select_dim0_gpu( - at::TensorList input_group, - at::TensorList indices_group) { - const auto group_size = indices_group.size(); - std::vector output_group; - - if (group_size == 0) { - return std::vector(); - } - - // Pack input_group and indices_group into TensorList - std::vector all_indices_input_vec; - all_indices_input_vec.reserve(group_size * 2); - - for (const Tensor& index : indices_group) { - all_indices_input_vec.push_back(index); - } - for (const Tensor& input : input_group) { - all_indices_input_vec.push_back(input); - } - - at::TensorList all_indices_input_tensor = all_indices_input_vec; - - static auto forward_op = - torch::Dispatcher::singleton() - .findSchemaOrThrow("fbgemm::group_index_select_dim0_gpu_impl", "") - .typed(); - auto res = forward_op.call(all_indices_input_tensor, group_size); - TORCH_CHECK(res.size() == group_size + 2); - // only return the outputs (the first group_size elements) - res.resize(group_size); - return res; -} } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { @@ -721,17 +586,17 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA("index_select_dim0", fbgemm_gpu::index_select_dim0_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_impl", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::forward_impl); + fbgemm_gpu::group_index_select_dim0_forward_impl_gpu); DISPATCH_TO_CUDA( "group_index_select_dim0_gpu_backward", - fbgemm_gpu::GroupIndexSelectDim0GPUOp::backward_impl); + fbgemm_gpu::group_index_select_dim0_backward_impl_gpu); DISPATCH_TO_CUDA( - "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu); + "group_index_select_dim0", fbgemm_gpu::group_index_select_dim0); } TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) { - m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu); + m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0); m.impl( "group_index_select_dim0_gpu_impl", - &fbgemm_gpu::group_index_select_dim0_gpu_impl); + &fbgemm_gpu::group_index_select_dim0_autograd_impl); }