Skip to content

Commit

Permalink
Add ssd_update_row_addrs (#2953)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2953

X-link: facebookresearch/FBGEMM#53

When pipeline prefetching is enabled, data in a scratch pad of the
current iteration can be moved to L1 or a scratch pad of the next
iteration during the prefetch step. `ssd_update_row_addrs` updates the
memory addresses of data that is relocated to the correct location.

Differential Revision: D60983150
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 9, 2024
1 parent 8ec2c66 commit e1cd9dc
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,112 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
return {ssd_row_addrs, post_bwd_evicted_indices};
}
__global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel(
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
ssd_row_addrs_curr,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
ssd_curr_next_map,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
lxu_cache_locations_curr,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
linear_index_inverse_indices_curr,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
unique_indices_count_cumsum_curr,
const at::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
cache_set_inverse_indices_curr,
const uint64_t lxu_cache_weights_addr,
const uint64_t inserted_ssd_weights_addr_next,
const int* N_unique_curr,
const uint64_t cache_row_bytes // has to be 64 bits to prevent overflow
) {
const auto n_curr = blockDim.y * blockIdx.x + threadIdx.y;
if (n_curr >= *N_unique_curr) {
return;
}
// Find mapping between n_curr and n_next
const auto n_next = ssd_curr_next_map[n_curr];
// Return if the row is not used in both previous and next iterations
if (n_next < 0) {
return;
}
// Find out if the row gets moved to the nextent iteration's scratch pad or
// L1 by checking the lxu_cache_locations_curr
const auto cache_set_id_curr = cache_set_inverse_indices_curr[n_curr];
const auto segment_start_curr =
unique_indices_count_cumsum_curr[cache_set_id_curr];
const auto segment_end_curr =
unique_indices_count_cumsum_curr[cache_set_id_curr + 1];
const auto cache_loc_curr = lxu_cache_locations_curr
[linear_index_inverse_indices_curr[segment_start_curr]];
const uint64_t ptr_addr = (cache_loc_curr == -1)
// The row is moved from the previous iteration's scratch pad to the
// next iteration's scratch pad
? (inserted_ssd_weights_addr_next + (n_next * cache_row_bytes))
// The row is moved from the previous iteration's scratch pad to L1 cache
: (lxu_cache_weights_addr + (cache_loc_curr * cache_row_bytes));
// Set pointer address
for (auto l = segment_start_curr + threadIdx.x; l < segment_end_curr;
l += blockDim.x) {
auto dst = linear_index_inverse_indices_curr[l];
*reinterpret_cast<uint64_t*>(&ssd_row_addrs_curr[dst]) = ptr_addr;
}
}
void ssd_update_row_addrs_cuda(
const Tensor& ssd_row_addrs_curr,
const Tensor& ssd_curr_next_map,
const Tensor& lxu_cache_locations_curr,
const Tensor& linear_index_inverse_indices_curr,
const Tensor& unique_indices_count_cumsum_curr,
const Tensor& cache_set_inverse_indices_curr,
const Tensor& lxu_cache_weights,
const Tensor& inserted_ssd_weights_next,
const Tensor& unique_indices_length_curr) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
ssd_row_addrs_curr,
ssd_curr_next_map,
lxu_cache_locations_curr,
linear_index_inverse_indices_curr,
unique_indices_count_cumsum_curr,
cache_set_inverse_indices_curr,
lxu_cache_weights,
inserted_ssd_weights_next,
unique_indices_length_curr);
CUDA_DEVICE_GUARD(ssd_row_addrs_curr);
const auto lxu_cache_weights_addr =
reinterpret_cast<uint64_t>(lxu_cache_weights.data_ptr());
const auto inserted_ssd_weights_addr_next =
reinterpret_cast<uint64_t>(inserted_ssd_weights_next.data_ptr());
const auto cache_row_bytes =
lxu_cache_weights.size(1) * lxu_cache_weights.element_size();
constexpr auto kNumWarps = kMaxThreads / kWarpSize;
ssd_update_row_addrs_kernel<<<
div_round_up(ssd_row_addrs_curr.numel(), kNumWarps),
dim3(kWarpSize, kNumWarps),
0,
at::cuda::getCurrentCUDAStream()>>>(
ssd_row_addrs_curr.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
ssd_curr_next_map.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lxu_cache_locations_curr
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
linear_index_inverse_indices_curr
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
unique_indices_count_cumsum_curr
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
cache_set_inverse_indices_curr
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
lxu_cache_weights_addr,
inserted_ssd_weights_addr_next,
unique_indices_length_curr.data_ptr<int32_t>(),
cache_row_bytes);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,56 @@ std::tuple<Tensor, Tensor> ssd_generate_row_addrs_cuda(
const Tensor& unique_indices_length,
const Tensor& cache_set_sorted_unique_indices);

/// @ingroup embedding-ssd
///
/// @brief Update memory addresses for SSD TBE data
///
/// When pipeline prefetching is enabled, data in a scratch pad of the
/// current iteration can be moved to L1 or a scratch pad of the next
/// iteration during the prefetch step. This operator updates the
/// memory addresses of data that is relocated to the correct
/// location.
///
/// @param ssd_row_addrs_curr The tensor that contains the row address
/// of the current iteration
/// @param inserted_ssd_weights_curr_next_map The tensor that contains
/// mapping between the location of each index in the
/// current iteration in the scratch pad of the next
/// iteration. (-1 = the data has not been moved).
/// inserted_ssd_weights_curr_next_map[i] is the location
// of index i in the next iteration's scratch pad.
/// @param lxu_cache_locations_curr The tensor that contains cache
/// slots where data is stored for the *full* list of
/// indices for the current iteration. -1 is a sentinel
/// value that indicates that data is not in cache.
/// @param linear_index_inverse_indices_curr The tensor that contains
/// the original position of linear indices before being
/// sorted for the current iteration
/// @param unique_indices_count_cumsum_curr The tensor that contains
/// the the exclusive prefix sum results of the counts of
/// unique indices for the current iteration
/// @param cache_set_inverse_indices_curr The tensor that contains the
/// original positions of cache sets before being sorted
/// for the current iteration
/// @param lxu_cache_weights The LXU cache tensor
/// @param inserted_ssd_weights_next The scratch pad tensor for the
/// next iteration
/// @param unique_indices_length_curr The tensor that contains the
/// number of unique indices (GPU tensor) for the current
/// iteration
///
/// @return None
void ssd_update_row_addrs_cuda(
const Tensor& ssd_row_addrs_curr,
const Tensor& inserted_ssd_weights_curr_next_map,
const Tensor& lxu_cache_locations_curr,
const Tensor& linear_index_inverse_indices_curr,
const Tensor& unique_indices_count_cumsum_curr,
const Tensor& cache_set_inverse_indices_curr,
const Tensor& lxu_cache_weights,
const Tensor& inserted_ssd_weights_next,
const Tensor& unique_indices_length_curr);

namespace {
class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
public:
Expand Down Expand Up @@ -315,5 +365,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" Tensor cache_set_sorted_unique_indices"
") -> (Tensor, Tensor)");
DISPATCH_TO_CUDA("ssd_generate_row_addrs", ssd_generate_row_addrs_cuda);
m.def(
"ssd_update_row_addrs("
" Tensor ssd_row_addrs_curr, "
" Tensor inserted_ssd_weights_curr_next_map, "
" Tensor lxu_cache_locations_curr, "
" Tensor linear_index_inverse_indices_curr, "
" Tensor unique_indices_count_cumsum_curr, "
" Tensor cache_set_inverse_indices_curr, "
" Tensor lxu_cache_weights, "
" Tensor inserted_ssd_weights_next, "
" Tensor unique_indices_length_curr"
") -> ()");
DISPATCH_TO_CUDA("ssd_update_row_addrs", ssd_update_row_addrs_cuda);
}
} // namespace

0 comments on commit e1cd9dc

Please sign in to comment.