Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [4/N]
Browse files Browse the repository at this point in the history
Summary: - Convert `pruned_array_lookup_cpu` to use `index_t`

Differential Revision: D62470736
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 12, 2024
1 parent 3de1f11 commit de30b80
Showing 1 changed file with 33 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ for (const auto l : c10::irange(L)) {
}

{% if not weighted %}

Tensor pruned_array_lookup_cpu(
Tensor indices,
Tensor offsets,
Expand All @@ -469,33 +470,41 @@ Tensor pruned_array_lookup_cpu(
int32_t T = index_remappings_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

auto dense_indices = empty_like(indices);
const auto* indices_acc = indices.data_ptr<int32_t>();
auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();
const auto* offsets_acc = offsets.data_ptr<int32_t>();

const auto index_remappings_acc = index_remappings.data_ptr<int32_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();
at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
int64_t index_remappings_start = index_remappings_offsets_acc[t];
int64_t index_remappings_end = index_remappings_offsets_acc[t + 1];
int64_t capacity = index_remappings_end - index_remappings_start;
int32_t indices_start = offsets_acc[t * B];
int32_t indices_end = offsets_acc[(t + 1) * B];
if (capacity > 0) {
for (const auto i : c10::irange(indices_start,indices_end)) {
int32_t idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(int32_t));
}
}
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

const auto index_remappings_acc = index_remappings.data_ptr<index_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();

at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
const auto index_remappings_start = index_remappings_offsets_acc[t];
const auto index_remappings_end = index_remappings_offsets_acc[t + 1];
const auto capacity = index_remappings_end - index_remappings_start;

const auto indices_start = offsets_acc[t * B];
const auto indices_end = offsets_acc[(t + 1) * B];

if (capacity > 0) {
for (const auto i : c10::irange(indices_start, indices_end)) {
auto idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(index_t));
}
}
});
});

return dense_indices;
}

Expand Down

0 comments on commit de30b80

Please sign in to comment.