From 155e8f1a8ef593d5365beff00f944fafac39d6c4 Mon Sep 17 00:00:00 2001 From: wsu Date: Fri, 16 Aug 2024 17:13:04 -0700 Subject: [PATCH] Add int4 to int4 CPU Sequence TBE kernel Differential Revision: D61305980 --- include/fbgemm/FbgemmEmbedding.h | 6 +- include/fbgemm/Utils.h | 35 +++++++++ src/EmbeddingSpMDMAutovec.cc | 50 ++++++++++--- src/EmbeddingSpMDMAutovec.h | 6 +- src/EmbeddingSpMDMNBit.cc | 99 ++++++++++++++++++++++--- src/RefImplementations.cc | 120 +++++++++++++++++++------------ src/RefImplementations.h | 8 ++- 7 files changed, 252 insertions(+), 72 deletions(-) diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index f787a637e3..15d287a54b 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -159,7 +159,7 @@ FBGEMM_API typename EmbeddingSpMDMKernelSignature< OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBitWithStrides( - int bit_rate, + const int input_bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, @@ -169,7 +169,9 @@ GenerateEmbeddingSpMDMNBitWithStrides( std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true, - bool is_bf16_out = false); + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); /** * @param output_stride If -1, output_stride is same as block_size diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 69b2e6d94f..6ad32bf860 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -416,4 +417,38 @@ FBGEMM_API bool is_autovec_disabled(); FBGEMM_API bool is_autovec_forced(); FBGEMM_API bool is_asmjit_disabled(); +/** + * @brief A function to check if the input parameter in the nbit CPU TBE kernel + * is valid. + */ +template +void nbit_embedding_sanity_check( + // assertions are ignored in release mode, in which case these parameters + // will be unused + [[maybe_unused]] const int input_bit_rate, + [[maybe_unused]] const int output_bit_rate, + [[maybe_unused]] const bool no_bag) { + assert( + (input_bit_rate == 2 || input_bit_rate == 4) && + "input_bit_rate must be 2 or 4"); + if (std::is_same::value) { + assert( + (no_bag && input_bit_rate == 4 && output_bit_rate == 4) && + "we currently only support int4 to int4 for sequential TBE"); + } else { + assert( + (output_bit_rate == 8 * sizeof(OutType)) && + "output_bit_rate should be equal to 8 * sizeof(OutType)"); + } +} + +#define WARN_ONCE(...) \ + do { \ + static bool _warned = false; \ + if (!_warned) { \ + _warned = true; \ + fprintf(stderr, __VA_ARGS__); \ + } \ + } while (0) + } // namespace fbgemm diff --git a/src/EmbeddingSpMDMAutovec.cc b/src/EmbeddingSpMDMAutovec.cc index 3247c9026b..67398ab461 100644 --- a/src/EmbeddingSpMDMAutovec.cc +++ b/src/EmbeddingSpMDMAutovec.cc @@ -273,7 +273,7 @@ INSTANTIATE_SPMDM_INDEX_T() template bool EmbeddingSpMDMNBit_autovec( - const int bit_rate, + const int input_bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -289,9 +289,14 @@ bool EmbeddingSpMDMNBit_autovec( int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, const bool scale_bias_last /*=true*/, - const bool is_bf16_out /*=false*/) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); - const int num_elem_per_byte = 8 / bit_rate; + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = 8 * sizeof(OutType); + } + nbit_embedding_sanity_check(input_bit_rate, output_bit_rate, no_bag); + const int num_elem_per_byte = 8 / input_bit_rate; if (output_stride == -1) { output_stride = block_size; @@ -335,6 +340,26 @@ bool EmbeddingSpMDMNBit_autovec( } } + if (no_bag) { + // We currently only support int4 to int4 for sequential TBE in this nbit + // kernel. Note that assert() will be ignored in release mode, so we check + // here to double check and also avoid "unused variable" warning + if (!(input_bit_rate == 4 && output_bit_rate == 4)) { + WARN_ONCE("no_bag is only supported for int4 to int4"); + return false; + } + for (int64_t i = 0; i < output_size; ++i) { + const auto idx = indices[i]; + if (idx < 0 || idx > data_size) { + return false; + } + const uint8_t* input_row = input + input_stride * idx; + memcpy(out, input_row, sizeof(uint8_t) * input_stride); + out += input_stride; + } + return true; + } + int64_t current = 0; const int64_t rounded_bs = round_up(block_size, num_elem_per_byte); vector buf(rounded_bs); @@ -387,7 +412,7 @@ bool EmbeddingSpMDMNBit_autovec( const int64_t offset = input_stride * idx + (scale_bias_last ? 0 : scale_bias_offset); const uint8_t* input_row = input + offset; - if (bit_rate == 4) { + if (input_bit_rate == 4) { const size_t halfbufsz = (block_size + 1) / 2; for (size_t j = 0; j < halfbufsz; ++j) { float quantized1 = float(input_row[j] & 0xf); @@ -395,7 +420,7 @@ bool EmbeddingSpMDMNBit_autovec( buf[j * 2] = std::fma(scale, quantized1, buf[j * 2] + bias); buf[j * 2 + 1] = std::fma(scale, quantized2, buf[j * 2 + 1] + bias); } - } else if (bit_rate == 2) { + } else if (input_bit_rate == 2) { size_t qbufsz = (block_size + 3) / 4; const uint8_t mask1 = 0x3; const uint8_t mask2 = 0xC; @@ -445,7 +470,7 @@ bool EmbeddingSpMDMNBit_autovec( #define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API bool EmbeddingSpMDMNBit_autovec( \ - const int bit_rate, \ + const int input_bit_rate, \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ @@ -461,11 +486,14 @@ bool EmbeddingSpMDMNBit_autovec( int64_t output_stride, \ int64_t input_stride, \ const bool scale_bias_last, \ - const bool is_bf16_out); + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); -#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) +#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) #define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \ diff --git a/src/EmbeddingSpMDMAutovec.h b/src/EmbeddingSpMDMAutovec.h index 997c02cdc3..2e7632c785 100644 --- a/src/EmbeddingSpMDMAutovec.h +++ b/src/EmbeddingSpMDMAutovec.h @@ -51,7 +51,7 @@ template < typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API bool EmbeddingSpMDMNBit_autovec( - const int bit_rate, + const int input_bit_rate, const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, @@ -67,7 +67,9 @@ FBGEMM_API bool EmbeddingSpMDMNBit_autovec( std::int64_t output_stride = -1, std::int64_t input_stride = -1, const bool scale_bias_last = true, - const bool is_bf16_out = false); + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); } // namespace fbgemm diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index f916c568b7..c0e4429bb5 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -1022,7 +1022,7 @@ template < typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDMNBitWithStrides( - int bit_rate, + const int input_bit_rate, const int64_t block_size, bool has_weight, bool normalize_by_lengths, @@ -1032,8 +1032,20 @@ typename EmbeddingSpMDMKernelSignature:: int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, - bool is_bf16_out) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = input_bit_rate; + } + assert( + (input_bit_rate == 2 || input_bit_rate == 4) && + "input_bit_rate must be 2 or 4"); + if (std::is_same::value) { + assert( + (no_bag && input_bit_rate == 4 && output_bit_rate == 4) && + "we currently only support int4 to int4 when using sequential TBE"); + } if (!cpuinfo_initialize()) { throw runtime_error("Failed to initialize cpuinfo!"); @@ -1042,10 +1054,74 @@ typename EmbeddingSpMDMKernelSignature:: output_stride = block_size; } if (input_stride == -1) { - int64_t num_elem_per_byte = 8 / bit_rate; + int64_t num_elem_per_byte = 8 / input_bit_rate; input_stride = ceil_div(block_size, num_elem_per_byte) + 2 * sizeof(uint16_t); } + if (no_bag) { + if (!is_autovec_disabled()) { + return [=](int64_t output_size, + int64_t index_size, + int64_t data_size, + const uint8_t* input, + const indxType* indices, + const offsetType* offsets_or_lengths, + const float* weights, + outType* out) { + return EmbeddingSpMDMNBit_autovec( + input_bit_rate, + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets_or_lengths, + weights, + normalize_by_lengths, + out, + is_weight_positional, + use_offsets, + output_stride, + input_stride, + scale_bias_last, + is_bf16_out, + no_bag, + output_bit_rate); + }; + } else { + return [=](int64_t output_size, + int64_t index_size, + int64_t data_size, + const uint8_t* input, + const indxType* indices, + const offsetType* offsets_or_lengths, + const float* weights, + outType* out) { + return EmbeddingSpMDMNBit_ref( + input_bit_rate, + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets_or_lengths, + weights, + normalize_by_lengths, + out, + is_weight_positional, + use_offsets, + output_stride, + input_stride, + scale_bias_last, + is_bf16_out, + no_bag, + output_bit_rate); + }; + } + } + if (fbgemmHasAvx512Support() && !is_asmjit_disabled()) { static GenEmbeddingSpMDMNBitLookup< indxType, @@ -1056,7 +1132,7 @@ typename EmbeddingSpMDMKernelSignature:: THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( - bit_rate, + input_bit_rate, block_size, has_weight, is_weight_positional, @@ -1096,7 +1172,7 @@ typename EmbeddingSpMDMKernelSignature:: THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( - bit_rate, + input_bit_rate, block_size, has_weight, is_weight_positional, @@ -1139,7 +1215,7 @@ typename EmbeddingSpMDMKernelSignature:: const float* weights, outType* out) { return EmbeddingSpMDMNBit_autovec( - bit_rate, + input_bit_rate, block_size, output_size, index_size, @@ -1171,7 +1247,7 @@ typename EmbeddingSpMDMKernelSignature:: const float* weights, outType* out) { return EmbeddingSpMDMNBit_ref( - bit_rate, + input_bit_rate, block_size, output_size, index_size, @@ -1364,7 +1440,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( OFFSET_TYPE, \ OUT_TYPE, \ THREAD_LOCAL>( \ - int bit_rate, \ + const int input_bit_rate, \ const int64_t block_size, \ bool has_weight, \ bool normalize_by_lengths, \ @@ -1374,7 +1450,9 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( int64_t output_stride, \ int64_t input_stride, \ bool scale_bias_last, \ - bool is_bf16_out); + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); #define INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ @@ -1396,6 +1474,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint16_t) \ + INSTANTIATE_SPMDM_THREAD_LOCAL(INDEX_TYPE, OFFSET_TYPE, uint8_t) \ template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ uint8_t, \ INDEX_TYPE, \ diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 0399750b27..685a0f4a10 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -1413,7 +1413,7 @@ bool EmbeddingSpMDM_ref( template bool EmbeddingSpMDMNBit_ref( - int bit_rate, + int input_bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, @@ -1429,9 +1429,14 @@ bool EmbeddingSpMDMNBit_ref( int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, - bool is_bf16_out /*=false*/) { - assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); - int num_elem_per_byte = 8 / bit_rate; + const bool is_bf16_out /*=false*/, + const bool no_bag /*=false*/, + int output_bit_rate /*=-1*/) { + if (output_bit_rate == -1) { + output_bit_rate = 8 * sizeof(OutType); + } + nbit_embedding_sanity_check(input_bit_rate, output_bit_rate, no_bag); + int num_elem_per_byte = 8 / input_bit_rate; if (output_stride == -1) { output_stride = block_size; @@ -1444,6 +1449,27 @@ bool EmbeddingSpMDMNBit_ref( input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte + scale_bias_offset; } + + if (no_bag) { + // We currently only support int4 to int4 for sequential TBE in this nbit + // kernel. Note that assert() will be ignored in release mode, so we check + // here to double check and also avoid "unused variable" warning + if (!(input_bit_rate == 4 && output_bit_rate == 4)) { + WARN_ONCE("no_bag is only supported for int4 to int4"); + return false; + } + for (int64_t i = 0; i < output_size; ++i) { + const auto idx = indices[i]; + if (idx < 0 || idx > data_size) { + return false; + } + const uint8_t* input_row = input + input_stride * idx; + memcpy(out, input_row, sizeof(uint8_t) * input_stride); + out += input_stride; + } + return true; + } + int64_t current = 0; vector buf(block_size); for (int m = 0; m < output_size; ++m) { @@ -1481,8 +1507,8 @@ bool EmbeddingSpMDMNBit_ref( uint8_t quantized = input [input_stride * idx + j / num_elem_per_byte + (scale_bias_last ? 0 : scale_bias_offset)]; - quantized >>= (j % num_elem_per_byte) * bit_rate; - quantized &= (1 << bit_rate) - 1; + quantized >>= (j % num_elem_per_byte) * input_bit_rate; + quantized &= (1 << input_bit_rate) - 1; buf[j] = std::fma(scale, quantized, buf[j] + bias); } @@ -2105,47 +2131,53 @@ INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_BASE -#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ - template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ - int bit_rate, \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const INDEX_TYPE* indices, \ - const OFFSET_TYPE* offsets_or_lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OUT_TYPE* out, \ - bool is_weight_positional, \ - bool use_offsets, \ - int64_t output_stride, \ - int64_t input_stride, \ - bool scale_bias_last, \ - bool is_bf16_out); \ - template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ - const uint8_t* input, \ - const INDEX_TYPE* indices, \ - const OFFSET_TYPE* offsets_or_lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OUT_TYPE* out, \ - bool is_weight_positional, \ - bool use_offsets, \ - int64_t output_stride, \ - int64_t input_stride, \ - int exponent_bits, \ - int exponent_bias, \ +#define INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ + template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ + const int input_bit_rate, \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const INDEX_TYPE* indices, \ + const OFFSET_TYPE* offsets_or_lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OUT_TYPE* out, \ + bool is_weight_positional, \ + bool use_offsets, \ + int64_t output_stride, \ + int64_t input_stride, \ + const bool scale_bias_last, \ + const bool is_bf16_out, \ + const bool no_bag, \ + int output_bit_rate); +#define INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ + template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ + const uint8_t* input, \ + const INDEX_TYPE* indices, \ + const OFFSET_TYPE* offsets_or_lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OUT_TYPE* out, \ + bool is_weight_positional, \ + bool use_offsets, \ + int64_t output_stride, \ + int64_t input_stride, \ + int exponent_bits, \ + int exponent_bias, \ bool is_bf16_out); #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ + INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) \ template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \ int bit_rate, \ const int64_t block_size, \ diff --git a/src/RefImplementations.h b/src/RefImplementations.h index f01aa57d5a..076e4e7fc4 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -246,7 +246,7 @@ template < typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API bool EmbeddingSpMDMNBit_ref( - int bit_rate, + const int input_bit_rate, const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, @@ -261,8 +261,10 @@ FBGEMM_API bool EmbeddingSpMDMNBit_ref( bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, - bool scale_bias_last = true, - bool is_bf16_out = false); + const bool scale_bias_last = true, + const bool is_bf16_out = false, + const bool no_bag = false, + int output_bit_rate = -1); template < typename IndexType = std::int64_t,