From 6a14a2a7dfeb7575e8028472fceeb9f5096c346a Mon Sep 17 00:00:00 2001 From: wsu Date: Wed, 14 Aug 2024 16:54:55 -0700 Subject: [PATCH] Enable int4 to int4 CPU STBE in fbgemm_gpu TBE API Differential Revision: D61305978 --- ...bedding_forward_quantized_cpu_template.cpp | 28 +++++++++++++------ .../fbgemm_gpu/utils/dispatch_macros.h | 2 ++ include/fbgemm/FbgemmEmbedding.h | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index e6d27d1c5e..f5d47961b4 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -167,9 +167,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Tensor output; SparseType o_dtype = static_cast(output_dtype); - TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16); + TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT4); bool output_is_bf16 = o_dtype == SparseType::BF16; bool output_is_int8 = o_dtype == SparseType::INT8; + bool output_is_int4 = o_dtype == SparseType::INT4; {% if not nobag %} const int kINT8QparamsBytes = 8; int64_t total_adjusted_D = total_D; @@ -178,10 +179,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ } output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory)); {% else %} - const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout + constexpr int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout + constexpr int kINT4QparamsElems = 8; // scale + bias takes 4 bytes which are 8 int4 elements int64_t adjusted_D = D; if (o_dtype == SparseType::INT8) { adjusted_D += kINT8QparamsBytes; + } else if (o_dtype == SparseType::INT4) { + adjusted_D += kINT4QparamsElems; } output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory)); @@ -212,7 +216,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ using other_fbgemm_out_t = typename std::conditional< std::is_same::value, float16, - std::conditional::value, bfloat16, float>::type >::type; + std::conditional::value, bfloat16, float>::type> ::type; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] { const auto* indices_acc = indices.data_ptr(); const auto* offsets_acc = offsets.data_ptr(); @@ -230,7 +234,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int32_t D_end = D_offsets_acc[t + 1]; const int32_t D = D_end - D_start; {% else %} - const int32_t D_start = offsets_acc[t * B] * adjusted_D; + const int32_t elems_D = (o_dtype == SparseType::INT4) ? at::divup(adjusted_D, 2) : adjusted_D; + const int32_t D_start = offsets_acc[t * B] * elems_D; {% endif %} const auto placement = static_cast(weights_placements_ptr[t]); @@ -266,8 +271,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ {% endif %} const float* indice_weights_ptr = nullptr; - // int8 output only enabled for nobag case with ref impl - const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }}; + // int8/int4 output only enabled for nobag case + const bool nobag_op = {{ "false" if not nobag else "output_is_int8 || output_is_int4" }}; {% if weighted %} indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr; {% endif %} @@ -278,7 +283,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides" if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides") %} - using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }}; + using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base or use_nbit else "other_fbgemm_out_t" }}; + {% if use_nbit %} + const int output_bit_rate = output_is_int4 ? 4 : sizeof(fbgemm_out_t) * 8; + {% endif %} // TODO: merge nobag int8 path with normal asmjit dispatch {% if nobag %} const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr; @@ -299,7 +307,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ {% endif %} >( {% if use_nbit %} - /*bit_rate=*/bit_rate, + /*input_bit_rate=*/bit_rate, {% endif %} D, {% if has_asmjit %} @@ -324,6 +332,10 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ /*no_bag=*/nobag_op, {% endif %} /*is_bf16_out=*/output_is_bf16 + {% if use_nbit %} + ,/*no_bag=*/nobag_op, + /*output_bit_rate=*/output_bit_rate + {% endif %} ); success = kernel( {{ "B" if not nobag else "index_size"}}, diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h b/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h index 4e56d4369c..41efac8dd5 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/dispatch_macros.h @@ -122,6 +122,8 @@ at::ScalarType::BFloat16, at::BFloat16, __VA_ARGS__) \ PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Float, float, __VA_ARGS__) \ PRIVATE_CASE_TYPE_OUTPUT2(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + PRIVATE_CASE_TYPE_OUTPUT2( \ + at::ScalarType::QUInt4x2, uint8_t, __VA_ARGS__) \ default: \ AT_ERROR( \ #NAME, \ diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index d0a44f296c..5b9fb3d08b 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -171,7 +171,7 @@ GenerateEmbeddingSpMDMNBitWithStrides( bool scale_bias_last = true, const bool is_bf16_out = false, const bool no_bag = false, - const bool output_bit_rate = 8 * sizeof(OutType)); + const int output_bit_rate = 8 * sizeof(OutType)); /** * @param output_stride If -1, output_stride is same as block_size