Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [7B/N] (#…
Browse files Browse the repository at this point in the history
…3189)

Summary:
X-link: facebookresearch/FBGEMM#284

Pull Request resolved: #3189

- Fix `index_remapping` in `pruned_array_lookup` to be `int64_t` always, since it is set up before we are provided `indices` and `offsets`

Differential Revision: D63567553
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 1, 2024
1 parent 4a8b5d0 commit 16acb01
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,54 +65,64 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
TENSOR_ON_CPU(hash_table);
TENSOR_ON_CPU(hash_table_offsets);

int32_t T = hash_table_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
const int32_t T = hash_table_offsets.size(0) - 1;
const int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);
const auto* indices_acc = indices.data_ptr<int32_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();

const auto* offsets_acc = offsets.data_ptr<int32_t>();
auto hash_table_acc = hash_table.accessor<int32_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
for (const auto t : c10::irange(T)) {
int64_t table_start = hash_table_offsets_acc[t];
int64_t table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
int64_t capacity = table_end - table_start;
for (const auto b : c10::irange(B)) {
int32_t indices_start = offsets_acc[t * B + b];
int32_t indices_end = offsets_acc[t * B + b + 1];
int32_t L = indices_end - indices_start;
for (const auto l : c10::irange(L)) {
int32_t idx = indices_acc[indices_start + l];
int32_t dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}

uint32_t slot = pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
while (true) {
int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];
// empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][0] = idx;
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int64_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}
// already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;

auto slot = pruned_hash_function(static_cast<hash_t>(idx)) % capacity;
while (true) {
const auto slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];
// empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][0] = idx;
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;
}
// already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;
}
// linear probe
slot = (slot + 1) % capacity;
}
// linear probe
slot = (slot + 1) % capacity;
}
}
}
}
});

return;
}

Expand Down Expand Up @@ -414,7 +424,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(hash_table);
TENSOR_ON_CPU(hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = hash_table_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
Expand All @@ -428,9 +438,9 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(

const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();

const auto* offsets_acc = offsets.data_ptr<index_t>();
const auto hash_table_acc = hash_table.accessor<index_t, 2>();

const auto hash_table_acc = hash_table.accessor<int64_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
Expand Down Expand Up @@ -463,7 +473,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
}
// already exists
if (slot_sparse_idx == idx) {
dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast<int64_t>(slot)][1];
dense_indices_acc[indices_start + l] = static_cast<index_t>(hash_table_acc[table_start + static_cast<int64_t>(slot)][1]);
break;
}
// linear probe
Expand All @@ -489,7 +499,7 @@ Tensor pruned_array_lookup_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(index_remappings);
TENSOR_ON_CPU(index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = index_remappings_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
Expand All @@ -502,7 +512,7 @@ Tensor pruned_array_lookup_cpu(
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

const auto index_remappings_acc = index_remappings.data_ptr<index_t>();
const auto index_remappings_acc = index_remappings.data_ptr<int64_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();

at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
Expand All @@ -517,7 +527,7 @@ Tensor pruned_array_lookup_cpu(
if (capacity > 0) {
for (const auto i : c10::irange(indices_start, indices_end)) {
auto idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
dense_indices_acc[i] = static_cast<index_t>(index_remappings_acc[index_remappings_start + idx]);
}
} else {
std::memcpy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor64<index_t, 2, at::RestrictPtrTraits>
const pta::PackedTensorAccessor64<int64_t, 2, at::RestrictPtrTraits>
hash_table,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_table_offsets,
Expand Down Expand Up @@ -103,7 +103,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
Expand All @@ -129,7 +129,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
index_t idx = indices[indices_start + l];
dense_indices[indices_start + l] =
index_remappings[index_remappings_start + idx];
static_cast<index_t>(index_remappings[index_remappings_start + idx]);
}
} else {
for (index_t l = threadIdx.x; l < L; l += blockDim.x) {
Expand All @@ -149,7 +149,7 @@ Tensor pruned_hashmap_lookup_cuda(
Tensor hash_table_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, hash_table, hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

CUDA_DEVICE_GUARD(indices);

Expand All @@ -173,7 +173,7 @@ Tensor pruned_hashmap_lookup_cuda(
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table, int64_t, 2, 64),
MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32),
B,
T,
Expand All @@ -191,7 +191,7 @@ Tensor pruned_array_lookup_cuda(
Tensor index_remappings_offsets) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
indices, offsets, index_remappings, index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);
CUDA_DEVICE_GUARD(indices);
Expand Down Expand Up @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32),
B,
T,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,15 @@ def max_ty_D(ty: SparseType) -> int:
)
self.register_buffer(
"index_remappings_array",
torch.empty(0, device=self.current_device, dtype=torch.int32),
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
self.register_buffer(
"index_remapping_hash_table_offsets",
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
self.register_buffer(
"index_remapping_hash_table",
torch.empty(0, device=self.current_device, dtype=torch.int32),
torch.empty(0, device=self.current_device, dtype=torch.int64),
)
self.register_buffer(
"original_rows_per_table",
Expand Down Expand Up @@ -946,8 +946,9 @@ def reset_embedding_spec_location(
@torch.jit.export
def recompute_module_buffers(self) -> None:
"""
Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets().
Currently those buffers are `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`.
Compute module buffers that're on meta device and are not materialized
in reset_weights_placements_and_offsets(). Currently those buffers are
`weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`.
Pruning related or uvm related buffers are not computed right now.
"""
if (
Expand Down Expand Up @@ -1527,11 +1528,11 @@ def set_index_remappings_array(
index_remappings_filter_nones.append(mapping)
if len(index_remappings_filter_nones) == 0:
self.index_remappings_array = torch.empty(
0, dtype=torch.int32, device=self.current_device
0, dtype=torch.int64, device=self.current_device
)
else:
self.index_remappings_array = torch.cat(index_remappings_filter_nones).to(
self.current_device
dtype=torch.int64, device=self.current_device
)

def set_index_remappings(
Expand All @@ -1554,7 +1555,7 @@ def set_index_remappings(
]
hash_table = torch.empty(
(sum(capacities), 2),
dtype=torch.int32,
dtype=torch.int64,
)
hash_table[:, :] = -1
hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long()
Expand Down
17 changes: 12 additions & 5 deletions fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]:


# @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators)
class NBitFowardTest(unittest.TestCase):
class NBitFowardAutovecTest(unittest.TestCase):
def execute_nbit_forward_( # noqa C901
self,
T: int,
Expand All @@ -105,6 +105,7 @@ def execute_nbit_forward_( # noqa C901
use_array_for_index_remapping: bool,
do_pruning: bool,
mixed_weights_ty: bool,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
# NOTE: weighted operation can be done only for SUM.
Expand Down Expand Up @@ -311,19 +312,22 @@ def execute_nbit_forward_( # noqa C901
fp8_config=fp8_config if has_fp8_weight else None,
)

indices = indices.to(dtype=indices_dtype)
offsets = offsets.to(dtype=indices_dtype)

if not use_cpu:
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1))
else cc(indices, offsets, xw.contiguous().view(-1))
)
else:
cc = cc.cpu()
indices, offsets = indices.cpu(), offsets.cpu()
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu())
else cc(indices, offsets, xw.contiguous().view(-1).cpu())
)

if do_pooling and B == 0:
Expand Down Expand Up @@ -373,6 +377,7 @@ def execute_nbit_forward_( # noqa C901
pooling_mode=st.sampled_from(
[PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE]
),
indices_dtype=st.sampled_from([torch.int32, torch.int64]),
output_dtype=st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
),
Expand All @@ -386,6 +391,7 @@ def test_nbit_forward_cpu_autovec(
self,
nbit_weights_ty: Optional[SparseType],
pooling_mode: PoolingMode,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
use_cpu = True
Expand Down Expand Up @@ -432,6 +438,7 @@ def test_nbit_forward_cpu_autovec(
False,
False,
mixed_weights_ty,
indices_dtype,
output_dtype,
)

Expand Down
Loading

0 comments on commit 16acb01

Please sign in to comment.