Skip to content

Commit

Permalink
Mlas int4 int8 with avx2/512 (#20687)
Browse files Browse the repository at this point in the history
### Description
model: phi-3-mini-4k-instruct
avx2 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |49.5|70.0|-29.2%|9.6|10.8|-34.2%
32 |76.8|52.4|9.7%|15.2|14.6|4.1%
64 |78.2|71.4|9.5%|16.6|16.3|1.8%
128 |72.9|70.6|3.2%|17.1|16.8|1.7%
256 |83.7|63.6|31.6%|18.1|17.4|4%

avx2 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |50.7|61.5|-17.5%|9.6|9.2|4.3%
32 |77.4|52.4|47.7%|14.6|13.9|5.0%
64 |78.7|63.0|24.9%|16.2|15.9|1.8%
128 |80.0|61.9|29.2%|17.2|16.9|1.7%
256 |81.5|63.3|28.7%|17.9|17.3|3.4%

avx2vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |82.9|117.0|-29.0%|15.9|19.3|-17.6%
32 |133.0|100.4|32.4%|26.1|24.5|6.5%
64 |166.9|118.8|40.4%|28.3|27.1|4.4%
128 |165.9|119.6|38.7%|29.3|28.5|2.8%
256 |165.2|119.6|38.1%|30.2|29.0|4.1%

avx2vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |80.2|118.9|-32.5%|15.1|16.7|-9.5%
32 |130.7|99.7|31.0%|25.0|23.8|5.0%
64 |168.7|124.9|35.0%|27.3|26.8|1.8%
128 |169.6|123.8|36.9%|29.2|27.9|4.6%
256 |175.0|125.7|39.0%|30.0|29.7|1.0%

avx512 symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |135.2|156.5|-13.6|25.5|23.8|7.1
32 |150.0|159.5|-5.9|34.9|29.6|17.9
64 |167.5|157.5|6.3|39.7|34.4|15.4
128 |177.8|158.0|12.5|40.3|35.4|13.8
256 |182.6|157.3|16.0|41.7|37.7|10.6

avx512 asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |136.1|151.4|-10.1%|26.1|19.9|31.1%
32 |150.0|157.8|-4.9%|34.3|29.3|17.0%
64 |165.7|156.6|5.8%|38.7|30.7|26.0%
128 |180.4|156.6|15.1%|40.2|34.7|15.8%
256 |181.3|158.0|14.7%|41.6|36.6|13.6%

avx512vnni symmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |143.4|155.4|-7.7%|25.6|23.3|9.8%
32 |159.2|157.0|1.4%|34.1|29.8|14.4%
64 |182.0|159.5|14.1%|38.4|34.8|10.3%
128 |221.2|160.8|37.5%|41.0|36.4|12.6%
256 |250.5|162.4|54.2%|41.6|37.7|10.3%

avx512vnni asymmetric
blklen|updated prompt tps | baseline prompt tps | prompt tps
change%|updated token gen tps | baseline token gen tps | token gen
change%
-|-|-|-|-|-|-
16 |142.5|152.3|-6.4%|26.3|19.7|33.5%
32 |158.2|155.0|2.0%|34.3|29.2|17.4%
64 |184.1|156.6|17.5%|38.3|30.9|23.9%
128 |215.8|156.1|17.5%|41.3|35.0|17.9%
256 |249.2|155.9|59.8%|41.1|36.3|13.2%


4bit gemm implementation with avx using tile.

1.
tile size is 2blk by 4. in case of size less then tile, it reduce to
1blk by 4, 2blk by 1 and lastly 1blk by 1.
with internal kernel, weight and activation are loaded based on SIMD
register width and blk length:
avx2 256bit register, 64 weights and activation are loaded.
   blklen16: 4 blks are computed by the internal kernel
   blklen32: 2 blks are computed by the internal kernel
   blklen64: 1 blk are computed by the internal kernel
   blklen128: 1 blks are computed 2 times by the internal kernel
   blklen16: 1 blks are computed 4 times by the internal kernel

avx512 512bit register, 128 weights and activation are loaded.
   blklen16: 8 blks are computed by the internal kernel
   blklen32: 4 blks are computed by the internal kernel
   blklen64: 2 blk are computed by the internal kernel
   blklen128: 1 blks are computed by the internal kernel
   blklen16: 1 blks are computed 2 times by the internal kernel

2.
blksum is precomputed during prepacking. 
computation is reformed:
Sum1(scale_a * scale_b * Sum_blk(a_i * b_i)) + Sum2(blksum_a * blksum_b)
  Sum_blk is over one blk
  Sum1 is over all blks for one output
  Sum2 is over all blks for one output
Sum is computed with sgemm with the current implementation. Further
improvement is possible.

 

---------

Signed-off-by: Liqun Fu <liqfu@microsoft.com>
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
Signed-off-by: Liqun Fu <liqun_fu@hotmail.com>
  • Loading branch information
liqunfu authored Aug 2, 2024
1 parent d0a6f57 commit b87e8ed
Show file tree
Hide file tree
Showing 26 changed files with 8,834 additions and 300 deletions.
13 changes: 11 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,17 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")

if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10")
message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
else()
message(STATUS "Using -mavx2 -mfma flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S
Expand All @@ -575,7 +584,7 @@ else()
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl")

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
Expand Down
41 changes: 28 additions & 13 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class MatMulNBits final : public OpKernel {

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
const Tensor* tensor_zero_point = nullptr;
has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point);
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
Expand Down Expand Up @@ -139,6 +141,7 @@ class MatMulNBits final : public OpKernel {
IAllocatorUniquePtr<void> packed_b_{};
size_t packed_b_size_{0};

bool has_zp_input_{false};
#if defined(ORT_NEURAL_SPEED)

bool is_asym_{false};
Expand Down Expand Up @@ -207,10 +210,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}

#else // defined(ORT_NEURAL_SPEED)

#else // defined(ORT_NEURAL_SPEED)
ORT_UNUSED_PARAMETER(prepacked_weights);
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (input_idx == InputIndex::B) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
Expand All @@ -220,12 +223,20 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr);
is_packed = true;
} else if (compute_type == CompInt8) {
#ifdef MLAS_TARGET_AMD64_IX86
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr);
is_packed = false;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);
is_packed = false;
}
#endif
}
#endif // defined(ORT_NEURAL_SPEED)

Expand Down Expand Up @@ -332,9 +343,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<float>();

IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type);
if (workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
Expand All @@ -344,14 +355,18 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type == CompInt8) {
data[i].QuantBDataWorkspace = packed_b_.get();
}
#endif
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

Expand Down
56 changes: 38 additions & 18 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ typedef enum {
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
Expand Down Expand Up @@ -159,14 +161,29 @@ MlasSQNBitGemmPackQuantBDataSize(
/**
* @brief Packs the quantized B data in a format that the kernel expects.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] ThreadPool optional thread pool to use
* If the function is called without QuantBScale and QuantBZeroPoint,
* it just packs QuantBData into PackedQuantBDataAndOrBlkSum.
*
* If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint
* additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum.
*
* Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData,
* QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData,
* and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint,
* BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum.
* If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum
* @param[in] QuantBScale quantized B scale
* @param[in] has_zp_input whether QuantBZeroPoint is provided
* @param[in] QuantBZeroPoint quantized B zero point
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
Expand All @@ -176,6 +193,9 @@ MlasSQNBitGemmPackQuantBData(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
bool has_zp_input,
const void* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Return Value:
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
}

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
Loading

0 comments on commit b87e8ed

Please sign in to comment.