Skip to content

Commit

Permalink
Add int4 to int4 CPU Sequence TBE kernel
Browse files Browse the repository at this point in the history
Differential Revision: D61305980
  • Loading branch information
wsu authored and facebook-github-bot committed Aug 17, 2024
1 parent 4622a72 commit 155e8f1
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 72 deletions.
6 changes: 4 additions & 2 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ FBGEMM_API typename EmbeddingSpMDMKernelSignature<
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDMNBitWithStrides(
int bit_rate,
const int input_bit_rate,
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
Expand All @@ -169,7 +169,9 @@ GenerateEmbeddingSpMDMNBitWithStrides(
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
bool scale_bias_last = true,
bool is_bf16_out = false);
const bool is_bf16_out = false,
const bool no_bag = false,
int output_bit_rate = -1);

/**
* @param output_stride If -1, output_stride is same as block_size
Expand Down
35 changes: 35 additions & 0 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <algorithm>
#include <array>
#include <cassert>
#include <cmath>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -416,4 +417,38 @@ FBGEMM_API bool is_autovec_disabled();
FBGEMM_API bool is_autovec_forced();
FBGEMM_API bool is_asmjit_disabled();

/**
* @brief A function to check if the input parameter in the nbit CPU TBE kernel
* is valid.
*/
template <typename OutType>
void nbit_embedding_sanity_check(
// assertions are ignored in release mode, in which case these parameters
// will be unused
[[maybe_unused]] const int input_bit_rate,
[[maybe_unused]] const int output_bit_rate,
[[maybe_unused]] const bool no_bag) {
assert(
(input_bit_rate == 2 || input_bit_rate == 4) &&
"input_bit_rate must be 2 or 4");
if (std::is_same<OutType, uint8_t>::value) {
assert(
(no_bag && input_bit_rate == 4 && output_bit_rate == 4) &&
"we currently only support int4 to int4 for sequential TBE");
} else {
assert(
(output_bit_rate == 8 * sizeof(OutType)) &&
"output_bit_rate should be equal to 8 * sizeof(OutType)");
}
}

#define WARN_ONCE(...) \
do { \
static bool _warned = false; \
if (!_warned) { \
_warned = true; \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)

} // namespace fbgemm
50 changes: 39 additions & 11 deletions src/EmbeddingSpMDMAutovec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ INSTANTIATE_SPMDM_INDEX_T()

template <typename IndexType, typename OffsetType, typename OutType>
bool EmbeddingSpMDMNBit_autovec(
const int bit_rate,
const int input_bit_rate,
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
Expand All @@ -289,9 +289,14 @@ bool EmbeddingSpMDMNBit_autovec(
int64_t output_stride /*=-1*/,
int64_t input_stride /*=-1*/,
const bool scale_bias_last /*=true*/,
const bool is_bf16_out /*=false*/) {
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
const int num_elem_per_byte = 8 / bit_rate;
const bool is_bf16_out /*=false*/,
const bool no_bag /*=false*/,
int output_bit_rate /*=-1*/) {
if (output_bit_rate == -1) {
output_bit_rate = 8 * sizeof(OutType);
}
nbit_embedding_sanity_check<OutType>(input_bit_rate, output_bit_rate, no_bag);
const int num_elem_per_byte = 8 / input_bit_rate;

if (output_stride == -1) {
output_stride = block_size;
Expand Down Expand Up @@ -335,6 +340,26 @@ bool EmbeddingSpMDMNBit_autovec(
}
}

if (no_bag) {
// We currently only support int4 to int4 for sequential TBE in this nbit
// kernel. Note that assert() will be ignored in release mode, so we check
// here to double check and also avoid "unused variable" warning
if (!(input_bit_rate == 4 && output_bit_rate == 4)) {
WARN_ONCE("no_bag is only supported for int4 to int4");
return false;
}
for (int64_t i = 0; i < output_size; ++i) {
const auto idx = indices[i];
if (idx < 0 || idx > data_size) {
return false;
}
const uint8_t* input_row = input + input_stride * idx;
memcpy(out, input_row, sizeof(uint8_t) * input_stride);
out += input_stride;
}
return true;
}

int64_t current = 0;
const int64_t rounded_bs = round_up(block_size, num_elem_per_byte);
vector<float> buf(rounded_bs);
Expand Down Expand Up @@ -387,15 +412,15 @@ bool EmbeddingSpMDMNBit_autovec(
const int64_t offset =
input_stride * idx + (scale_bias_last ? 0 : scale_bias_offset);
const uint8_t* input_row = input + offset;
if (bit_rate == 4) {
if (input_bit_rate == 4) {
const size_t halfbufsz = (block_size + 1) / 2;
for (size_t j = 0; j < halfbufsz; ++j) {
float quantized1 = float(input_row[j] & 0xf);
float quantized2 = float(input_row[j] >> 4);
buf[j * 2] = std::fma(scale, quantized1, buf[j * 2] + bias);
buf[j * 2 + 1] = std::fma(scale, quantized2, buf[j * 2 + 1] + bias);
}
} else if (bit_rate == 2) {
} else if (input_bit_rate == 2) {
size_t qbufsz = (block_size + 3) / 4;
const uint8_t mask1 = 0x3;
const uint8_t mask2 = 0xC;
Expand Down Expand Up @@ -445,7 +470,7 @@ bool EmbeddingSpMDMNBit_autovec(

#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
template FBGEMM_API bool EmbeddingSpMDMNBit_autovec( \
const int bit_rate, \
const int input_bit_rate, \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
Expand All @@ -461,11 +486,14 @@ bool EmbeddingSpMDMNBit_autovec(
int64_t output_stride, \
int64_t input_stride, \
const bool scale_bias_last, \
const bool is_bf16_out);
const bool is_bf16_out, \
const bool no_bag, \
int output_bit_rate);

#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16)
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t)

#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \
Expand Down
6 changes: 4 additions & 2 deletions src/EmbeddingSpMDMAutovec.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ template <
typename OffsetType = std::int32_t,
typename OutType = float>
FBGEMM_API bool EmbeddingSpMDMNBit_autovec(
const int bit_rate,
const int input_bit_rate,
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
Expand All @@ -67,7 +67,9 @@ FBGEMM_API bool EmbeddingSpMDMNBit_autovec(
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
const bool scale_bias_last = true,
const bool is_bf16_out = false);
const bool is_bf16_out = false,
const bool no_bag = false,
int output_bit_rate = -1);

} // namespace fbgemm

Expand Down
99 changes: 89 additions & 10 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ template <
typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
Type
GenerateEmbeddingSpMDMNBitWithStrides(
int bit_rate,
const int input_bit_rate,
const int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
Expand All @@ -1032,8 +1032,20 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
int64_t output_stride /*=-1*/,
int64_t input_stride /*=-1*/,
bool scale_bias_last /*=true*/,
bool is_bf16_out) {
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
const bool is_bf16_out /*=false*/,
const bool no_bag /*=false*/,
int output_bit_rate /*=-1*/) {
if (output_bit_rate == -1) {
output_bit_rate = input_bit_rate;
}
assert(
(input_bit_rate == 2 || input_bit_rate == 4) &&
"input_bit_rate must be 2 or 4");
if (std::is_same<outType, uint8_t>::value) {
assert(
(no_bag && input_bit_rate == 4 && output_bit_rate == 4) &&
"we currently only support int4 to int4 when using sequential TBE");
}

if (!cpuinfo_initialize()) {
throw runtime_error("Failed to initialize cpuinfo!");
Expand All @@ -1042,10 +1054,74 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
output_stride = block_size;
}
if (input_stride == -1) {
int64_t num_elem_per_byte = 8 / bit_rate;
int64_t num_elem_per_byte = 8 / input_bit_rate;
input_stride =
ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(uint16_t);
}
if (no_bag) {
if (!is_autovec_disabled()) {
return [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
const uint8_t* input,
const indxType* indices,
const offsetType* offsets_or_lengths,
const float* weights,
outType* out) {
return EmbeddingSpMDMNBit_autovec(
input_bit_rate,
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets_or_lengths,
weights,
normalize_by_lengths,
out,
is_weight_positional,
use_offsets,
output_stride,
input_stride,
scale_bias_last,
is_bf16_out,
no_bag,
output_bit_rate);
};
} else {
return [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
const uint8_t* input,
const indxType* indices,
const offsetType* offsets_or_lengths,
const float* weights,
outType* out) {
return EmbeddingSpMDMNBit_ref(
input_bit_rate,
block_size,
output_size,
index_size,
data_size,
input,
indices,
offsets_or_lengths,
weights,
normalize_by_lengths,
out,
is_weight_positional,
use_offsets,
output_stride,
input_stride,
scale_bias_last,
is_bf16_out,
no_bag,
output_bit_rate);
};
}
}

if (fbgemmHasAvx512Support() && !is_asmjit_disabled()) {
static GenEmbeddingSpMDMNBitLookup<
indxType,
Expand All @@ -1056,7 +1132,7 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
THREAD_LOCAL>
kernel_generator;
const auto original_func = kernel_generator.getOrCreate(
bit_rate,
input_bit_rate,
block_size,
has_weight,
is_weight_positional,
Expand Down Expand Up @@ -1096,7 +1172,7 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
THREAD_LOCAL>
kernel_generator;
const auto original_func = kernel_generator.getOrCreate(
bit_rate,
input_bit_rate,
block_size,
has_weight,
is_weight_positional,
Expand Down Expand Up @@ -1139,7 +1215,7 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
const float* weights,
outType* out) {
return EmbeddingSpMDMNBit_autovec(
bit_rate,
input_bit_rate,
block_size,
output_size,
index_size,
Expand Down Expand Up @@ -1171,7 +1247,7 @@ typename EmbeddingSpMDMKernelSignature<uint8_t, indxType, offsetType, outType>::
const float* weights,
outType* out) {
return EmbeddingSpMDMNBit_ref(
bit_rate,
input_bit_rate,
block_size,
output_size,
index_size,
Expand Down Expand Up @@ -1364,7 +1440,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
OFFSET_TYPE, \
OUT_TYPE, \
THREAD_LOCAL>( \
int bit_rate, \
const int input_bit_rate, \
const int64_t block_size, \
bool has_weight, \
bool normalize_by_lengths, \
Expand All @@ -1374,7 +1450,9 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
int64_t output_stride, \
int64_t input_stride, \
bool scale_bias_last, \
bool is_bf16_out);
const bool is_bf16_out, \
const bool no_bag, \
int output_bit_rate);

#define INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
Expand All @@ -1396,6 +1474,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint16_t) \
INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint8_t) \
template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
uint8_t, \
INDEX_TYPE, \
Expand Down
Loading

0 comments on commit 155e8f1

Please sign in to comment.