Skip to content

Commit

Permalink
Add memchecks to jagged tensor ops (#2572)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2572

- Add memchecks to jagged tensor ops

Reviewed By: spcyppt

Differential Revision: D57123188

fbshipit-source-id: 2a7831dd8f9c56067411911c5d445f824a410433
  • Loading branch information
q10 authored and facebook-github-bot committed May 9, 2024
1 parent 7d15c59 commit c216005
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 277 deletions.
177 changes: 90 additions & 87 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ DEVICE_INLINE bool walk_down_tensor_storage_tree_(
template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
__global__
__launch_bounds__(kMaxThreads) void jagged_dense_elementwise_dense_output_kernel_(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
x_values,
StackArray<index_t*> x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y,
at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> output,
const pta::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y,
pta::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> output,
StackArray<int64_t> jagged_dims,
F f,
const scalar_t padding_value) {
Expand Down Expand Up @@ -243,28 +243,28 @@ void jagged_dense_elementwise_dense_output_(
const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});
Tensor output_reshaped = output.view(y_reshaped.sizes());

#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
std::vector<Tensor> x_offsets_contig; \
x_offsets_contig.resize(num_jagged_dim); \
StackArray<index_t*> x_offset_ptrs; \
x_offset_ptrs.ndim = num_jagged_dim; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs.vals[d] = \
x_offsets_contig[d].template data_ptr<index_t>(); \
} \
jagged_dense_elementwise_dense_output_kernel_<NUM_JAGGED_DIM, index_t> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
x_offset_ptrs, \
y_reshaped \
.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
output_reshaped \
.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
jagged_dims_tensor, \
f, \
padding_value); \
#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
std::vector<Tensor> x_offsets_contig; \
x_offsets_contig.resize(num_jagged_dim); \
StackArray<index_t*> x_offset_ptrs; \
x_offset_ptrs.ndim = num_jagged_dim; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs.vals[d] = \
x_offsets_contig[d].template data_ptr<index_t>(); \
} \
[[maybe_unused]] const auto func_name = \
"jagged_dense_elementwise_dense_output_kernel_"; \
jagged_dense_elementwise_dense_output_kernel_<NUM_JAGGED_DIM, index_t> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \
x_offset_ptrs, \
MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \
MAKE_PTA_WITH_NAME(func_name, output_reshaped, scalar_t, 3, 32), \
jagged_dims_tensor, \
f, \
padding_value); \
}

JAGGED_TENSOR_DISPATCH_DIMS();
Expand All @@ -289,13 +289,13 @@ Tensor jagged_dense_elementwise_dense_output_(
template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
__global__
__launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output_kernel_(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
x_values,
StackArray<index_t*> x_offsets,
StackArray<int64_t> x_offsets_sizes,
const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_0,
const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_1,
at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_0,
const pta::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y_1,
pta::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
output_values,
StackArray<int64_t> jagged_dims,
F f) {
Expand Down Expand Up @@ -380,9 +380,10 @@ __launch_bounds__(kMaxThreads) void jagged_dense_dense_elementwise_jagged_output

template <typename index_t>
__global__ void jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_(
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
pta::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
pta::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
int nnz,
int B) {
struct SharedMemory<index_t> smem;
Expand Down Expand Up @@ -519,13 +520,13 @@ fh(__half& v_out, const __half& x, const __half& y0, const __half& y1, F f) {

template <typename index_t, typename F>
__global__ void jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_(
at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits> values,
const at::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits> values,
const pta::PackedTensorAccessor32<c10::Half, 2, at::RestrictPtrTraits>
x_values,
const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y0,
const at::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y1,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
const pta::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y0,
const pta::PackedTensorAccessor32<c10::Half, 3, at::RestrictPtrTraits> y1,
const pta::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> rows,
const pta::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> cols,
const int nnz,
const int E,
F f) {
Expand Down Expand Up @@ -692,37 +693,39 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
return matches;
}

#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
dim3 threads, blocks; \
StackArray<int64_t> jagged_dims_tensor; \
std::tie(threads, blocks, jagged_dims_tensor) = \
check_shape_and_partition_(x_values, x_offsets, y); \
blocks.x = div_round_up(x_values.size(0), threads.y); \
std::vector<Tensor> x_offsets_contig; \
x_offsets_contig.resize(num_jagged_dim); \
StackArray<index_t*> x_offset_ptrs; \
x_offset_ptrs.ndim = num_jagged_dim; \
StackArray<int64_t> x_offset_sizes; \
x_offset_sizes.ndim = num_jagged_dim; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs.vals[d] = \
x_offsets_contig[d].template data_ptr<index_t>(); \
x_offset_sizes.vals[d] = x_offsets[d].numel(); \
} \
jagged_dense_dense_elementwise_jagged_output_kernel_< \
NUM_JAGGED_DIM, \
index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
x_offset_ptrs, \
x_offset_sizes, \
y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
y_reshaped.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
output_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
jagged_dims_tensor, \
[f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \
-> scalar_t { return f(x, y); }); \
#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
dim3 threads, blocks; \
StackArray<int64_t> jagged_dims_tensor; \
std::tie(threads, blocks, jagged_dims_tensor) = \
check_shape_and_partition_(x_values, x_offsets, y); \
blocks.x = div_round_up(x_values.size(0), threads.y); \
std::vector<Tensor> x_offsets_contig; \
x_offsets_contig.resize(num_jagged_dim); \
StackArray<index_t*> x_offset_ptrs; \
x_offset_ptrs.ndim = num_jagged_dim; \
StackArray<int64_t> x_offset_sizes; \
x_offset_sizes.ndim = num_jagged_dim; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs.vals[d] = \
x_offsets_contig[d].template data_ptr<index_t>(); \
x_offset_sizes.vals[d] = x_offsets[d].numel(); \
} \
[[maybe_unused]] const auto func_name = \
"jagged_dense_dense_elementwise_jagged_output_kernel_"; \
jagged_dense_dense_elementwise_jagged_output_kernel_< \
NUM_JAGGED_DIM, \
index_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
MAKE_PTA_WITH_NAME(func_name, x_values, scalar_t, 2, 32), \
x_offset_ptrs, \
x_offset_sizes, \
MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \
MAKE_PTA_WITH_NAME(func_name, y_reshaped, scalar_t, 3, 32), \
MAKE_PTA_WITH_NAME(func_name, output_values, scalar_t, 2, 32), \
jagged_dims_tensor, \
[f] __device__(scalar_t x, scalar_t y, scalar_t /*unused*/) \
-> scalar_t { return f(x, y); }); \
}

///@addtogroup jagged-tensor-ops-cuda
Expand Down Expand Up @@ -810,18 +813,19 @@ void jagged_dense_elementwise_jagged_output_opt_(
}
dim3 threads_bs = dim3(1024, 1, 1);
dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name1 =
"jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_";
#endif
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>
<<<blocks_bs,
threads_bs,
dynamic_smem_size,
at::cuda::getCurrentCUDAStream()>>>(
x_offsets[0]
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
t_rows_after_bs
.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
t_cols_after_bs
.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32),
nnz,
B);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -831,21 +835,20 @@ void jagged_dense_elementwise_jagged_output_opt_(
if (blocks.y > 65535) {
blocks.y = 65535;
}
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name2 =
"jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_";
#endif
jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
index_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
output_values
.packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
x_values
.packed_accessor32<c10::Half, 2, at::RestrictPtrTraits>(),
y_reshaped
.packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
y_reshaped
.packed_accessor32<c10::Half, 3, at::RestrictPtrTraits>(),
t_rows_after_bs
.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
t_cols_after_bs
.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(
func_name2, output_values, c10::Half, 2, 32),
MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32),
MAKE_PTA_WITH_NAME(func_name2, y_reshaped, c10::Half, 3, 32),
MAKE_PTA_WITH_NAME(func_name2, y_reshaped, c10::Half, 3, 32),
MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32),
nnz,
E,
[f] __device__(__half x, __half y0, __half) -> __half {
Expand Down
Loading

0 comments on commit c216005

Please sign in to comment.