diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index ec11e86e8..4dbb5795a 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -166,6 +166,9 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( evicted_indices, pta::PackedTensorAccessor32 actions_count, + const bool lock_cache_line, + pta::PackedTensorAccessor32 + lxu_cache_locking_counter, TORCH_DSA_KERNEL_ARGS) { // Number of cache sets const int32_t C = lxu_cache_state.size(0); @@ -216,51 +219,65 @@ __global__ __launch_bounds__(kMaxThreads) void ssd_cache_actions_insert_kernel( SL += 1; } - // now, we need to insert the (unique!) values in indices[n:n + SL] into + // Now, we need to insert the (unique!) values in indices[n:n + SL] into // our slots. const int32_t slot = threadIdx.x; const int64_t slot_time = lru_state[cache_set][slot]; - int64_t costs[1] = {slot_time}; + + // Check if the slot is locked + const bool is_slot_locked = + lock_cache_line && (lxu_cache_locking_counter[cache_set][slot] > 0); + // Check if the slot has the inserted row that was a cache hit. + const int64_t slot_idx = lxu_cache_state[cache_set][slot]; + const bool slot_has_idx = slot_idx != -1 && slot_time == time_stamp; + // Check if the slot is unavailable: either it is locked or contains + // a cache hit inserted row + const bool is_slot_unavailable = is_slot_locked || slot_has_idx; + + // Set the slot cost: if the slot is not available, set it to the + // maximum timestamp which is the current timestamp. After sorting, + // the unavailable slots will be in the bottom, while the available + // slots will be bubbled to the top + const int64_t slot_cost = is_slot_unavailable ? time_stamp : slot_time; + + // Prepare key-value pair for sorting + int64_t costs[1] = {slot_cost}; int32_t slots[1] = {slot}; + // Sort the slots based on their costs BitonicSort>::sort(costs, slots); - const int32_t sorted_slot = slots[0]; - const int64_t sorted_time = costs[0]; + + // Get the sorted results + const int32_t insert_slot = slots[0]; + const int64_t insert_cost = costs[0]; auto l = threadIdx.x; + // Get the current index + const int64_t current_idx = shfl_sync(slot_idx, insert_slot); + // Insert rows if (l < SL) { // Insert indices - const int32_t insert_slot = sorted_slot; - const int64_t insert_time = sorted_time; - const int64_t insert_idx = cache_set_sorted_indices[n + l]; - const int64_t current_idx = lxu_cache_state[cache_set][insert_slot]; - -#if 0 - // TODO: Check whether to uncomment this - // Only check insert_time if tag is for valid entry - if (current_idx != -1) { - // We need to ensure if prefetching (prefetch_dist) batches ahead - // No entries that are younger than (time_stamp - prefetch_dist) are - // evicted from the cache. This will break the guarantees required - // for the SSD embedding. - // If you hit this assert, increase the cache size. - CUDA_KERNEL_ASSERT2(insert_time < (time_stamp - prefetch_dist)); - } -#endif - if (current_idx != -1 && insert_time == time_stamp) { - // Skip this slot as the inserted row was a cache hit - // This is conflict miss + if (insert_cost == time_stamp) { + // Skip this slot as it is not available evicted_indices[n + l] = -1; assigned_cache_slots[n + l] = -1; } else { evicted_indices[n + l] = current_idx; // -1 if not set, >= 0 if valid. assigned_cache_slots[n + l] = cache_set * kWarpSize + insert_slot; + + // TODO: Check if we can do contiguous writes here. + // Update cache states lxu_cache_state[cache_set][insert_slot] = insert_idx; lru_state[cache_set][insert_slot] = time_stamp; + + // Lock cache line + if (lock_cache_line) { + lxu_cache_locking_counter[cache_set][insert_slot] += 1; + } } } @@ -280,9 +297,11 @@ ssd_cache_populate_actions_cuda( int64_t prefetch_dist, Tensor lru_state, bool gather_cache_stats, - std::optional ssd_cache_stats) { + std::optional ssd_cache_stats, + const bool lock_cache_line, + const c10::optional& lxu_cache_locking_counter) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( - linear_indices, lxu_cache_state, lru_state); + linear_indices, lxu_cache_state, lru_state, lxu_cache_locking_counter); CUDA_DEVICE_GUARD(linear_indices); @@ -332,9 +351,17 @@ ssd_cache_populate_actions_cuda( /*cache_set_inverse_indices=*/at::empty({0}, int_options)); } + Tensor lxu_cache_locking_counter_; + if (lock_cache_line) { + TORCH_CHECK(lxu_cache_locking_counter.has_value()); + lxu_cache_locking_counter_ = lxu_cache_locking_counter.value(); + } else { + lxu_cache_locking_counter_ = + at::empty({0, 0}, lxu_cache_state.options().dtype(at::kInt)); + } + auto actions_count = at::empty({1}, int_options); // Find uncached indices - Tensor lxu_cache_locking_counter = at::empty({0, 0}, int_options); auto [sorted_cache_sets, cache_set_sorted_unique_indices, @@ -348,8 +375,8 @@ ssd_cache_populate_actions_cuda( lru_state, gather_cache_stats, ssd_cache_stats_, - /*lock_cache_line=*/false, - lxu_cache_locking_counter, + lock_cache_line, + lxu_cache_locking_counter_, /*compute_inverse_indices=*/true); TORCH_CHECK(cache_set_inverse_indices.has_value()); @@ -373,7 +400,10 @@ ssd_cache_populate_actions_cuda( MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32), MAKE_PTA_WITH_NAME(func_name, assigned_cache_slots, int32_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, evicted_indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32)); + MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32), + lock_cache_line, + MAKE_PTA_WITH_NAME( + func_name, lxu_cache_locking_counter_, int32_t, 2, 32)); return std::make_tuple( cache_set_sorted_unique_indices, diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp index b9abcbbd5..86eb3c0a7 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp @@ -26,7 +26,9 @@ ssd_cache_populate_actions_cuda( int64_t prefetch_dist, Tensor lru_state, bool gather_cache_stats, - std::optional ssd_cache_stats); + std::optional ssd_cache_stats, + const bool lock_cache_line, + const c10::optional& lxu_cache_locking_counter); /// @ingroup embedding-ssd /// @@ -298,7 +300,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " int prefetch_dist, " " Tensor lru_state, " " bool gather_cache_stats=False, " - " Tensor? ssd_cache_stats=None" + " Tensor? ssd_cache_stats=None, " + " bool lock_cache_line=False, " + " Tensor? lxu_cache_locking_counter=None" ") -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); DISPATCH_TO_CUDA( "ssd_cache_populate_actions", ssd_cache_populate_actions_cuda);