Skip to content

Commit

Permalink
Add memchecks to ssd split embeddings cache (#2589)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2589

- Add memchecks to ssd split embeddings cache

Reviewed By: spcyppt

Differential Revision: D57307547

fbshipit-source-id: 63b43c4b6479d64d0435d373d6e5415ec1d46c44
  • Loading branch information
q10 authored and facebook-github-bot committed May 14, 2024
1 parent 17a4e18 commit af97deb
Showing 1 changed file with 38 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <ATen/cuda/Atomic.cuh>
#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/fbgemm_tensor_accessor.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
Expand All @@ -25,10 +26,12 @@ using namespace fbgemm_gpu;

template <typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel(
at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> self,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits> values,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> count,
pta::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> self,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
values,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> count,
TORCH_DSA_KERNEL_ARGS) {
const int32_t N = indices.size(0);
const int32_t n = blockIdx.x * blockDim.y + threadIdx.y;
Expand All @@ -48,10 +51,11 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel(

template <>
__global__ __launch_bounds__(kMaxThreads) void masked_index_put_kernel(
at::PackedTensorAccessor64<uint8_t, 2, at::RestrictPtrTraits> self,
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<uint8_t, 2, at::RestrictPtrTraits> values,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> count,
pta::PackedTensorAccessor64<uint8_t, 2, at::RestrictPtrTraits> self,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<uint8_t, 2, at::RestrictPtrTraits> values,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> count,
TORCH_DSA_KERNEL_ARGS) {
const int32_t N = indices.size(0);
const int32_t n = blockIdx.x * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -97,43 +101,47 @@ Tensor masked_index_put_cuda(
[&] {
const int32_t tx = std::min<int32_t>(D / 4, kMaxThreads);
const dim3 threads(tx, kMaxThreads / tx);
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "masked_index_put_kernel";
#endif
TORCH_DSA_KERNEL_LAUNCH(
masked_index_put_kernel<scalar_t>,
div_round_up(N, kMaxThreads / tx),
dim3(tx, kMaxThreads / tx),
0,
at::cuda::getCurrentCUDAStream(),
self.packed_accessor64<scalar_t, 2, at::RestrictPtrTraits>(),
indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
count.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
MAKE_PTA_WITH_NAME(func_name, self, scalar_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32));
} // lambda
);

return self;
}

__global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel(
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits>
lxu_cache_state,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
sorted_cache_sets, // [N = \sum_{b} L_{b} total indices, i.e.
// flattened
// [B][L]
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
cache_set_sorted_indices, // [N = \sum_{b} L_{b} total indices, i.e.
// flattened [B][L]
int64_t time_stamp,
int64_t prefetch_dist, // Number of batches we can prefetch ahead of a
// forward call A value of 1 means that entries where
// timestep with insert_time >= time_stamp -
// prefetch_dist are locked, and cannot be evicted.
at::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<int64_t, 2, at::RestrictPtrTraits> lru_state,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
assigned_cache_slots,
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
evicted_indices,
at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> actions_count,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
actions_count,
TORCH_DSA_KERNEL_ARGS) {
// Number of cache sets
const int32_t C = lxu_cache_state.size(0);
Expand Down Expand Up @@ -280,23 +288,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> ssd_cache_populate_actions_cuda(
lxu_cache_locking_counter);
auto sorted_cache_sets = cache_sets_and_unique_indices.first;
auto cache_set_sorted_unique_indices = cache_sets_and_unique_indices.second;
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "ssd_cache_actions_insert_kernel";
#endif
TORCH_DSA_KERNEL_LAUNCH(
ssd_cache_actions_insert_kernel,
div_round_up(N, kMaxThreads / kWarpSize),
dim3(kWarpSize, kMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream(),
lxu_cache_state.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
sorted_cache_sets.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_set_sorted_unique_indices
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, sorted_cache_sets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cache_set_sorted_unique_indices, int64_t, 1, 32),
time_stamp,
prefetch_dist,
lru_state.packed_accessor32<int64_t, 2, at::RestrictPtrTraits>(),
assigned_cache_slots
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
evicted_indices.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
actions_count.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>());
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, assigned_cache_slots, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, evicted_indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32));

return std::make_tuple(
cache_set_sorted_unique_indices,
Expand Down

0 comments on commit af97deb

Please sign in to comment.