From c2160057e8320f2ab058d535366a84d5d3cc4f94 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 9 May 2024 08:00:07 -0700 Subject: [PATCH] Add memchecks to jagged tensor ops (#2572) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2572 - Add memchecks to jagged tensor ops Reviewed By: spcyppt Differential Revision: D57123188 fbshipit-source-id: 2a7831dd8f9c56067411911c5d445f824a410433 --- fbgemm_gpu/src/jagged_tensor_ops/common.cuh | 177 +++++++++--------- ...e_elementwise_add_jagged_output_forward.cu | 102 +++++----- .../jagged_index_add_2d_forward.cu | 38 ++-- .../jagged_index_select_2d_forward.cu | 37 ++-- .../jagged_unique_indices.cu | 75 +++++--- .../keyed_jagged_index_select_dim1.cu | 138 +++++++------- 6 files changed, 290 insertions(+), 277 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh index da910d6f6..15e735b8d 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/common.cuh +++ b/fbgemm_gpu/src/jagged_tensor_ops/common.cuh @@ -118,11 +118,11 @@ DEVICE_INLINE bool walk_down_tensor_storage_tree_( template __global__ __launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_( - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 x_values, StackArray x_offsets, - const at::PackedTensorAccessor32 y, - at::PackedTensorAccessor32 output, + const pta::PackedTensorAccessor32 y, + pta::PackedTensorAccessor32 output, StackArray jagged_dims, F f, const scalar_t padding_value) { @@ -243,28 +243,28 @@ void jagged_dense_elementwise_dense_output_( const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)}); Tensor output_reshaped = output.view(y_reshaped.sizes()); -#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ - { \ - std::vector x_offsets_contig; \ - x_offsets_contig.resize(num_jagged_dim); \ - StackArray x_offset_ptrs; \ - x_offset_ptrs.ndim = num_jagged_dim; \ - for (int d = 0; d < num_jagged_dim; ++d) { \ - x_offsets_contig[d] = x_offsets[d].contiguous(); \ - x_offset_ptrs.vals[d] = \ - x_offsets_contig[d].template data_ptr(); \ - } \ - jagged_dense_elementwise_dense_output_kernel_ \ - <<>>( \ - x_values.packed_accessor32(), \ - x_offset_ptrs, \ - y_reshaped \ - .packed_accessor32(), \ - output_reshaped \ - .packed_accessor32(), \ - jagged_dims_tensor, \ - f, \ - padding_value); \ +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + } \ + [[maybe_unused]] const auto func_name = \ + "jagged_dense_elementwise_dense_output_kernel_"; \ + jagged_dense_elementwise_dense_output_kernel_ \ + <<>>( \ + MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \ + x_offset_ptrs, \ + MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \ + MAKE_PTA_WITH_NAME(func_name, output_reshaped, scalar_t, 3, 32), \ + jagged_dims_tensor, \ + f, \ + padding_value); \ } JAGGED_TENSOR_DISPATCH_DIMS(); @@ -289,13 +289,13 @@ Tensor jagged_dense_elementwise_dense_output_( template __global__ __launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_( - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 x_values, StackArray x_offsets, StackArray x_offsets_sizes, - const at::PackedTensorAccessor32 y_0, - const at::PackedTensorAccessor32 y_1, - at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 y_0, + const pta::PackedTensorAccessor32 y_1, + pta::PackedTensorAccessor32 output_values, StackArray jagged_dims, F f) { @@ -380,9 +380,10 @@ __launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output template __global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_( - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 rows, - at::PackedTensorAccessor32 cols, + const pta::PackedTensorAccessor32 + offsets, + pta::PackedTensorAccessor32 rows, + pta::PackedTensorAccessor32 cols, int nnz, int B) { struct SharedMemory smem; @@ -519,13 +520,13 @@ fh(__half& v_out, const __half& x, const __half& y0, const __half& y1, F f) { template __global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_( - at::PackedTensorAccessor32 values, - const at::PackedTensorAccessor32 + pta::PackedTensorAccessor32 values, + const pta::PackedTensorAccessor32 x_values, - const at::PackedTensorAccessor32 y0, - const at::PackedTensorAccessor32 y1, - const at::PackedTensorAccessor32 rows, - const at::PackedTensorAccessor32 cols, + const pta::PackedTensorAccessor32 y0, + const pta::PackedTensorAccessor32 y1, + const pta::PackedTensorAccessor32 rows, + const pta::PackedTensorAccessor32 cols, const int nnz, const int E, F f) { @@ -692,37 +693,39 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt( return matches; } -#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ - { \ - dim3 threads, blocks; \ - StackArray jagged_dims_tensor; \ - std::tie(threads, blocks, jagged_dims_tensor) = \ - check_shape_and_partition_(x_values, x_offsets, y); \ - blocks.x = div_round_up(x_values.size(0), threads.y); \ - std::vector x_offsets_contig; \ - x_offsets_contig.resize(num_jagged_dim); \ - StackArray x_offset_ptrs; \ - x_offset_ptrs.ndim = num_jagged_dim; \ - StackArray x_offset_sizes; \ - x_offset_sizes.ndim = num_jagged_dim; \ - for (int d = 0; d < num_jagged_dim; ++d) { \ - x_offsets_contig[d] = x_offsets[d].contiguous(); \ - x_offset_ptrs.vals[d] = \ - x_offsets_contig[d].template data_ptr(); \ - x_offset_sizes.vals[d] = x_offsets[d].numel(); \ - } \ - jagged_dense_dense_elementwise_jagged_output_kernel_< \ - NUM_JAGGED_DIM, \ - index_t><<>>( \ - x_values.packed_accessor32(), \ - x_offset_ptrs, \ - x_offset_sizes, \ - y_reshaped.packed_accessor32(), \ - y_reshaped.packed_accessor32(), \ - output_values.packed_accessor32(), \ - jagged_dims_tensor, \ - [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ - -> scalar_t { return f(x, y); }); \ +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + [[maybe_unused]] const auto func_name = \ + "jagged_dense_dense_elementwise_jagged_output_kernel_"; \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \ + x_offset_ptrs, \ + x_offset_sizes, \ + MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \ + MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \ + MAKE_PTA_WITH_NAME(func_name, output_values, scalar_t, 2, 32), \ + jagged_dims_tensor, \ + [f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \ + -> scalar_t { return f(x, y); }); \ } ///@addtogroup jagged-tensor-ops-cuda @@ -810,18 +813,19 @@ void jagged_dense_elementwise_jagged_output_opt_( } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name1 = + "jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_"; +#endif jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< index_t> <<>>( - x_offsets[0] - .packed_accessor32(), - t_rows_after_bs - .packed_accessor32(), - t_cols_after_bs - .packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32), nnz, B); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -831,21 +835,20 @@ void jagged_dense_elementwise_jagged_output_opt_( if (blocks.y > 65535) { blocks.y = 65535; } +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name2 = + "jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_"; +#endif jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< index_t> <<>>( - output_values - .packed_accessor32(), - x_values - .packed_accessor32(), - y_reshaped - .packed_accessor32(), - y_reshaped - .packed_accessor32(), - t_rows_after_bs - .packed_accessor32(), - t_cols_after_bs - .packed_accessor32(), + MAKE_PTA_WITH_NAME( + func_name2, output_values, c10::Half, 2, 32), + MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32), + MAKE_PTA_WITH_NAME(func_name2, y_reshaped, c10::Half, 3, 32), + MAKE_PTA_WITH_NAME(func_name2, y_reshaped, c10::Half, 3, 32), + MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32), nnz, E, [f] __device__(__half x, __half y0, __half) -> __half { diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu index df40c0caf..c855cc174 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu @@ -12,36 +12,38 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { -#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ - { \ - dim3 threads, blocks; \ - StackArray jagged_dims_tensor; \ - std::tie(threads, blocks, jagged_dims_tensor) = \ - check_shape_and_partition_(x_values, x_offsets, y_0); \ - blocks.x = div_round_up(x_values.size(0), threads.y); \ - std::vector x_offsets_contig; \ - x_offsets_contig.resize(num_jagged_dim); \ - StackArray x_offset_ptrs; \ - x_offset_ptrs.ndim = num_jagged_dim; \ - StackArray x_offset_sizes; \ - x_offset_sizes.ndim = num_jagged_dim; \ - for (int d = 0; d < num_jagged_dim; ++d) { \ - x_offsets_contig[d] = x_offsets[d].contiguous(); \ - x_offset_ptrs.vals[d] = \ - x_offsets_contig[d].template data_ptr(); \ - x_offset_sizes.vals[d] = x_offsets[d].numel(); \ - } \ - jagged_dense_dense_elementwise_jagged_output_kernel_< \ - NUM_JAGGED_DIM, \ - index_t><<>>( \ - x_values.packed_accessor32(), \ - x_offset_ptrs, \ - x_offset_sizes, \ - y_0_reshaped.packed_accessor32(), \ - y_1_reshaped.packed_accessor32(), \ - output_values.packed_accessor32(), \ - jagged_dims_tensor, \ - f); \ +#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \ + { \ + dim3 threads, blocks; \ + StackArray jagged_dims_tensor; \ + std::tie(threads, blocks, jagged_dims_tensor) = \ + check_shape_and_partition_(x_values, x_offsets, y_0); \ + blocks.x = div_round_up(x_values.size(0), threads.y); \ + std::vector x_offsets_contig; \ + x_offsets_contig.resize(num_jagged_dim); \ + StackArray x_offset_ptrs; \ + x_offset_ptrs.ndim = num_jagged_dim; \ + StackArray x_offset_sizes; \ + x_offset_sizes.ndim = num_jagged_dim; \ + for (int d = 0; d < num_jagged_dim; ++d) { \ + x_offsets_contig[d] = x_offsets[d].contiguous(); \ + x_offset_ptrs.vals[d] = \ + x_offsets_contig[d].template data_ptr(); \ + x_offset_sizes.vals[d] = x_offsets[d].numel(); \ + } \ + [[maybe_unused]] const auto func_name = \ + "jagged_dense_dense_elementwise_jagged_output_kernel_"; \ + jagged_dense_dense_elementwise_jagged_output_kernel_< \ + NUM_JAGGED_DIM, \ + index_t><<>>( \ + MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \ + x_offset_ptrs, \ + x_offset_sizes, \ + MAKE_PTA_WITH_NAME(func_name, y_0_reshaped, scalar_t, 3, 32), \ + MAKE_PTA_WITH_NAME(func_name, y_1_reshaped, scalar_t, 3, 32), \ + MAKE_PTA_WITH_NAME(func_name, output_values, scalar_t, 2, 32), \ + jagged_dims_tensor, \ + f); \ } template @@ -128,18 +130,20 @@ void jagged_dense_dense_elementwise_jagged_output_opt_( } dim3 threads_bs = dim3(1024, 1, 1); dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); + +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name1 = + "jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_"; +#endif jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< index_t> <<>>( - x_offsets[0] - .packed_accessor32(), - t_rows_after_bs - .packed_accessor32(), - t_cols_after_bs - .packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32), + MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32), nnz, B); C10_CUDA_KERNEL_LAUNCH_CHECK(); @@ -149,21 +153,23 @@ void jagged_dense_dense_elementwise_jagged_output_opt_( if (blocks.y > 65535) { blocks.y = 65535; } + +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name2 = + "jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_"; +#endif jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< index_t> <<>>( - output_values - .packed_accessor32(), - x_values - .packed_accessor32(), - y_0_reshaped - .packed_accessor32(), - y_1_reshaped - .packed_accessor32(), - t_rows_after_bs - .packed_accessor32(), - t_cols_after_bs - .packed_accessor32(), + MAKE_PTA_WITH_NAME( + func_name2, output_values, c10::Half, 2, 32), + MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32), + MAKE_PTA_WITH_NAME( + func_name2, y_0_reshaped, c10::Half, 3, 32), + MAKE_PTA_WITH_NAME( + func_name2, y_1_reshaped, c10::Half, 3, 32), + MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32), + MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32), nnz, E, [f] __device__(__half x, __half y0, __half y1) -> __half { diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu index 83953322b..098c8b874 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu @@ -14,12 +14,14 @@ namespace fbgemm_gpu { template __global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel( - at::PackedTensorAccessor64 output, - const at::PackedTensorAccessor64 values, - const at::PackedTensorAccessor32 + pta::PackedTensorAccessor64 output, + const pta::PackedTensorAccessor64 + values, + const pta::PackedTensorAccessor32 input_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 output_offsets, const int64_t num_dense_input_rows) { __shared__ int smem[1]; @@ -98,27 +100,21 @@ Tensor jagged_index_add_2d_forward_cuda( indices.scalar_type(), "jagged_index_add_2d_kernel_wrapper_2", [&] { +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "jagged_index_add_2d_kernel"; +#endif jagged_index_add_2d_kernel<<< dim3(num_blocks), dim3(num_cols), 0, at::cuda::getCurrentCUDAStream()>>>( - output.packed_accessor64< - scalar_t, - 2, - at::RestrictPtrTraits>(), - values.packed_accessor64< - scalar_t, - 2, - at::RestrictPtrTraits>(), - input_offsets_contig->packed_accessor32< - int64_t, - 1, - at::RestrictPtrTraits>(), - indices - .packed_accessor32(), - output_offsets - .packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 64), + MAKE_PTA_WITH_NAME( + func_name, (*input_offsets_contig), int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, output_offsets, int64_t, 1, 32), num_dense_input_rows); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu index 81a62ecf5..f02d64e85 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu @@ -14,12 +14,13 @@ namespace fbgemm_gpu { template __global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel( - at::PackedTensorAccessor64 output, - const at::PackedTensorAccessor64 input, - const at::PackedTensorAccessor32 + pta::PackedTensorAccessor64 output, + const pta::PackedTensorAccessor64 input, + const pta::PackedTensorAccessor32 input_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 output_offsets, const int64_t num_dense_output_rows) { __shared__ int smem[1]; @@ -95,27 +96,21 @@ Tensor jagged_index_select_2d_forward_cuda( indices.scalar_type(), "jagged_index_select_2d_kernel_wrapper_2", [&] { +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "jagged_index_select_2d_kernel"; +#endif jagged_index_select_2d_kernel<<< dim3(num_blocks), dim3(num_cols), 0, at::cuda::getCurrentCUDAStream()>>>( - output.packed_accessor64< - scalar_t, - 2, - at::RestrictPtrTraits>(), - values.packed_accessor64< - scalar_t, - 2, - at::RestrictPtrTraits>(), - input_offsets - .packed_accessor32(), - indices - .packed_accessor32(), - output_offsets_contig->packed_accessor32< - int64_t, - 1, - at::RestrictPtrTraits>(), + MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 64), + MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 64), + MAKE_PTA_WITH_NAME( + func_name, input_offsets, int64_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, (*output_offsets_contig), int64_t, 1, 32), num_dense_output_rows); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu index 54934a6cb..830ea900e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_unique_indices.cu @@ -16,11 +16,13 @@ namespace fbgemm_gpu { // can be sorted together. template __global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel( - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_cumsum, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 + offsets, + pta::PackedTensorAccessor32 linear_indices, FixedDivisor fd) { const int32_t T = hash_size_cumsum.size(0) - 1; @@ -54,10 +56,11 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_wo_infos_kernel( // the element from the unique indices according to the reverse index info. template __global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel( - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 reverse_index, - at::PackedTensorAccessor32 + pta::PackedTensorAccessor32 unique_indices) { const auto total_indices = indices.size(0); const auto b_t = blockIdx.x * blockDim.x + threadIdx.x; @@ -73,12 +76,13 @@ __global__ __launch_bounds__(kMaxThreads) void delinearize_unique_index_kernel( // values in the reverse index array. template __global__ __launch_bounds__(kMaxThreads) void unique_indices_length_kernel( - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 hash_size_offsets, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 reverse_index, - const at::PackedTensorAccessor32 offsets, - at::PackedTensorAccessor32 lengths) { + const pta::PackedTensorAccessor32 + offsets, + pta::PackedTensorAccessor32 lengths) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage_max; __shared__ typename BlockReduce::TempStorage temp_storage_min; @@ -144,15 +148,18 @@ std::tuple jagged_unique_indices_cuda( indices.scalar_type(), "linearize_index", ([&] { const auto linearize_index_kernel_ = linearize_index_wo_infos_kernel; +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "linearize_index_kernel_"; +#endif linearize_index_kernel_<<< div_round_up(total_B, kMaxThreads), kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - hash_size_cumsum.packed_accessor32(), - indices.packed_accessor32(), - offsets.packed_accessor32(), - linear_indices.packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, linear_indices, index_t, 1, 32), FixedDivisor(total_B / T)); C10_CUDA_KERNEL_LAUNCH_CHECK(); })); @@ -170,14 +177,17 @@ std::tuple jagged_unique_indices_cuda( indices.scalar_type(), "delinearize_unique_index", ([&] { const auto delinearize_unique_index_kernel_ = delinearize_unique_index_kernel; +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "delinearize_unique_index_kernel_"; +#endif delinearize_unique_index_kernel_<<< div_round_up(total_indices + 1, kMaxThreads), kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - indices.packed_accessor32(), - reverse_index.packed_accessor32(), - unique_indices.packed_accessor32()); + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, reverse_index, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, unique_indices, index_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); })); @@ -188,16 +198,18 @@ std::tuple jagged_unique_indices_cuda( index_t, std::numeric_limits::max(), std::numeric_limits::min()>; +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "unique_indices_length_kernel_"; +#endif unique_indices_length_kernel_<<< T, kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - hash_size_offsets - .packed_accessor32(), - reverse_index.packed_accessor32(), - offsets.packed_accessor32(), - output_lengths.packed_accessor32()); + MAKE_PTA_WITH_NAME(func_name, hash_size_offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, reverse_index, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, output_lengths, index_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); })); @@ -209,10 +221,12 @@ std::tuple jagged_unique_indices_cuda( // Compute hash size for each key using the max value of indices per key. template __global__ __launch_bounds__(kMaxThreads) void compute_hash_size_kernel( - const at::PackedTensorAccessor32 offsets, - const at::PackedTensorAccessor32 indices, + const pta::PackedTensorAccessor32 + offsets, + const pta::PackedTensorAccessor32 + indices, const int64_t batch_size, - at::PackedTensorAccessor32 hash_size) { + pta::PackedTensorAccessor32 hash_size) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage_max; @@ -254,15 +268,18 @@ std::tuple jagged_hash_size_cumsum_cuda( const auto compute_hash_size_kernel_ = compute_hash_size_kernel< index_t, std::numeric_limits::min()>; +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "compute_hash_size_kernel_"; +#endif compute_hash_size_kernel_<<< T, kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - offsets.packed_accessor32(), - indices.packed_accessor32(), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), batch_size, - hash_size.packed_accessor32()); + MAKE_PTA_WITH_NAME(func_name, hash_size, index_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); })); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index dd853e5cf..f78b41903 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -19,10 +19,11 @@ template < int NUM_THREADS_PER_BLOCK, int MAX_ENTRIES_PER_BLOCK> __global__ void index_select_scalar_cumsum_kernel( - at::PackedTensorAccessor32 output, - at::PackedTensorAccessor32 output_cumsum, - const at::PackedTensorAccessor32 input, - const at::PackedTensorAccessor32 indices, + pta::PackedTensorAccessor32 output, + pta::PackedTensorAccessor32 output_cumsum, + const pta::PackedTensorAccessor32 input, + const pta::PackedTensorAccessor32 + indices, const int num_batches, const int input_batch_size, const int last_block_num_entries, @@ -73,16 +74,17 @@ template < typename weight_t, bool has_weights> __global__ void keyed_jagged_index_select_dim1_kernel( - at::PackedTensorAccessor64 output, - at::PackedTensorAccessor64 + pta::PackedTensorAccessor64 output, + pta::PackedTensorAccessor64 output_weights, - const at::PackedTensorAccessor64 input, - const at::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 input, + const pta::PackedTensorAccessor64 weights, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 input_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 output_offsets, const int num_batches, const int input_batch_size) { @@ -121,12 +123,13 @@ __global__ void keyed_jagged_index_select_dim1_kernel( template __global__ void keyed_jagged_index_add_dim1_kernel( - at::PackedTensorAccessor64 output, - const at::PackedTensorAccessor64 input, - const at::PackedTensorAccessor32 + pta::PackedTensorAccessor64 output, + const pta::PackedTensorAccessor64 input, + const pta::PackedTensorAccessor32 input_offsets, - const at::PackedTensorAccessor32 indices, - const at::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 + indices, + const pta::PackedTensorAccessor32 output_offsets, const int num_batches, const int output_batch_size) { @@ -226,6 +229,10 @@ class KeyedJaggedIndexSelectDim1GPUOp indices.scalar_type(), "index_select_scalar_cumsum_wrapper_3", [&] { +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = + "index_select_scalar_cumsum_kernel"; +#endif index_select_scalar_cumsum_kernel< length_t, index_t, @@ -236,22 +243,14 @@ class KeyedJaggedIndexSelectDim1GPUOp MAX_CUMSUM_ENTRIES_PER_BLOCK, 0, at::cuda::getCurrentCUDAStream()>>>( - output_lengths.packed_accessor32< - length_t, - 1, - at::RestrictPtrTraits>(), - output_offsets.packed_accessor32< - offset_t, - 1, - at::RestrictPtrTraits>(), - lengths.packed_accessor32< - length_t, - 1, - at::RestrictPtrTraits>(), - indices.packed_accessor32< - index_t, - 1, - at::RestrictPtrTraits>(), + MAKE_PTA_WITH_NAME( + func_name, output_lengths, length_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, output_offsets, offset_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, lengths, length_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, indices, index_t, 1, 32), num_batches, batch_size, num_output_lengths - @@ -282,27 +281,27 @@ class KeyedJaggedIndexSelectDim1GPUOp const auto output_offsets_contig = output_offsets.expect_contiguous(); if (grid_size != 0) { -#define LAUNCH_KERNEL(WEIGHTED, WEIGHT_TYPE, OUTPUT_WEIGHTS, WEIGHTS) \ - { \ - keyed_jagged_index_select_dim1_kernel< \ - value_t, \ - index_t, \ - offset_t, \ - WEIGHT_TYPE, \ - WEIGHTED> \ - <<>>( \ - output.packed_accessor64(), \ - OUTPUT_WEIGHTS \ - .packed_accessor64(), \ - values.packed_accessor64(), \ - WEIGHTS \ - .packed_accessor64(), \ - offsets.packed_accessor32(), \ - indices.packed_accessor32(), \ - output_offsets_contig \ - ->packed_accessor32(), \ - num_batches, \ - batch_size); \ +#define LAUNCH_KERNEL(WEIGHTED, WEIGHT_TYPE, OUTPUT_WEIGHTS, WEIGHTS) \ + { \ + [[maybe_unused]] const auto func_name = \ + "keyed_jagged_index_select_dim1_kernel"; \ + keyed_jagged_index_select_dim1_kernel< \ + value_t, \ + index_t, \ + offset_t, \ + WEIGHT_TYPE, \ + WEIGHTED> \ + <<>>( \ + MAKE_PTA_WITH_NAME(func_name, output, value_t, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name, OUTPUT_WEIGHTS, WEIGHT_TYPE, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name, values, value_t, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name, WEIGHTS, WEIGHT_TYPE, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name, offsets, offset_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), \ + MAKE_PTA_WITH_NAME( \ + func_name, (*output_offsets_contig), offset_t, 1, 32), \ + num_batches, \ + batch_size); \ } FBGEMM_DISPATCH_ALL_TYPES( values.scalar_type(), @@ -395,6 +394,9 @@ class KeyedJaggedIndexSelectDim1GPUOp "keyed_jagged_index_add_dim1_wrapper_2", [&] { using offset_t = index_t; +#ifdef FBGEMM_GPU_MEMCHECK + const auto func_name = "keyed_jagged_index_add_dim1_kernel"; +#endif AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "keyed_jagged_index_add_dim1_wrapper_3", @@ -404,26 +406,20 @@ class KeyedJaggedIndexSelectDim1GPUOp kMaxThreads, 0, at::cuda::getCurrentCUDAStream()>>>( - grad_input.packed_accessor64< - scalar_t, - 1, - at::RestrictPtrTraits>(), - grad.packed_accessor64< - scalar_t, - 1, - at::RestrictPtrTraits>(), - grad_offsets_contig->packed_accessor32< - offset_t, - 1, - at::RestrictPtrTraits>(), - indices.packed_accessor32< - index_t, - 1, - at::RestrictPtrTraits>(), - output_offsets.packed_accessor32< + MAKE_PTA_WITH_NAME( + func_name, grad_input, scalar_t, 1, 64), + MAKE_PTA_WITH_NAME( + func_name, grad, scalar_t, 1, 64), + MAKE_PTA_WITH_NAME( + func_name, + (*grad_offsets_contig), offset_t, 1, - at::RestrictPtrTraits>(), + 32), + MAKE_PTA_WITH_NAME( + func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, output_offsets, offset_t, 1, 32), num_batches, output_batch_size); C10_CUDA_KERNEL_LAUNCH_CHECK();