Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int64_t indices and offsets in TBE inference [4/N] #3128

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ for (const auto l : c10::irange(L)) {
}

{% if not weighted %}

Tensor pruned_array_lookup_cpu(
Tensor indices,
Tensor offsets,
Expand All @@ -469,33 +470,41 @@ Tensor pruned_array_lookup_cpu(
int32_t T = index_remappings_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

auto dense_indices = empty_like(indices);
const auto* indices_acc = indices.data_ptr<int32_t>();
auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();
const auto* offsets_acc = offsets.data_ptr<int32_t>();

const auto index_remappings_acc = index_remappings.data_ptr<int32_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();
at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
int64_t index_remappings_start = index_remappings_offsets_acc[t];
int64_t index_remappings_end = index_remappings_offsets_acc[t + 1];
int64_t capacity = index_remappings_end - index_remappings_start;
int32_t indices_start = offsets_acc[t * B];
int32_t indices_end = offsets_acc[(t + 1) * B];
if (capacity > 0) {
for (const auto i : c10::irange(indices_start,indices_end)) {
int32_t idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(int32_t));
}
}
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

const auto index_remappings_acc = index_remappings.data_ptr<index_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();

at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
const auto index_remappings_start = index_remappings_offsets_acc[t];
const auto index_remappings_end = index_remappings_offsets_acc[t + 1];
const auto capacity = index_remappings_end - index_remappings_start;

const auto indices_start = offsets_acc[t * B];
const auto indices_end = offsets_acc[(t + 1) * B];

if (capacity > 0) {
for (const auto i : c10::irange(indices_start, indices_end)) {
auto idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(index_t));
}
}
});
});

return dense_indices;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ using Tensor = at::Tensor;

namespace nbit {

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor64<int32_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor64<index_t, 2, at::RestrictPtrTraits>
hash_table,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_table_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
// uint32_t capacity = hash_table.size(0);
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
Expand All @@ -35,9 +36,9 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t table_start = hash_table_offsets[t];
const int64_t table_end = hash_table_offsets[t + 1];
Expand All @@ -51,20 +52,25 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
return;
}

using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const uint32_t subwarp_id = threadIdx.x / 4;
const uint32_t subwarp_tid = threadIdx.x % 4;
#ifdef USE_ROCM
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
#else
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);
#endif

for (int32_t l_start = 0; l_start + subwarp_id < L;
l_start += kWarpSize / 4) {
const int32_t idx = indices[indices_start + l_start + subwarp_id];
uint32_t slot_start =
pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
const index_t idx = indices[indices_start + l_start + subwarp_id];
hash_t slot_start =
pruned_hash_function(static_cast<hash_t>(idx)) % capacity;

while (true) {
const uint32_t slot = (slot_start + subwarp_tid) % capacity;
const hash_t slot = (slot_start + subwarp_tid) % capacity;
const int2 val = *reinterpret_cast<const int2*>(
&hash_table[table_start + static_cast<int64_t>(slot)][0]);
const int32_t slot_sparse_idx = val.x;
Expand All @@ -78,6 +84,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
found = true;
dense_indices[indices_start + l_start + subwarp_id] = slot_dense_idx;
}

if (__any_sync(subwarp_mask, found)) {
break;
} else if (__any_sync(subwarp_mask, empty)) {
Expand All @@ -89,56 +96,60 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
const int32_t B,
const int32_t T,
pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
dense_indices) {
const int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
const int32_t t = b_t / B;
const int32_t b = b_t % B;
if (b_t >= B * T) {
return;
}
const int32_t indices_start = offsets[t * B + b];
const int32_t indices_end = offsets[t * B + b + 1];
const int32_t L = indices_end - indices_start;
const index_t indices_start = offsets[t * B + b];
const index_t indices_end = offsets[t * B + b + 1];
const index_t L = indices_end - indices_start;

const int64_t index_remappings_start = index_remappings_offsets[t];
const int64_t index_remappings_end = index_remappings_offsets[t + 1];
const int64_t capacity = index_remappings_end - index_remappings_start;

if (capacity > 0) {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
int32_t idx = indices[indices_start + l];
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
index_t idx = indices[indices_start + l];
dense_indices[indices_start + l] =
index_remappings[index_remappings_start + idx];
}
} else {
for (int32_t l = threadIdx.x; l < L; l += blockDim.x) {
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
dense_indices[indices_start + l] = indices[indices_start + l];
}
}
}

} // namespace nbit

using namespace nbit;

Tensor pruned_hashmap_lookup_cuda(
Tensor indices,
Tensor offsets,
Tensor hash_table,
Tensor hash_table_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, hash_table, hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -149,23 +160,25 @@ Tensor pruned_hashmap_lookup_cuda(
TORCH_CHECK(hash_table.size(0) < std::numeric_limits<int32_t>::max());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
#endif

nbit::int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, int32_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<<
nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
});

C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
Expand All @@ -178,6 +191,7 @@ Tensor pruned_array_lookup_cuda(
Tensor index_remappings_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, index_remappings, index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -204,23 +218,26 @@ Tensor pruned_array_lookup_cuda(
TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim());
constexpr size_t kForwardMaxThreads = 256;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup", [&] {
#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel";
const auto func_name =
"int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel";
#endif

nbit::int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, int32_t, 1, 32));
int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<<
nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize),
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32));
});

C10_CUDA_KERNEL_LAUNCH_CHECK();
return dense_indices;
}
Loading
Loading