From 5d4e59b605548df78c7d996d517f23cfee8e8433 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Tue, 13 Aug 2024 11:09:48 -0700 Subject: [PATCH] Fix masked_index_* values index type Summary: Use 64-bit index for `values` in `masked_index_kernel` Reviewed By: chrisxcai Differential Revision: D61216281 --- .../ssd_split_embeddings_cache_cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu index ec11e86e8..a8bf1752d 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_embeddings_cache_cuda.cu @@ -54,7 +54,7 @@ __global__ __launch_bounds__(kMaxThreads) void masked_index_kernel( pta::PackedTensorAccessor64 self, const pta::PackedTensorAccessor32 indices, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor64 values, const pta::PackedTensorAccessor32 count) { @@ -118,7 +118,7 @@ Tensor masked_index_impl( at::cuda::getCurrentCUDAStream()>>>( MAKE_PTA_WITH_NAME(func_name, self, scalar_t, 2, 64), MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 32), + MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 64), MAKE_PTA_WITH_NAME(func_name, count, int32_t, 1, 32)); C10_CUDA_KERNEL_LAUNCH_CHECK(); } // lambda