From e0b56ac28fc6ea19c46807213194073256ba6c65 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Tue, 13 Aug 2024 12:41:48 -0700 Subject: [PATCH] Add ssd_update_row_addrs (#2953) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2953 X-link: https://github.com/facebookresearch/FBGEMM/pull/53 When pipeline prefetching is enabled, data in a scratch pad of the current iteration can be moved to L1 or a scratch pad of the next iteration during the prefetch step. `ssd_update_row_addrs` updates the memory addresses of data that is relocated to the correct location. Reviewed By: ehsanardestani Differential Revision: D60983150 fbshipit-source-id: a5e898c9a567e549e6fc439bdc0ccad04909ebf1 --- .../ssd_split_embeddings_cache_cuda.cu | 109 ++++++++++++++++++ .../ssd_split_table_batched_embeddings.cpp | 63 ++++++++++ 2 files changed, 172 insertions(+) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index 4dbb5795a..10189fcf7 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -538,3 +538,112 @@ std::tuple ssd_generate_row_addrs_cuda( return {ssd_row_addrs, post_bwd_evicted_indices}; } + +__global__ __launch_bounds__(kMaxThreads) void ssd_update_row_addrs_kernel( + at::PackedTensorAccessor32 + ssd_row_addrs_curr, + const at::PackedTensorAccessor32 + ssd_curr_next_map, + const at::PackedTensorAccessor32 + lxu_cache_locations_curr, + const at::PackedTensorAccessor32 + linear_index_inverse_indices_curr, + const at::PackedTensorAccessor32 + unique_indices_count_cumsum_curr, + const at::PackedTensorAccessor32 + cache_set_inverse_indices_curr, + const uint64_t lxu_cache_weights_addr, + const uint64_t inserted_ssd_weights_addr_next, + const int* N_unique_curr, + const uint64_t cache_row_bytes // has to be 64 bits to prevent overflow +) { + const auto n_curr = blockDim.y * blockIdx.x + threadIdx.y; + if (n_curr >= *N_unique_curr) { + return; + } + + // Find mapping between n_curr and n_next + const auto n_next = ssd_curr_next_map[n_curr]; + + // Return if the row is not used in both previous and next iterations + if (n_next < 0) { + return; + } + + // Find out if the row gets moved to the nextent iteration's scratch pad or + // L1 by checking the lxu_cache_locations_curr + const auto cache_set_id_curr = cache_set_inverse_indices_curr[n_curr]; + const auto segment_start_curr = + unique_indices_count_cumsum_curr[cache_set_id_curr]; + const auto segment_end_curr = + unique_indices_count_cumsum_curr[cache_set_id_curr + 1]; + const auto cache_loc_curr = lxu_cache_locations_curr + [linear_index_inverse_indices_curr[segment_start_curr]]; + + const uint64_t ptr_addr = (cache_loc_curr == -1) + // The row is moved from the previous iteration's scratch pad to the + // next iteration's scratch pad + ? (inserted_ssd_weights_addr_next + (n_next * cache_row_bytes)) + // The row is moved from the previous iteration's scratch pad to L1 cache + : (lxu_cache_weights_addr + (cache_loc_curr * cache_row_bytes)); + + // Set pointer address + for (auto l = segment_start_curr + threadIdx.x; l < segment_end_curr; + l += blockDim.x) { + auto dst = linear_index_inverse_indices_curr[l]; + *reinterpret_cast(&ssd_row_addrs_curr[dst]) = ptr_addr; + } +} + +void ssd_update_row_addrs_cuda( + const Tensor& ssd_row_addrs_curr, + const Tensor& ssd_curr_next_map, + const Tensor& lxu_cache_locations_curr, + const Tensor& linear_index_inverse_indices_curr, + const Tensor& unique_indices_count_cumsum_curr, + const Tensor& cache_set_inverse_indices_curr, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights_next, + const Tensor& unique_indices_length_curr) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + ssd_row_addrs_curr, + ssd_curr_next_map, + lxu_cache_locations_curr, + linear_index_inverse_indices_curr, + unique_indices_count_cumsum_curr, + cache_set_inverse_indices_curr, + lxu_cache_weights, + inserted_ssd_weights_next, + unique_indices_length_curr); + + CUDA_DEVICE_GUARD(ssd_row_addrs_curr); + + const auto lxu_cache_weights_addr = + reinterpret_cast(lxu_cache_weights.data_ptr()); + const auto inserted_ssd_weights_addr_next = + reinterpret_cast(inserted_ssd_weights_next.data_ptr()); + const auto cache_row_bytes = + lxu_cache_weights.size(1) * lxu_cache_weights.element_size(); + constexpr auto kNumWarps = kMaxThreads / kWarpSize; + + ssd_update_row_addrs_kernel<<< + div_round_up(ssd_row_addrs_curr.numel(), kNumWarps), + dim3(kWarpSize, kNumWarps), + 0, + at::cuda::getCurrentCUDAStream()>>>( + ssd_row_addrs_curr.packed_accessor32(), + ssd_curr_next_map.packed_accessor32(), + lxu_cache_locations_curr + .packed_accessor32(), + linear_index_inverse_indices_curr + .packed_accessor32(), + unique_indices_count_cumsum_curr + .packed_accessor32(), + cache_set_inverse_indices_curr + .packed_accessor32(), + lxu_cache_weights_addr, + inserted_ssd_weights_addr_next, + unique_indices_length_curr.data_ptr(), + cache_row_bytes); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index 86eb3c0a7..1cc67815a 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -148,6 +148,56 @@ std::tuple ssd_generate_row_addrs_cuda( const Tensor& unique_indices_length, const Tensor& cache_set_sorted_unique_indices); +/// @ingroup embedding-ssd +/// +/// @brief Update memory addresses for SSD TBE data +/// +/// When pipeline prefetching is enabled, data in a scratch pad of the +/// current iteration can be moved to L1 or a scratch pad of the next +/// iteration during the prefetch step. This operator updates the +/// memory addresses of data that is relocated to the correct +/// location. +/// +/// @param ssd_row_addrs_curr The tensor that contains the row address +/// of the current iteration +/// @param inserted_ssd_weights_curr_next_map The tensor that contains +/// mapping between the location of each index in the +/// current iteration in the scratch pad of the next +/// iteration. (-1 = the data has not been moved). +/// inserted_ssd_weights_curr_next_map[i] is the location +// of index i in the next iteration's scratch pad. +/// @param lxu_cache_locations_curr The tensor that contains cache +/// slots where data is stored for the *full* list of +/// indices for the current iteration. -1 is a sentinel +/// value that indicates that data is not in cache. +/// @param linear_index_inverse_indices_curr The tensor that contains +/// the original position of linear indices before being +/// sorted for the current iteration +/// @param unique_indices_count_cumsum_curr The tensor that contains +/// the the exclusive prefix sum results of the counts of +/// unique indices for the current iteration +/// @param cache_set_inverse_indices_curr The tensor that contains the +/// original positions of cache sets before being sorted +/// for the current iteration +/// @param lxu_cache_weights The LXU cache tensor +/// @param inserted_ssd_weights_next The scratch pad tensor for the +/// next iteration +/// @param unique_indices_length_curr The tensor that contains the +/// number of unique indices (GPU tensor) for the current +/// iteration +/// +/// @return None +void ssd_update_row_addrs_cuda( + const Tensor& ssd_row_addrs_curr, + const Tensor& inserted_ssd_weights_curr_next_map, + const Tensor& lxu_cache_locations_curr, + const Tensor& linear_index_inverse_indices_curr, + const Tensor& unique_indices_count_cumsum_curr, + const Tensor& cache_set_inverse_indices_curr, + const Tensor& lxu_cache_weights, + const Tensor& inserted_ssd_weights_next, + const Tensor& unique_indices_length_curr); + namespace { class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder { public: @@ -319,5 +369,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " Tensor cache_set_sorted_unique_indices" ") -> (Tensor, Tensor)"); DISPATCH_TO_CUDA("ssd_generate_row_addrs", ssd_generate_row_addrs_cuda); + m.def( + "ssd_update_row_addrs(" + " Tensor ssd_row_addrs_curr, " + " Tensor inserted_ssd_weights_curr_next_map, " + " Tensor lxu_cache_locations_curr, " + " Tensor linear_index_inverse_indices_curr, " + " Tensor unique_indices_count_cumsum_curr, " + " Tensor cache_set_inverse_indices_curr, " + " Tensor lxu_cache_weights, " + " Tensor inserted_ssd_weights_next, " + " Tensor unique_indices_length_curr" + ") -> ()"); + DISPATCH_TO_CUDA("ssd_update_row_addrs", ssd_update_row_addrs_cuda); } } // namespace