Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlas int4 int8 with avx2/512 #20687

Merged
merged 48 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
293f121
quick adapt llama.cpp to experiment performance. Only works with blkl…
liqunfu May 3, 2024
04c2e56
fire
liqunfu May 6, 2024
cdfda6f
tile 2x4 SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symme…
liqunfu May 7, 2024
92dad97
use one_16_epi16 and accumulate_2blk_dot: SQNBITGEMM<4>/BlkLen:32/M:2…
liqunfu May 8, 2024
5418e9c
apply to M1, BQuant layout pack block (subblk) larger than blklen: SQ…
liqunfu May 9, 2024
0401f72
use new AQuant layout (not work if total M is not RangeCountM): SQNBI…
liqunfu May 10, 2024
a57eeba
apply blksum to blklen32 and 64: SQNBITGEMM<4>/BlkLen:32/M:2048/N:409…
liqunfu May 13, 2024
f2c33af
blklen16
liqunfu May 15, 2024
0ca24f4
impl avx512: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/S…
liqunfu May 26, 2024
7f89d5f
matmul_nbit & fix alignment for sgemm
liqunfu Jun 1, 2024
ed0e666
merge main
liqunfu Jun 4, 2024
35d02a6
fix mlas benchmark not using multi threads
liqunfu Jun 10, 2024
b9493ad
profiling
liqunfu Jun 10, 2024
c443eb5
Merge branch 'liqun/mlas-q4-tile-avx' of https://github.com/microsoft…
liqunfu Jun 10, 2024
ac66951
sgemm after sq4bit for avx2
liqunfu Jun 16, 2024
42a1305
avx512
liqunfu Jun 17, 2024
740031a
layout to follow compute, M1 separate with M > 1
liqunfu Jun 27, 2024
1a6031e
make avx512 run
liqunfu Jun 28, 2024
283fd2d
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Jun 28, 2024
d035939
avx512 blklen64 pass
liqunfu Jul 4, 2024
f329d2d
pass avx512 blklen32
liqunfu Jul 5, 2024
27cfd9c
pass avx512 blklen 16, 128, 256
liqunfu Jul 5, 2024
edee319
pass fp32, refactor sqnbitgemm
liqunfu Jul 11, 2024
fb9221a
merge main
liqunfu Jul 12, 2024
c109b4b
avx512vnni
liqunfu Jul 18, 2024
6654d22
merge main
liqunfu Jul 18, 2024
4b91bed
avxvnni
liqunfu Jul 20, 2024
8674b9f
rm unused ComputeParallelTasksSGemm
liqunfu Jul 23, 2024
e26e29e
avoid _mm256_dpbusds_avx_epi32 in avx512vnni
liqunfu Jul 24, 2024
2b0307e
fix linux build
liqunfu Jul 24, 2024
40df782
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Jul 26, 2024
51e97c8
refactor for Arm64
liqunfu Jul 26, 2024
48e8639
more refactor for Arm64
liqunfu Jul 26, 2024
705aa1f
hsum_float_16
liqunfu Jul 29, 2024
012e9c4
hsum_float_16
liqunfu Jul 29, 2024
21b9138
condition for -mavxvnni
liqunfu Jul 30, 2024
1fb1c83
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 10
liqunfu Jul 30, 2024
85918e9
missed 2 files from (__GNUC__ > 10)
liqunfu Jul 30, 2024
9530ac5
missed _mm256_dpbusds_avx_epi32 and print out cmake msgs
liqunfu Jul 30, 2024
f77cffd
unused zp, etc.
liqunfu Jul 30, 2024
a6fd378
unused zp, etc.
liqunfu Jul 30, 2024
c875e5c
remove test code changes
liqunfu Jul 30, 2024
3b56710
remove test code changes
liqunfu Jul 30, 2024
746562f
lint
liqunfu Jul 30, 2024
52fc7fa
lint
liqunfu Jul 30, 2024
0933a6b
code name
liqunfu Jul 30, 2024
2b35c82
update reviewers' comments
liqunfu Jul 31, 2024
caeb35e
Merge branch 'main' into liqun/mlas-q4-tile-avx
liqunfu Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
const Tensor* tensor_zero_point = nullptr;
has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point);
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
Expand Down Expand Up @@ -139,6 +141,7 @@
IAllocatorUniquePtr<void> packed_b_{};
size_t packed_b_size_{0};

bool has_zp_input_{false};
#if defined(ORT_NEURAL_SPEED)

bool is_asym_{false};
Expand Down Expand Up @@ -208,9 +211,8 @@
}

#else // defined(ORT_NEURAL_SPEED)

const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (input_idx == InputIndex::B) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
return Status::OK();
}
Expand All @@ -220,13 +222,24 @@
}
auto qptr = tensor.DataRaw();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get());
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr);

Check warning on line 225 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:225: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (prepacked_weights) {
// TODO: cannot use packed_b_ after

Check warning on line 227 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:227: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
assert(false);
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
else if (input_idx == InputIndex::scales && packed_b_ != nullptr) {

Check warning on line 234 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 An else should appear on the same line as the preceding } [whitespace/newline] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:234: An else should appear on the same line as the preceding } [whitespace/newline] [4]

Check warning on line 234 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:234: If an else has a brace on one side, it should have it on both [readability/braces] [5]
auto sptr = tensor.Data<float>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr);

Check warning on line 236 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:236: Lines should be <= 120 characters long [whitespace/line_length] [2]
is_packed = false;
} else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr);

Check warning on line 240 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:240: Lines should be <= 120 characters long [whitespace/line_length] [2]
is_packed = false;
}
#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
Expand Down Expand Up @@ -265,6 +278,7 @@
}

Status MatMulNBits::Compute(OpKernelContext* ctx) const {
//auto start = std::chrono::high_resolution_clock::now(); // Start timing here

Check warning on line 281 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:281: Should have a space between // and comment [whitespace/comments] [4]
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input<Tensor>(InputIndex::A);
const auto* a_data = a->Data<float>();
Expand Down Expand Up @@ -332,9 +346,9 @@
const auto* bias_data = bias == nullptr ? nullptr : bias->Data<float>();

IAllocatorUniquePtr<std::byte> workspace{};
if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count,
nbits_, block_size_, compute_type);
workspace_size > 0) {
const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type);
if (workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
workspace = IAllocator::MakeUniquePtr<std::byte>(allocator, workspace_size);
Expand All @@ -344,17 +358,29 @@
for (size_t i = 0; i < batch_count; ++i) {
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].QuantBData = packed_b_.get();
data[i].QuantBDataWorkspace = packed_b_.get();
data[i].QuantBScale = scales_data;
data[i].QuantBZeroPoint = zero_points_data;
data[i].Bias = bias_data;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
data[i].node_name = this->Node().Name();
}
//auto start2 = std::chrono::high_resolution_clock::now(); // Start timing here

Check warning on line 369 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:369: Should have a space between // and comment [whitespace/comments] [4]

//const int CountTotal = 2000;

Check warning on line 371 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:371: Should have a space between // and comment [whitespace/comments] [4]
//int count = CountTotal;

Check warning on line 372 in onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc:372: Should have a space between // and comment [whitespace/comments] [4]
//while (count-- > 0)
MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);

MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(),
thread_pool);
//auto end = std::chrono::high_resolution_clock::now(); // End timing here

//std::chrono::duration<double, std::nano> elapsed2 = end - start2;
//// Calculate and print the duration in nanoseconds
//std::chrono::duration<double, std::nano> elapsed = end - start;
//std::cout << "MlasSQNBitGemmBatch: " << elapsed2.count() / CountTotal << " ns\n";
//std::cout << "main Duration_M" << M << "xN" << N << "xK" << K << ": " << elapsed.count() / CountTotal << " ns\n";
return Status::OK();
}
}
Expand Down
27 changes: 23 additions & 4 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Module Name:

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"
#include <string>
liqunfu marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Define compute types of block quantization, in order of decreasing accuracy.
Expand All @@ -45,15 +46,18 @@ typedef enum {
struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values)
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr;
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
const float* Bias = nullptr; ///< optional address of Bias, vector size N
float* C = nullptr; ///< address of result matrix
size_t ldc = 0; ///< leading dimension of C

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
std::string node_name = "";
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
};

/**
Expand Down Expand Up @@ -159,14 +163,26 @@ MlasSQNBitGemmPackQuantBDataSize(
/**
* @brief Packs the quantized B data in a format that the kernel expects.
*
* If the function is called without QuantBScale and QuantBZeroPoint,
* it just packs QuantBData into PackedQuantBDataAndOrBlkSum.
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
*
* If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint
* additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum.
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
*
* Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData,
* QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData,
* and then QuantBScale and QuantBZeroPoint. The second time the function is called with QuantBScale,
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
* BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum.
* If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
liqunfu marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] ThreadPool optional thread pool to use
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
MlasSQNBitGemmPackQuantBData(
Expand All @@ -176,6 +192,9 @@ MlasSQNBitGemmPackQuantBData(
size_t BlkLen,
MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType,
const void* QuantBData,
void* PackedQuantBData,
MLAS_THREADPOOL* ThreadPool = nullptr
void* PackedQuantBDataAndOrBlkSum,
const void* QuantBScale,
bool has_zp_input,
const void* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Return Value:
this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni;
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni;
this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
//this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni;
}
}
}
Expand Down
Loading
Loading