Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlas int4 int8 with avx2/512 #20687

Merged
merged 48 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
293f121
quick adapt llama.cpp to experiment performance. Only works with blkl…
liqunfu May 3, 2024
04c2e56
fire
liqunfu May 6, 2024
cdfda6f
tile 2x4 SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symme…
liqunfu May 7, 2024
92dad97
use one_16_epi16 and accumulate_2blk_dot: SQNBITGEMM<4>/BlkLen:32/M:2…
liqunfu May 8, 2024
5418e9c
apply to M1, BQuant layout pack block (subblk) larger than blklen: SQ…
liqunfu May 9, 2024
0401f72
use new AQuant layout (not work if total M is not RangeCountM): SQNBI…
liqunfu May 10, 2024
a57eeba
apply blksum to blklen32 and 64: SQNBITGEMM<4>/BlkLen:32/M:2048/N:409…
liqunfu May 13, 2024
f2c33af
blklen16
liqunfu May 15, 2024
0ca24f4
impl avx512: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/S…
liqunfu May 26, 2024
7f89d5f
matmul_nbit & fix alignment for sgemm
liqunfu Jun 1, 2024
ed0e666
merge main
liqunfu Jun 4, 2024
35d02a6
fix mlas benchmark not using multi threads
liqunfu Jun 10, 2024
b9493ad
profiling
liqunfu Jun 10, 2024
c443eb5
Merge branch 'liqun/mlas-q4-tile-avx' of https://github.com/microsoft…
liqunfu Jun 10, 2024
ac66951
sgemm after sq4bit for avx2
liqunfu Jun 16, 2024
42a1305
avx512
liqunfu Jun 17, 2024
740031a
layout to follow compute, M1 separate with M > 1
liqunfu Jun 27, 2024
1a6031e
make avx512 run
liqunfu Jun 28, 2024
283fd2d
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Jun 28, 2024
d035939
avx512 blklen64 pass
liqunfu Jul 4, 2024
f329d2d
pass avx512 blklen32
liqunfu Jul 5, 2024
27cfd9c
pass avx512 blklen 16, 128, 256
liqunfu Jul 5, 2024
edee319
pass fp32, refactor sqnbitgemm
liqunfu Jul 11, 2024
fb9221a
merge main
liqunfu Jul 12, 2024
c109b4b
avx512vnni
liqunfu Jul 18, 2024
6654d22
merge main
liqunfu Jul 18, 2024
4b91bed
avxvnni
liqunfu Jul 20, 2024
8674b9f
rm unused ComputeParallelTasksSGemm
liqunfu Jul 23, 2024
e26e29e
avoid _mm256_dpbusds_avx_epi32 in avx512vnni
liqunfu Jul 24, 2024
2b0307e
fix linux build
liqunfu Jul 24, 2024
40df782
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Jul 26, 2024
51e97c8
refactor for Arm64
liqunfu Jul 26, 2024
48e8639
more refactor for Arm64
liqunfu Jul 26, 2024
705aa1f
hsum_float_16
liqunfu Jul 29, 2024
012e9c4
hsum_float_16
liqunfu Jul 29, 2024
21b9138
condition for -mavxvnni
liqunfu Jul 30, 2024
1fb1c83
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 10
liqunfu Jul 30, 2024
85918e9
missed 2 files from (__GNUC__ > 10)
liqunfu Jul 30, 2024
9530ac5
missed _mm256_dpbusds_avx_epi32 and print out cmake msgs
liqunfu Jul 30, 2024
f77cffd
unused zp, etc.
liqunfu Jul 30, 2024
a6fd378
unused zp, etc.
liqunfu Jul 30, 2024
c875e5c
remove test code changes
liqunfu Jul 30, 2024
3b56710
remove test code changes
liqunfu Jul 30, 2024
746562f
lint
liqunfu Jul 30, 2024
52fc7fa
lint
liqunfu Jul 30, 2024
0933a6b
code name
liqunfu Jul 30, 2024
2b35c82
update reviewers' comments
liqunfu Jul 31, 2024
caeb35e
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,17 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")

message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")

if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10")
message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
else()
message(STATUS "Using -mavx2 -mfma flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S
Expand All @@ -575,7 +584,7 @@ else()
${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl")
set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl")

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
Expand Down
41 changes: 28 additions & 13 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
const Tensor* tensor_zero_point = nullptr;
has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point);
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
Expand Down Expand Up @@ -139,6 +141,7 @@
IAllocatorUniquePtr<void> packed_b_{};
size_t packed_b_size_{0};

bool has_zp_input_{false};
#if defined(ORT_NEURAL_SPEED)

bool is_asym_{false};
Expand Down Expand Up @@ -207,10 +210,10 @@
is_packed = true;
}

#else // defined(ORT_NEURAL_SPEED)

#else // defined(ORT_NEURAL_SPEED)
ORT_UNUSED_PARAMETER(prepacked_weights);
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (input_idx == InputIndex::B) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
Expand All @@ -220,12 +223,20 @@
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr);

Check warning on line 226 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:226: Lines should be <= 120 characters long [whitespace/line_length] [2]
is_packed = true;
} else if (compute_type == CompInt8) {
#ifdef MLAS_TARGET_AMD64_IX86
if (input_idx == InputIndex::scales && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr);

Check warning on line 232 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:232: Lines should be <= 120 characters long [whitespace/line_length] [2]
is_packed = false;
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);

Check warning on line 236 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:236: Lines should be <= 120 characters long [whitespace/line_length] [2]
is_packed = false;
}
#endif
}
#endif // defined(ORT_NEURAL_SPEED)

Expand Down Expand Up @@ -332,9 +343,9 @@
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<float>();

IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type);
if (workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
Expand All @@ -344,14 +355,18 @@
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
#ifdef MLAS_TARGET_AMD64_IX86
if (compute_type == CompInt8) {
data[i].QuantBDataWorkspace = packed_b_.get();
}
#endif
data[i].PackedQuantBData = static_cast<std::byte*>(packed_b_.get());
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
}

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

Expand Down
56 changes: 38 additions & 18 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ typedef enum {
* @brief Data parameters for float/n-bit quantized int GEMM routine.
*/
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
Expand Down Expand Up @@ -159,14 +161,29 @@ MlasSQNBitGemmPackQuantBDataSize(
/**
* @brief Packs the quantized B data in a format that the kernel expects.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] ThreadPool optional thread pool to use
* If the function is called without QuantBScale and QuantBZeroPoint,
* it just packs QuantBData into PackedQuantBDataAndOrBlkSum.
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
*
* If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint
* additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum.
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
*
* Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData,
* QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData,
* and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint,
* BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum.
* If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum
* @param[in] QuantBScale quantized B scale
* @param[in] has_zp_input whether QuantBZeroPoint is provided
* @param[in] QuantBZeroPoint quantized B zero point
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
Expand All @@ -176,6 +193,9 @@ MlasSQNBitGemmPackQuantBData(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
bool has_zp_input,
const void* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);
2 changes: 2 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;

extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ Return Value:
this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni;
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni;
}

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
Loading
Loading