Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
Signed-off-by: liqunfu <liqun.fu@microsoft.com>
  • Loading branch information
liqunfu committed Jul 30, 2024
1 parent 746562f commit 52fc7fa
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {

IAllocatorUniquePtr<std::byte> workspace{};
const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(
M, N, K, batch_count, nbits_, block_size_, compute_type);
M, N, K, batch_count, nbits_, block_size_, compute_type);
if (workspace_size > 0) {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator));
Expand Down
23 changes: 12 additions & 11 deletions onnxruntime/core/mlas/inc/mlas_qnbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ Module Name:

#include "mlas.h"
#include "mlas_gemm_postprocessor.h"
#include <string>

/**
* @brief Define compute types of block quantization, in order of decreasing accuracy.
Expand All @@ -47,7 +46,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {
const float* A = nullptr; ///< address of A (float32 matrix)
size_t lda = 0; ///< leading dimension of A
const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values)
const std::byte* PackedQuantBData = nullptr;
const std::byte* PackedQuantBData = nullptr;
const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block
const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block
const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block
Expand All @@ -57,7 +56,6 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS {

///< optional post processing to apply to result matrix
MLAS_GEMM_POSTPROCESSOR<float>* PostProcessor = nullptr;
std::string node_name = "";
};

/**
Expand Down Expand Up @@ -171,17 +169,20 @@ MlasSQNBitGemmPackQuantBDataSize(
*
* Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData,
* QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData,
* and then QuantBScale and QuantBZeroPoint. The second time the function is called with QuantBScale,
* and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint,
* BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum.
* If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint.
*
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[out] PackedQuantBData packed quantized B data
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints)
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
* @param[in] QuantBData quantized B data
* @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum
* @param[in] QuantBScale quantized B scale
* @param[in] has_zp_input whether QuantBZeroPoint is provided
* @param[in] QuantBZeroPoint quantized B zero point
* @param[in] ThreadPool thread pool to use (no parallel if nullptr)
*/
void MLASCALL
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ void RunTest(const TestOptions& opts,
} // namespace

TEST(MatMulNBits, Float32) {
//onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");
for (auto M : {1, 2, 100}) {
for (auto N : {/*2560, */1, 2, 32, 288}) {
for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) {
for (auto block_size : {16, 32, 64, 128 }) {
for (auto N : {/*2560, */ 1, 2, 32, 288}) {
for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) {
for (auto block_size : {16, 32, 64, 128}) {
for (auto accuracy_level : {0, 1, 4}) {
TestOptions base_opts{};
base_opts.M = M, base_opts.N = N, base_opts.K = K;
Expand All @@ -280,7 +280,7 @@ TEST(MatMulNBits, Float32) {

{
TestOptions opts = base_opts;
RunTest<float>(opts);
RunTest<float>(opts);
}

{
Expand Down

0 comments on commit 52fc7fa

Please sign in to comment.