diff --git a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu index f32301e47..6239d914f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu @@ -88,8 +88,18 @@ DLL_PUBLIC Tensor reorder_batched_ad_lengths_gpu( ? at::empty({T * num_ads_in_batch}, cat_ad_lengths.options()) : at::empty_like(cat_ad_lengths); + const int64_t grid_size = (B * T + 32 - 1) / 32; + TORCH_CHECK( + grid_size >= 0, + "grid_size must be positive, got ", + grid_size, + " where B =", + B, + " and T =", + T); + const dim3 threads(32, 32); - const dim3 blocks((B * T + 32 - 1) / 32); + const dim3 blocks(grid_size); FBGEMM_DISPATCH_ALL_TYPES( cat_ad_lengths.scalar_type(),