From 293f121d24a8ea83a189227f9e242ec5745aeade Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 2 May 2024 20:00:51 -0700 Subject: [PATCH 01/41] quick adapt llama.cpp to experiment performance. Only works with blklen32, symmetric1 hasBias0 Int8 Signed-off-by: Liqun Fu --- cmake/onnxruntime_mlas.cmake | 2 + onnxruntime/core/mlas/lib/llama.cpp.sgemm.cpp | 321 ++++++++++++++++++ onnxruntime/core/mlas/lib/llama.cpp.sgemm.h | 5 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 31 +- 4 files changed, 358 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/mlas/lib/llama.cpp.sgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/llama.cpp.sgemm.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 682dcfc5fe84..0d8f6e2df66f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -38,6 +38,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ${MLAS_SRC_DIR}/sqnbitgemm.h ${MLAS_SRC_DIR}/sqnbitgemm.cpp + ${MLAS_SRC_DIR}/llama.cpp.sgemm.h + ${MLAS_SRC_DIR}/llama.cpp.sgemm.cpp ) target_sources(onnxruntime_mlas PRIVATE diff --git a/onnxruntime/core/mlas/lib/llama.cpp.sgemm.cpp b/onnxruntime/core/mlas/lib/llama.cpp.sgemm.cpp new file mode 100644 index 000000000000..5b414120ec62 --- /dev/null +++ b/onnxruntime/core/mlas/lib/llama.cpp.sgemm.cpp @@ -0,0 +1,321 @@ +// ported/adapted from https://github.com/ggerganov/llama.cpp/pull/6414 +#define __AVX2__ 1 + +#include +#include +#include "llama.cpp.sgemm.h" +#include "sqnbitgemm.h" +//#include "sqnbitgemm_kernel_avx_common.h" +#include +#include + +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + +#if defined(__ARM_NEON) || defined(__AVX512F__) +#define VECTOR_REGISTERS 32 +#else +#define VECTOR_REGISTERS 16 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED FUSED MULTIPLY ADD + +/** + * Computes a * b + c. + */ +template +inline U +madd(T a, T b, U c) +{ + return add(mul(a, b), c); +} + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> +inline __m256 +madd(__m256 a, __m256 b, __m256 c) +{ + return _mm256_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512F__) +template <> +inline __m512 +madd(__m512 a, __m512 b, __m512 c) +{ + return _mm512_fmadd_ps(a, b, c); +} +#endif + + +template +class tinyBLAS_Q0_AVX2 +{ + public: + tinyBLAS_Q0_AVX2(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, + const float *QuantBScale, int64_t StrideQuantBScale) + : A_q4_(A), B_q8_(B), C(C), k(k), lda_q4_(lda), ldb_q8_(ldb), ldc_(ldc), + Quant4Scale_(QuantBScale), StrideQuant4Scale_(StrideQuantBScale) + { + } + + void matmul(int64_t m, int64_t n) + { + mnpack(0, m, 0, n); + } + + private: + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) + { + int64_t mc, nc, mp, np; + switch ((std::min(m - m0, (int64_t)4) << 4) | std::min(n - n0, (int64_t)4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; +#else + case 0x44: + case 0x43: + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) + { + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + for (int64_t tile = 0; tile < tiles; ++tile) { + int64_t ii = m0 + tile / xtiles * RM; + int64_t jj = n0 + tile % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) // blk count (BlockCountK) + for (int64_t j = 0; j < RN; ++j) // + for (int64_t i = 0; i < RM; ++i) { + const std::byte *Quant4ABlk = A_q4_ + lda_q4_ * (ii + i) + l * BlkDataSizeInBytes16; + const std::byte *Quant8BBlk = B_q8_ + ldb_q8_ * (jj + j) + l * Q8BlkSize(BlkLen32); + const float &scale_q8 = Q8BlkScale(Quant8BBlk); + const float &scale_q4 = *(Quant4Scale_ + (ii + i) * StrideQuant4Scale_ + l); + + const int8_t zp = 8; + const __m256i q4_v = load_q4(Quant4ABlk, zp); + const __m256i q8_v = load_q8(Quant8BBlk); + Cv[j][i] = madd( + _mm256_set1_ps(scale_q8 * scale_q4), + updot(_mm256_sign_epi8(q4_v, q4_v), _mm256_sign_epi8(q8_v, q4_v)), + Cv[j][i] + ); + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc_ * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline float hsum(__m128 x) + { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); + } + inline float hsum(__m256 x) + { + return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); + } + + inline __m256i load_q8(const std::byte *Quant8Blk) + { + return _mm256_loadu_si256((const __m256i *)Q8BlkData(Quant8Blk)); + } + + inline __m256i load_q4(const std::byte *Quant4DataPtr, const int8_t zp) + { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(Quant4DataPtr)); + + const __m128i low_mask = _mm_set1_epi8(15); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_32_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + const __m256i bzp0 = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp0); + + return bv_32_epi8; + } + + inline __m256 updot(__m256i u, __m256i s) + { + __m256i res; +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + const TA *const A_q4_; + const TB *const B_q8_; + TC *const C; + const int64_t k; + const int64_t lda_q4_; + const int64_t ldb_q8_; + const int64_t ldc_; + const float *Quant4Scale_; + int64_t StrideQuant4Scale_; +}; + +/** + * Performs optimized matrix multiplication on CPU. + * + * This subroutine may compute C = Aᵀ * B with column major ordering. + * Despite its name, this isn't a generalized implementation. Work is + * only performed when a handwritten kernel is written and available. + * Otherwise the caller should fall back to a general matmul routine. + * + * For example, for single-threaded single-precision GEMM you can say + * + * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, + * 0, 1, GGML_TASK_TYPE_COMPUTE, + * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); + * + * @param m is rows in `A` and `C` + * @param n is cols in `B` and `C` + * @param k is cols in `A` and rows in `B` + * @param A is first input matrix (always transposed) + * @param lda is row stride of `A` + * @param B is second input matrix (never transposed) + * @param ldb is row stride of `B` + * @param C is input/output array of output matrices + * @param ldc is row stride of `C` + * @param ith is thread id (must be less than `nth`) + * @param nth is number of threads (must be greater than zero) + * @param task is GGML task type + * @param Atype is GGML data type of `A` + * @param Btype is GGML data type of `B` + * @param Ctype is GGML data type of `C` + * @return true if this function was able to service the matmul request + */ +bool +llamafile_sgemm( + int64_t m, + int64_t n, + int64_t k, + const std::byte *A, + int64_t lda, + const std::byte *B, + int64_t ldb, + float *C, + int64_t ldc, + const float *QuantBScale, + int64_t StrideQuantBScale +) +{ + tinyBLAS_Q0_AVX2 tb{k, A, lda, B, ldb, C, ldc, QuantBScale, StrideQuantBScale}; + tb.matmul(m, n); + return true; +} diff --git a/onnxruntime/core/mlas/lib/llama.cpp.sgemm.h b/onnxruntime/core/mlas/lib/llama.cpp.sgemm.h new file mode 100644 index 000000000000..a98ef9f689ee --- /dev/null +++ b/onnxruntime/core/mlas/lib/llama.cpp.sgemm.h @@ -0,0 +1,5 @@ +#include +#include + +bool +llamafile_sgemm(int64_t m, int64_t n, int64_t k, const std::byte *A, int64_t lda, const std::byte *B, int64_t ldb, float *C, int64_t ldc, const float *QuantBScale, int64_t StrideQuantBScale); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 4fd28f5e6998..2b713147be86 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,6 +19,11 @@ Module Name: #include +#define SQ4BITGEMM_USE_TILE +#if defined SQ4BITGEMM_USE_TILE +#include "llama.cpp.sgemm.h" +#endif + namespace { @@ -394,6 +399,7 @@ SQ4BitGemm_CompInt8( const size_t RangeCountN ) { +#ifndef SQ4BITGEMM_USE_TILE #ifdef MLAS_TARGET_AMD64_IX86 if (RangeCountM != 1) { // perf experiment shows fp32 is faster than int8 in M > 1 cases. @@ -404,6 +410,7 @@ SQ4BitGemm_CompInt8( ); return; } +#endif #endif constexpr size_t BlkBitWidth = 4; @@ -468,7 +475,7 @@ SQ4BitGemm_CompInt8( (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - +#ifndef SQ4BITGEMM_USE_TILE for (size_t m = 0; m < RangeCountM; ++m) { GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( BlkLen, @@ -485,6 +492,28 @@ SQ4BitGemm_CompInt8( c_blk += ldc; a_row += lda; } +#else + int64_t llama_cpp_m = CountN; + int64_t llama_cpp_n = RangeCountM; + int64_t llama_cpp_k = k_blks; + const std::byte* llama_cpp_A = b_col; + int64_t llama_cpp_lda = ldb; + const std::byte* llama_cpp_B = a_row; + int64_t llama_cpp_ldb = lda; + float* llama_cpp_C = c_blk; + int64_t llama_cpp_ldc = ldc; + const float* llama_cpp_QuantBScale = b_col_scale; + int64_t llama_cpp_StrideQuantBScale = k_blks; + llamafile_sgemm( + llama_cpp_m, llama_cpp_n, llama_cpp_k, + llama_cpp_A, llama_cpp_lda, + llama_cpp_B, llama_cpp_ldb, + llama_cpp_C, llama_cpp_ldc, + llama_cpp_QuantBScale, llama_cpp_StrideQuantBScale + ); + (void)bias; + (void)b_col_zp; +#endif } } From 04c2e5603ac1596ad348554e2d2075c1767315c5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Sun, 5 May 2024 23:14:52 -0700 Subject: [PATCH 02/41] fire Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 88 ++- onnxruntime/core/mlas/lib/sqnbitgemm.h | 18 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 95 +++ .../mlas/lib/sqnbitgemm_kernel_avx2_int8.h | 632 ++++++++++++++++++ .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 14 + .../test/mlas/bench/bench_sqnbitgemm.cpp | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 13 + 7 files changed, 828 insertions(+), 34 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 2b713147be86..f3657ef01e74 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,8 +19,8 @@ Module Name: #include -#define SQ4BITGEMM_USE_TILE -#if defined SQ4BITGEMM_USE_TILE +//#define SQ4BITGEMM_USE_TILE +#ifdef SQ4BITGEMM_USE_TILE #include "llama.cpp.sgemm.h" #endif @@ -399,19 +399,19 @@ SQ4BitGemm_CompInt8( const size_t RangeCountN ) { -#ifndef SQ4BITGEMM_USE_TILE -#ifdef MLAS_TARGET_AMD64_IX86 - if (RangeCountM != 1) { - // perf experiment shows fp32 is faster than int8 in M > 1 cases. - // route to fp32 compute before int8 compute is improved. - SQ4BitGemm_CompFp32( - BlkLen, - K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN - ); - return; - } -#endif -#endif +//#ifndef SQ4BITGEMM_USE_TILE +//#ifdef MLAS_TARGET_AMD64_IX86 +// if (RangeCountM != 1) { +// // perf experiment shows fp32 is faster than int8 in M > 1 cases. +// // route to fp32 compute before int8 compute is improved. +// SQ4BitGemm_CompFp32( +// BlkLen, +// K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN +// ); +// return; +// } +//#endif +//#endif constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); @@ -475,24 +475,7 @@ SQ4BitGemm_CompInt8( (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; -#ifndef SQ4BITGEMM_USE_TILE - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - - c_blk += ldc; - a_row += lda; - } -#else +#ifdef SQ4BITGEMM_USE_TILE int64_t llama_cpp_m = CountN; int64_t llama_cpp_n = RangeCountM; int64_t llama_cpp_k = k_blks; @@ -513,6 +496,45 @@ SQ4BitGemm_CompInt8( ); (void)bias; (void)b_col_zp; +#else +#if 0 + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + //GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + // BlkLen, + // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, + // K, k_blks, bias, lda, ldc + //); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } +#else + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + a_row, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + lda, + ldc); +#endif #endif } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 318a51e1c80a..f5952b49ff46 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -228,6 +228,24 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmM1Kernel_CompInt8_Fn* SQ4BitGemmM1Kernel_CompInt8 = nullptr; + typedef void(SQ4BitGemmKernel_CompInt8_Fn)( + const size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ); + + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index b5d7a4e78fbe..ee33f7a617db 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,6 +22,7 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx2_int8.h" MLAS_FORCEINLINE __m256 @@ -338,6 +339,99 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +MLAS_FORCEINLINE +void +SQ4BitGemmKernel_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + assert(false); + } else if (BlkLen == 32) { + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + lda, + ldc + ); + //MlasQ4Int8GemmKernelBlkLen32Avx2>( + // QuantA, + // QuantBData, + // QuantBScale, + // QuantBZeroPoint, + // C, + // CountM, + // CountN, + // CountK, + // BlockCountK, + // Bias, + // lda, + // ldc + //); + } else { + assert(false); + } + } else { + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + assert(false); + } else if (BlkLen == 32) { + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + lda, + ldc + ); + // MlasQ4Int8GemmKernelBlkLen32Avx2>( + // QuantA, + // QuantBData, + // QuantBScale, + // QuantBZeroPoint, + // C, + // CountM, + // CountN, + // CountK, + // BlockCountK, + // Bias, + // lda, + // ldc + //); + } else { + assert(false); + } + } +} + MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( @@ -1107,6 +1201,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h new file mode 100644 index 000000000000..81b3fd426ad7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h @@ -0,0 +1,632 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_dot(const __m256i av_32_epi8, const __m256i bv_32_epi8, const float combined_scale, const __m256i one, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i av00_32_epi8, + const __m256i av01_32_epi8, + const __m256i av10_32_epi8, + const __m256i av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float combined_scale00, + const float combined_scale01, + const float combined_scale10, + const float combined_scale11, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed - bv0), 4); + const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + __m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one = _mm256_set1_epi16(1); + accumulate_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one, acc0); + accumulate_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one, acc0); + accumulate_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one, acc1); + accumulate_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i av00_32_epi8, + const __m256i av10_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float combined_scale00, + const float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + + const __m128i low_mask = _mm_set1_epi8(0x0F); + const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... + const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... + __m256i bv_32_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one = _mm256_set1_epi16(1); + accumulate_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one, acc0); + accumulate_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one, acc1); +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m += NRows2) { + // accumulate_blklen32_r2c4_avx2 + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2]; + for (int i = 0; i < NCols4 * NRows2; i++) { + acc[i] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + // accumulate_blklen32_r2c1_avx2 + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + lda) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + lda) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} + +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index abace949a1c5..7e0f672fc01d 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -239,6 +239,20 @@ load_and_mul_sum_s8_quads_with_zp_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale0, acc0); } +template +void MLAS_FORCEINLINE +get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) +{ + if constexpr (HasZeroPoint) { + zp0 = std::to_integer((*QuantBZeroPointPtr) & std::byte{0x0F}); + zp1 = std::to_integer((*QuantBZeroPointPtr) >> 4); + } else { + zp0 = 8; + zp1 = 8; + (void)QuantBZeroPointPtr; + } +} + template int8_t MLAS_FORCEINLINE get_zp(bool is_lower_half_byte_zp, const std::byte* QuantBZeroPointPtr) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 903c5a498575..00a7bf3d8af4 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -110,7 +110,7 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgsProduct({ {16, 32, 64, 128, 256}, // BlkLen - {1, 1024, 2048}, // M + {1, 2, 1024, 2048}, // M {4096, 11008}, // N {4096, 11008}, // K {1, 8}, // Threads diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 71a6123b868b..654a5e0189a0 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -394,6 +394,19 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Mon, 6 May 2024 19:04:31 -0700 Subject: [PATCH 03/41] tile 2x4 SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1542487160 ns 1539062500 ns Signed-off-by: Liqun Fu --- .../mlas/lib/sqnbitgemm_kernel_avx2_int8.h | 535 ++++++++++++++++-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 5 +- 2 files changed, 487 insertions(+), 53 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h index 81b3fd426ad7..31a93e30b4fb 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h @@ -39,12 +39,16 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); const __m256i low_mask = _mm256_set1_epi8(0x0F); - const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 // TODO: will this be faster and save a use of low_mask? - // const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed - bv0), 4); - const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - __m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + + // This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + //__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + __m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); int8_t zp0, zp1; get_2_zps(QuantBZeroPointPtr, zp0, zp1); @@ -73,11 +77,8 @@ accumulate_blklen32_r2c1blk1_avx2( { // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); - - const __m128i low_mask = _mm_set1_epi8(0x0F); - const __m128i bv_lo0 = _mm_and_si128(bv_packed0, low_mask); // 0, 1, 2, 3,... - const __m128i bv_hi0 = _mm_and_si128(_mm_srli_epi16(bv_packed0, 4), low_mask); // 16, 17, 18, 19,... - __m256i bv_32_epi8 = _mm256_set_m128i(bv_hi0, bv_lo0); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); @@ -89,30 +90,88 @@ accumulate_blklen32_r2c1blk1_avx2( } template -MLAS_FORCEINLINE - size_t - MlasQ4Int8TileGemmKernelBlkLen32Avx2( - const std::byte* QuantA, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t /*CountK*/, - size_t BlockCountK, - const float* Bias, - size_t lda, - size_t ldc - ) +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i av00_32_epi8, + const __m256i av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float combined_scale00, + const float combined_scale01, + __m256& acc0) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + // This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + //__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + __m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one = _mm256_set1_epi16(1); + accumulate_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one, acc0); + accumulate_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one, acc0); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i av00_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one = _mm256_set1_epi16(1); + accumulate_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one, acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) { - // We process 32 quantized values in a batch. constexpr size_t BlkLen32 = 32; constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); + // process 2 blks of 64 4b weights a time constexpr size_t PerAccuBlk2 = 2; @@ -121,38 +180,36 @@ MLAS_FORCEINLINE const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); for (size_t m = 0; m < CountM; m += NRows2) { - // accumulate_blklen32_r2c4_avx2 - // for each row of A, reset B pointers const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; auto* SumPtr = C + m * ldc; - int64_t nblk = (int64_t)(CountN)-NCols4; - while (nblk >= 0) { + for (size_t n = 0; n < CountN; n += NCols4) { const std::byte* QuantAPtr = QuantA + m * lda; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - __m256 acc[NCols4 * NRows2]; - for (int i = 0; i < NCols4 * NRows2; i++) { - acc[i] = _mm256_setzero_ps(); - } + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; size_t k_blks_remaining = BlockCountK; // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; // load A: const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); @@ -180,7 +237,7 @@ MLAS_FORCEINLINE const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); } { @@ -189,7 +246,7 @@ MLAS_FORCEINLINE const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); } { @@ -198,11 +255,11 @@ MLAS_FORCEINLINE const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); } // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantAPtr += Q8Blk32Size * PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; if constexpr (HasZeroPoint) { @@ -231,21 +288,21 @@ MLAS_FORCEINLINE // Col1 const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -268,11 +325,49 @@ MLAS_FORCEINLINE BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; - nblk -= NCols4; - } // while (nblk >= 0) + } + } +} - nblk += NCols4; - for (int64_t n = 0; n < nblk; n++) { +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { // accumulate_blklen32_r2c1_avx2 const std::byte* QuantAPtr = QuantA + m * lda; const std::byte* QuantBDataPtr = QuantBDataColPtr; @@ -330,10 +425,10 @@ MLAS_FORCEINLINE } *SumPtr = hsum_float_8(acc0); - *(SumPtr + lda) = hsum_float_8(acc1); + *(SumPtr + ldc) = hsum_float_8(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; - *(SumPtr + lda) += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; } // move to next column @@ -346,10 +441,348 @@ MLAS_FORCEINLINE BiasPtr += BiasPtr != nullptr ? 1 : 0; SumPtr += 1; } - } // m + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + } + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4BlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc + ); + } + if (remainingCols > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + if (remainingRows > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + return CountM; } +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. template accumulator> MLAS_FORCEINLINE size_t diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 654a5e0189a0..f55fcbca4dc1 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -404,8 +404,9 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Tue, 7 May 2024 19:59:39 -0700 Subject: [PATCH 04/41] use one_16_epi16 and accumulate_2blk_dot: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1434872720 ns Signed-off-by: Liqun Fu --- .../mlas/lib/sqnbitgemm_kernel_avx2_int8.h | 109 ++++++++++++------ 1 file changed, 73 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h index 31a93e30b4fb..6fbeb7bf2063 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h @@ -8,29 +8,55 @@ MLAS_FORCEINLINE void -accumulate_dot(const __m256i av_32_epi8, const __m256i bv_32_epi8, const float combined_scale, const __m256i one, __m256& acc) +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) { const __m256i dot_16_epi16 = _mm256_maddubs_epi16( _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) ); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one, dot_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +MLAS_FORCEINLINE void +accumulate_2blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float& combined_scale0, const float& combined_scale1, + const __m256i& one_16_epi16, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale_8_ps = _mm256_set_ps( + combined_scale0, combined_scale1, combined_scale0, combined_scale1, + combined_scale0, combined_scale1, combined_scale0, combined_scale1 + ); + acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( - const __m256i av00_32_epi8, - const __m256i av01_32_epi8, - const __m256i av10_32_epi8, - const __m256i av11_32_epi8, + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, const std::byte* QuantBDataPtr, const std::byte* QuantBZeroPointPtr, - const float combined_scale00, - const float combined_scale01, - const float combined_scale10, - const float combined_scale11, + const float& combined_scale00, + const float& combined_scale01, + const float& combined_scale10, + const float& combined_scale11, __m256& acc0, __m256& acc1 ) @@ -38,15 +64,19 @@ accumulate_blklen32_r2c1blk2_avx2( // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + //low_mask = _mm256_packus_epi16(low_mask, low_mask); const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 - // TODO: will this be faster and save a use of low_mask? + // TODO: will this (the second line below) be faster and not keep low_mask in use? // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - // This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + // This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. //__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); __m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); @@ -54,23 +84,29 @@ accumulate_blklen32_r2c1blk2_avx2( get_2_zps(QuantBZeroPointPtr, zp0, zp1); bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - __m256i one = _mm256_set1_epi16(1); - accumulate_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one, acc0); - accumulate_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one, acc0); - accumulate_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one, acc1); - accumulate_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one, acc1); + + // generating constant 1s is fater here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. + // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); + // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); } template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk1_avx2( - const __m256i av00_32_epi8, - const __m256i av10_32_epi8, + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, const std::byte* QuantBDataPtr, const std::byte* QuantBZeroPointPtr, - const float combined_scale00, - const float combined_scale10, + const float& combined_scale00, + const float& combined_scale10, __m256& acc0, __m256& acc1 ) @@ -84,20 +120,20 @@ accumulate_blklen32_r2c1blk1_avx2( const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - __m256i one = _mm256_set1_epi16(1); - accumulate_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one, acc0); - accumulate_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one, acc1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); } template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( - const __m256i av00_32_epi8, - const __m256i av01_32_epi8, + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, const std::byte* QuantBDataPtr, const std::byte* QuantBZeroPointPtr, - const float combined_scale00, - const float combined_scale01, + const float& combined_scale00, + const float& combined_scale01, __m256& acc0) { // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | @@ -120,18 +156,19 @@ accumulate_blklen32_r1c1blk2_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - __m256i one = _mm256_set1_epi16(1); - accumulate_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one, acc0); - accumulate_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one, acc0); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); } template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( - const __m256i av00_32_epi8, + const __m256i& av00_32_epi8, const std::byte* QuantBDataPtr, const std::byte* QuantBZeroPointPtr, - const float combined_scale00, + const float& combined_scale00, __m256& acc0 ) { @@ -144,8 +181,8 @@ accumulate_blklen32_r1c1blk1_avx2( const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - __m256i one = _mm256_set1_epi16(1); - accumulate_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one, acc0); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); } template From 5418e9c04b0aca4dcf1a41b17294838a1f6c1dc4 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Wed, 8 May 2024 19:51:53 -0700 Subject: [PATCH 05/41] apply to M1, BQuant layout pack block (subblk) larger than blklen: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1265060620 ns 1265625000 ns Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 67 ++++++++++--------- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 36 ++++++++-- .../mlas/lib/sqnbitgemm_kernel_avx2_int8.h | 32 ++++----- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 49 +++++++++----- .../test/mlas/unittest/test_sqnbitgemm.cpp | 2 + 5 files changed, 116 insertions(+), 70 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f3657ef01e74..5c3a17334e82 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -497,44 +497,45 @@ SQ4BitGemm_CompInt8( (void)bias; (void)b_col_zp; #else -#if 0 - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + if (BlkLen == 32) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + a_row, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + lda, + ldc ); - //GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - // BlkLen, - // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, - // K, k_blks, bias, lda, ldc - //); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc + } else { + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); - } + // GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + // BlkLen, + // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, + // K, k_blks, bias, lda, ldc + //); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } - c_blk += ldc; - a_row += lda; + c_blk += ldc; + a_row += lda; + } } -#else - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - BlkLen, - a_row, - b_col, - b_col_scale, - b_col_zp, - c_blk, - RangeCountM, - CountN, - K, - k_blks, - bias, - lda, - ldc); -#endif #endif } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index ee33f7a617db..bb7467ac093f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -462,16 +462,30 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + 1, // CountM CountN, + CountK, BlockStrideQuantB, - Bias + Bias, + 0, // lda, not needed when CountM = 1 + 0 // ldc, not needed when CountM = 1 ); + // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + // QuantA, + // QuantBData, + // QuantBScale, + // QuantBZeroPoint, + // C, + // CountN, + // BlockStrideQuantB, + // Bias + //); } else { SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( BlkLen, @@ -501,16 +515,30 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + 1, // CountM CountN, + CountK, BlockStrideQuantB, - Bias + Bias, + 0, // lda, not needed when CountM = 1 + 0 // ldc, not needed when CountM = 1 ); + // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( + // QuantA, + // QuantBData, + // QuantBScale, + // QuantBZeroPoint, + // C, + // CountN, + // BlockStrideQuantB, + // Bias + //); } else { SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( BlkLen, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h index 6fbeb7bf2063..f78bffd93816 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h @@ -38,8 +38,8 @@ accumulate_2blk_dot( const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); const __m256 scale_8_ps = _mm256_set_ps( - combined_scale0, combined_scale1, combined_scale0, combined_scale1, - combined_scale0, combined_scale1, combined_scale0, combined_scale1 + combined_scale1, combined_scale1, combined_scale0, combined_scale0, + combined_scale1, combined_scale1, combined_scale0, combined_scale0 ); acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); } @@ -69,16 +69,16 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256i low_mask = _mm256_set1_epi8(0x0F); //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); //low_mask = _mm256_packus_epi16(low_mask, low_mask); - const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 // TODO: will this (the second line below) be faster and not keep low_mask in use? // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - // This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - //__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - __m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); int8_t zp0, zp1; get_2_zps(QuantBZeroPointPtr, zp0, zp1); @@ -140,16 +140,16 @@ accumulate_blklen32_r1c1blk2_avx2( const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); const __m256i low_mask = _mm256_set1_epi8(0x0F); - const __m256i bv0 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 // TODO: will this be faster and save a use of low_mask? // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - const __m256i bv1 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - // This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - //__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - __m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); int8_t zp0, zp1; get_2_zps(QuantBZeroPointPtr, zp0, zp1); @@ -770,7 +770,7 @@ MLAS_FORCEINLINE ldc ); } - if (remainingCols > 0) { + if (remainingCols > 0 && multipleRows > 0) { Q4Int8Gemm2xXBlkLen32Avx2( QuantA, QuantBData + multipleCols * StrideQuantBData, @@ -785,7 +785,7 @@ MLAS_FORCEINLINE ldc); } - if (remainingRows > 0) { + if (remainingRows > 0 && multipleCols > 0) { Q4Int8GemmXx4BlkLen32Avx2( QuantA + multipleRows * lda, QuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 7e0f672fc01d..48419fca2efc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -27,7 +27,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -39,12 +39,17 @@ SQ4BitGemmPackQuantBData( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block + const size_t BlkBytePairCount = BlkLen / 4; size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } const size_t SubBlkDataSize = SubBlkLen / 2; const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block // // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: @@ -65,35 +70,45 @@ SQ4BitGemmPackQuantBData( // // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // src: | v0 v1 | v2 v3 | ... | v60 v61 | v62 v63 | // => // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | // + // When BlkLen = 32 for the remaining blk, it shall be: + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // MlasTrySimpleParallel( ThreadPool, Iterations, [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const size_t data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; const std::byte* QuantBData = QuantBDataBegin + data_offset; std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { - for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + if (SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + } + } - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + for (size_t byte_pair_idx = 0; byte_pair_idx < PackBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + PackDataSize / 2]; - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - QuantBData += SubBlkDataSize; - PackedQuantBData += SubBlkDataSize; + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); } } ); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index f55fcbca4dc1..929f8aefc9ed 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -396,6 +396,8 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Thu, 9 May 2024 20:17:54 -0700 Subject: [PATCH 06/41] use new AQuant layout (not work if total M is not RangeCountM): SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 1214042220 ns Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 29 +- onnxruntime/core/mlas/lib/sqnbitgemm.h | 10 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 23 +- .../mlas/lib/sqnbitgemm_kernel_avx2_int8.h | 570 +++++--- .../mlas/lib/sqnbitgemm_kernel_avx512_int8.h | 1171 +++++++++++++++++ .../test/mlas/unittest/test_sqnbitgemm.cpp | 1 + 6 files changed, 1634 insertions(+), 170 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 5c3a17334e82..9ea8514cc8b5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -84,8 +84,9 @@ MlasIsSQNBitGemmAvailable( Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { - return Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && - Dispatch->QuantizeARow_CompInt8 != nullptr; + return + (Dispatch->SQ4BitGemmM1Kernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || + (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8_2 != nullptr); } default: { return false; @@ -498,9 +499,12 @@ SQ4BitGemm_CompInt8( (void)b_col_zp; #else if (BlkLen == 32) { + // TODO: this does not work is RangeCountM is not the total M. + const float* a_row_scale = (const float*)(QuantA + RangeCountM * k_blks * BlkLen); GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, a_row, + a_row_scale, b_col, b_col_scale, b_col_zp, @@ -568,6 +572,7 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8_2; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); @@ -577,12 +582,22 @@ InitializeWorkspace_CompInt8( const float* ARowPtr = data.A; std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + if (QuantizeARow) { + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + } else { + float* QuantARowScalePtr = (float*)(QuantARowPtr + M * BlockCountK * BlkLen); + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr); - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + } } }); } @@ -712,6 +727,6 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index f5952b49ff46..eed866a70a30 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -231,6 +231,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { typedef void(SQ4BitGemmKernel_CompInt8_Fn)( const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -263,4 +264,13 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { ); QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARow_CompInt8_Fn2)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale + ); + QuantizeARow_CompInt8_Fn2* QuantizeARow_CompInt8_2 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index bb7467ac093f..005846340ba4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -344,6 +344,7 @@ void SQ4BitGemmKernel_CompInt8_avx2( const size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -353,7 +354,7 @@ SQ4BitGemmKernel_CompInt8_avx2( size_t CountK, size_t BlockCountK, const float* Bias, - size_t lda, + size_t /*lda*/, size_t ldc ) { @@ -364,6 +365,7 @@ SQ4BitGemmKernel_CompInt8_avx2( } else if (BlkLen == 32) { MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -373,7 +375,6 @@ SQ4BitGemmKernel_CompInt8_avx2( CountK, BlockCountK, Bias, - lda, ldc ); //MlasQ4Int8GemmKernelBlkLen32Avx2>( @@ -400,6 +401,7 @@ SQ4BitGemmKernel_CompInt8_avx2( } else if (BlkLen == 32) { MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -409,7 +411,6 @@ SQ4BitGemmKernel_CompInt8_avx2( CountK, BlockCountK, Bias, - lda, ldc ); // MlasQ4Int8GemmKernelBlkLen32Avx2>( @@ -462,8 +463,10 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else if (BlkLen == 32) { + const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -473,7 +476,6 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( CountK, BlockStrideQuantB, Bias, - 0, // lda, not needed when CountM = 1 0 // ldc, not needed when CountM = 1 ); // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( @@ -515,8 +517,10 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias ); } else if (BlkLen == 32) { + const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); MlasQ4Int8TileGemmKernelBlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -526,7 +530,6 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( CountK, BlockStrideQuantB, Bias, - 0, // lda, not needed when CountM = 1 0 // ldc, not needed when CountM = 1 ); // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( @@ -1154,13 +1157,15 @@ QuantizeARow_CompInt8_avx2( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m256 signBit = _mm256_set1_ps(-0.0f); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -1181,8 +1186,8 @@ QuantizeARow_CompInt8_avx2( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps(inverse_scale); @@ -1230,7 +1235,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx2; + d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h index f78bffd93816..f31e19e688ea 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h @@ -44,6 +44,121 @@ accumulate_2blk_dot( acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); } +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul_ps( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_mul_ps( + _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul_ps( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)) + ); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( @@ -108,7 +223,8 @@ accumulate_blklen32_r2c1blk1_avx2( const float& combined_scale00, const float& combined_scale10, __m256& acc0, - __m256& acc1 + __m256& acc1, + bool zp_low_half = true ) { // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | @@ -116,7 +232,7 @@ accumulate_blklen32_r2c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const int8_t zp = get_zp(zp_low_half, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -187,8 +303,9 @@ accumulate_blklen32_r1c1blk1_avx2( template MLAS_FORCEINLINE void -Q4Int8Gemm2x4BlkLen32Avx2( +Q4Int8Gemm2x4x2BlkLen32Avx2( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -197,7 +314,6 @@ Q4Int8Gemm2x4BlkLen32Avx2( size_t CountN, size_t BlockCountK, const float* Bias, - size_t lda, size_t ldc ) { @@ -207,11 +323,10 @@ Q4Int8Gemm2x4BlkLen32Avx2( constexpr size_t NRows2 = 2; constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); - // process 2 blks of 64 4b weights a time constexpr size_t PerAccuBlk2 = 2; + const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); @@ -229,74 +344,83 @@ Q4Int8Gemm2x4BlkLen32Avx2( for (size_t n = 0; n < CountN; n += NCols4) { const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() }; size_t k_blks_remaining = BlockCountK; - // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; + const std::byte* QuantABlk01 = QuantABlk00 + 32; const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; + const std::byte* QuantABlk11 = QuantABlk10 + 32; // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - const float& scale_a10 = Q8BlkScale(QuantABlk10); - const float& scale_a11 = Q8BlkScale(QuantABlk11); + //const float& scale_a00 = Q8BlkScale(QuantABlk00); + //const float& scale_a01 = Q8BlkScale(QuantABlk01); + //const float& scale_a10 = Q8BlkScale(QuantABlk10); + //const float& scale_a11 = Q8BlkScale(QuantABlk11); { // Col0 - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + //const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + //const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); } { // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + //const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + //const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); } { // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + //const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + //const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); } { // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + //const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + //const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); } // increment block pointers - QuantAPtr += Q8Blk32Size * PerAccuBlk2; + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; if constexpr (HasZeroPoint) { @@ -308,30 +432,30 @@ Q4Int8Gemm2x4BlkLen32Avx2( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); - const float& scale_a00 = Q8BlkScale(QuantABlk0); - const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); { // Col0 - const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); } { // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } @@ -366,9 +490,130 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4x1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + constexpr size_t Q8Blk32Size = BlkLen32; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 2 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; k++) { + const __m256i av0_32_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av1_32_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + { + // Col0 + const float& scale0 = scale_a0 * QuantBScalePtr[0]; + const float& scale1 = scale_a1 * QuantBScalePtr[0]; + accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale0, scale1, acc[0], acc[NCols4], k % 2 == 0); + } + + { + // Col1 + const float& scale0 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale1 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale0, scale1, acc[1], acc[NCols4 + 1], k % 2 == 0); + } + + { + // Col2 + const float& scale0 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale1 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale0, scale1, acc[2], acc[NCols4 + 2], k % 2 == 0); + } + + { + // Col3 + const float& scale0 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale1 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale0, scale1, acc[3], acc[NCols4 + 3], k % 2 == 0); + } + + // increment block pointers + QuantAPtr += Q8Blk32Size; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -377,7 +622,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( size_t CountN, size_t BlockCountK, const float* Bias, - size_t lda, size_t ldc) { constexpr size_t BlkLen32 = 32; @@ -389,6 +633,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( // process 2 blks of 64 4b weights a time constexpr size_t PerAccuBlk2 = 2; + const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); @@ -405,8 +650,9 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( float* SumPtr = C + m * ldc; for (size_t n = 0; n < CountN; n++) { - // accumulate_blklen32_r2c1_avx2 const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -414,31 +660,36 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk01 = QuantABlk00 + BlkLen32; const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk11 = QuantABlk10 + BlkLen32; // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); - const float& scale_a10 = Q8BlkScale(QuantABlk10); - const float& scale_a11 = Q8BlkScale(QuantABlk11); - - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + //const float& scale_a00 = Q8BlkScale(QuantABlk00); + //const float& scale_a01 = Q8BlkScale(QuantABlk01); + //const float& scale_a10 = Q8BlkScale(QuantABlk10); + //const float& scale_a11 = Q8BlkScale(QuantABlk11); + + //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + //const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + //const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + accumulate_blklen32_r2c1blk2_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; if constexpr (HasZeroPoint) { @@ -450,11 +701,11 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); - const float& scale_a00 = Q8BlkScale(QuantABlk0); - const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; @@ -485,6 +736,7 @@ template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -493,7 +745,6 @@ Q4Int8GemmXx4BlkLen32Avx2( size_t CountN, size_t BlockCountK, const float* Bias, - size_t lda, size_t ldc ) { @@ -506,6 +757,7 @@ Q4Int8GemmXx4BlkLen32Avx2( // process 2 blks of 64 4b weights a time constexpr size_t PerAccuBlk2 = 2; + const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); @@ -523,6 +775,8 @@ Q4Int8GemmXx4BlkLen32Avx2( for (size_t n = 0; n < CountN; n += NCols4) { const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -530,41 +784,50 @@ Q4Int8GemmXx4BlkLen32Avx2( __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); { // Col0 - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); } { // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1] + ); } { // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2] + ); } { // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + //const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + //const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); } // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; if constexpr (HasZeroPoint) { @@ -576,13 +839,13 @@ Q4Int8GemmXx4BlkLen32Avx2( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a00 = *QuantAScalePtr; { // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); } { // Col1 @@ -624,6 +887,7 @@ template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -632,7 +896,6 @@ Q4Int8GemmXxXBlkLen32Avx2( size_t CountN, size_t BlockCountK, const float* Bias, - size_t lda, size_t ldc ) { @@ -645,6 +908,7 @@ Q4Int8GemmXxXBlkLen32Avx2( // process 2 blks of 64 4b weights a time constexpr size_t PerAccuBlk2 = 2; + const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); @@ -662,6 +926,7 @@ Q4Int8GemmXxXBlkLen32Avx2( for (size_t n = 0; n < CountN; n++) { const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -669,22 +934,19 @@ Q4Int8GemmXxXBlkLen32Avx2( __m256 acc0 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); - - const float& scale_a00 = Q8BlkScale(QuantABlk00); - const float& scale_a01 = Q8BlkScale(QuantABlk01); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, + QuantAScalePtr, QuantBScalePtr, acc0); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; if constexpr (HasZeroPoint) { @@ -694,11 +956,8 @@ Q4Int8GemmXxXBlkLen32Avx2( // TODO: use a loop in case PerAccuBlk2 is not 2. if (k_blks_remaining > 0) { - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - - const float& scale_a00 = Q8BlkScale(QuantABlk0); - + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); } @@ -726,6 +985,7 @@ MLAS_FORCEINLINE size_t MlasQ4Int8TileGemmKernelBlkLen32Avx2( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -735,7 +995,6 @@ MLAS_FORCEINLINE size_t /*CountK*/, size_t BlockCountK, const float* Bias, - size_t lda, size_t ldc ) { @@ -744,6 +1003,8 @@ MLAS_FORCEINLINE constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); @@ -756,8 +1017,10 @@ MLAS_FORCEINLINE size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8Gemm2x4BlkLen32Avx2( + //Q4Int8Gemm2x4x1BlkLen32Avx2( + Q4Int8Gemm2x4x2BlkLen32Avx2( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -766,53 +1029,52 @@ MLAS_FORCEINLINE multipleCols, BlockCountK, Bias, - lda, ldc ); } if (remainingCols > 0 && multipleRows > 0) { Q4Int8Gemm2xXBlkLen32Avx2( - QuantA, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleCols, - multipleRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - lda, - ldc); + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); } if (remainingRows > 0 && multipleCols > 0) { Q4Int8GemmXx4BlkLen32Avx2( - QuantA + multipleRows * lda, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C + multipleRows * ldc, - remainingRows, - multipleCols, - BlockCountK, - Bias, - lda, - ldc); + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); } if (remainingCols > 0 && remainingRows > 0) { Q4Int8GemmXxXBlkLen32Avx2( - QuantA + multipleRows * lda, - QuantBData + multipleCols * StrideQuantBData, - QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, - C + multipleRows * ldc + multipleCols, - remainingRows, - remainingCols, - BlockCountK, - Bias ? Bias + multipleCols : nullptr, - lda, - ldc); + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); } return CountM; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h new file mode 100644 index 000000000000..7d9dc3685462 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8.h @@ -0,0 +1,1171 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +MLAS_FORCEINLINE void +accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, + const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) +{ + const __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + ); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +MLAS_FORCEINLINE void +accumulate_2blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float& combined_scale0, const float& combined_scale1, + const __m256i& one_16_epi16, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale_8_ps = _mm256_set_ps( + combined_scale1, combined_scale1, combined_scale0, combined_scale0, + combined_scale1, combined_scale1, combined_scale0, combined_scale0 + ); + acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256d scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_mul( + _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + ); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + ); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256d scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_mul( + _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), + _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + const float& combined_scale10, + const float& combined_scale11, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + //low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this (the second line below) be faster and not keep low_mask in use? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + // generating constant 1s is fater here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. + // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); + // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + const float& combined_scale01, + __m256& acc0) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 + // TODO: will this be faster and save a use of low_mask? + // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 + + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + + //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. + ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); + //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); + + int8_t zp0, zp1; + get_2_zps(QuantBZeroPointPtr, zp0, zp1); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); + //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); + accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const std::byte* QuantBZeroPointPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + const int8_t zp = get_zp(true, QuantBZeroPointPtr); + const __m256i bzp = _mm256_set1_epi8(zp); + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8Gemm2x4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + constexpr size_t Q8Blk32Size = Q8BlkSize(BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8Blk32Size; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8Blk32Size; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += Q8Blk32Size * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + // accumulate_blklen32_r2c1_avx2 + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk10)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk11)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + const float& scale_a10 = Q8BlkScale(QuantABlk10); + const float& scale_a11 = Q8BlkScale(QuantABlk11); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + const float& scale_10 = scale_a10 * QuantBScalePtr[0]; + const float& scale_11 = scale_a11 * QuantBScalePtr[1]; + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0 + lda)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + const float& scale_a10 = Q8BlkScale(QuantABlk0 + lda); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXx4BlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + { + // Col0 + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); + } + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmXxXBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk00)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk01)); + + const float& scale_a00 = Q8BlkScale(QuantABlk00); + const float& scale_a01 = Q8BlkScale(QuantABlk01); + + const float& scale_00 = scale_a00 * QuantBScalePtr[0]; + const float& scale_01 = scale_a01 * QuantBScalePtr[1]; + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a00 = Q8BlkScale(QuantABlk0); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE + size_t + MlasQ4Int8TileGemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8Gemm2x4BlkLen32Avx2( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8Gemm2xXBlkLen32Avx2( + QuantA, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmXx4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + lda, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmXxXBlkLen32Avx2( + QuantA + multipleRows * lda, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + lda, + ldc); + } + + return CountM; +} + +// this function is to explore larger NCols. With Avx2 it does not improve performance. +// Leave it here until the same is implemented in avx512. +template accumulator> +MLAS_FORCEINLINE +size_t +MlasQ4Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t lda, + size_t ldc +) +{ + // We process 32 quantized values in a batch. + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t m = 0; m < CountM; m++) { + // for each row of A, reset B pointers + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + int64_t nblk = (int64_t)(CountN)-NCols4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA + m * lda; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4]; + + acc[0] = _mm256_setzero_ps(); + acc[1] = _mm256_setzero_ps(); + acc[2] = _mm256_setzero_ps(); + acc[3] = _mm256_setzero_ps(); + + if constexpr (NCols4 == 8) { + acc[4] = _mm256_setzero_ps(); + acc[5] = _mm256_setzero_ps(); + acc[6] = _mm256_setzero_ps(); + acc[7] = _mm256_setzero_ps(); + } + + size_t k_blks_remaining = BlockCountK; + + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + const float& scale_41 = scale_a1 * (QuantBScalePtr + 4 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr, true, scale_40, acc[4]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_41, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + const float& scale_51 = scale_a1 * (QuantBScalePtr + 5 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_50, acc[5]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_51, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + const float& scale_61 = scale_a1 * (QuantBScalePtr + 6 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, false, scale_61, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + const float& scale_71 = scale_a1 * (QuantBScalePtr + 7 * StrideQuantBScale)[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, false, scale_71, acc[7]); + } + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } // k_blks_remaining + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc[0]); + + // Col1 + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc[1]); + + // Col2 + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc[2]); + + // Col3 + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc[3]); + + if constexpr (NCols4 == 8) { + // Col4 + const float& scale_40 = scale_a0 * (QuantBScalePtr + 4 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 4 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 4 * StrideQuantBZeroPoint, true, scale_40, acc[4]); + + // Col5 + const float& scale_50 = scale_a0 * (QuantBScalePtr + 5 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 5 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 5 * StrideQuantBZeroPoint, true, scale_50, acc[5]); + + // Col6 + const float& scale_60 = scale_a0 * (QuantBScalePtr + 6 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 6 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 6 * StrideQuantBZeroPoint, true, scale_60, acc[6]); + + // Col7 + const float& scale_70 = scale_a0 * (QuantBScalePtr + 7 * StrideQuantBScale)[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr + 7 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 7 * StrideQuantBZeroPoint, true, scale_70, acc[7]); + } + } // k_blks_remaining + + if constexpr (NCols4 == 8) { + __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + if (BiasPtr != nullptr) { + acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); + acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); + } + _mm_storeu_ps(SumPtr, acc_0); + _mm_storeu_ps(SumPtr+4, acc_1); + } else { + __m128 acc_x = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + } + + // move to next NCols columns + + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + nblk -= NCols4; + } // while (nblk >= 0) + + nblk += NCols4; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen32); + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a1 = Q8BlkScale(QuantABlk1); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); + + // increment block pointers + QuantAPtr += Q8BlkSize(BlkLen32) * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + + const float& scale_a0 = Q8BlkScale(QuantABlk0); + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + accumulator(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; +} diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 929f8aefc9ed..0e1537c3bc0c 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -402,6 +402,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Sun, 12 May 2024 19:51:31 -0700 Subject: [PATCH 07/41] apply blksum to blklen32 and 64: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 784668090 ns; SQNBITGEMM<4>/BlkLen:64/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 754939430 ns Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 8 + onnxruntime/core/mlas/inc/mlas_qnbit.h | 17 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 350 +++++++++---- onnxruntime/core/mlas/lib/sqnbitgemm.h | 18 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 110 ++-- ...=> sqnbitgemm_kernel_avx2_int8_blklen32.h} | 79 +-- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 492 ++++++++++++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 3 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 3 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 31 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 1 + .../test/mlas/unittest/test_fgemm_fixture.h | 1 + .../test/mlas/unittest/test_sqnbitgemm.cpp | 2 + 13 files changed, 920 insertions(+), 195 deletions(-) rename onnxruntime/core/mlas/lib/{sqnbitgemm_kernel_avx2_int8.h => sqnbitgemm_kernel_avx2_int8_blklen32.h} (96%) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index dbc678c9bc9c..24d96699952d 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -199,10 +199,18 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); if (prepacked_weights) { + // TODO: cannot use packed_b_ after + assert(false); prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } is_packed = true; + } else if (input_idx == 2) { + // MlasSQNBitGemmPackQuantBData with scales + assert(false); + } else if (input_idx == 3) { + // MlasSQNBitGemmPackQuantBData with zp + assert(false); } #endif // defined(ORT_NEURAL_SPEED) diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 32e9cc98106d..b7b347232518 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -48,6 +48,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block const float* Bias = nullptr; ///< optional address of Bias, vector size N float* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C @@ -159,6 +160,18 @@ MlasSQNBitGemmPackQuantBDataSize( /** * @brief Packs the quantized B data in a format that the kernel expects. * + * If the function is called without QuantBScale and QuantBZeroPoint, + * it just packs QuantBData into PackedQuantBDataAndOrBlkSum. + * + * If the function is called with QuantBData, QuantBScale, and QuantBZeroPoint + * additional BlkSum (Scale * zeropoint) is computed and stored at the second part of PackedQuantBDataAndOrBlkSum. + * + * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, + * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, + * and then QuantBScale and QuantBZeroPoint. The second time the function is called with QuantBScale, + * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. + * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. + * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) @@ -176,6 +189,8 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, + void* PackedQuantBDataAndOrBlkSum, + const void* QuantBScale, + const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool = nullptr ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 9ea8514cc8b5..c7aa152180a9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,11 +19,6 @@ Module Name: #include -//#define SQ4BITGEMM_USE_TILE -#ifdef SQ4BITGEMM_USE_TILE -#include "llama.cpp.sgemm.h" -#endif - namespace { @@ -125,7 +120,8 @@ SQNBitGemmPerGemmWorkspaceSize( case SQNBitGemmVariant_BitWidth4_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } default: { @@ -198,6 +194,37 @@ MlasSQNBitGemmPackQuantBDataSize( return 0; } +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + constexpr size_t BlkBitWidth = 4; + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + } + std::byte* PackedQuantBData; + float* QuantBBlkSum; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + +struct PerGemmQuantAWorkspace { + PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) + : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + QuantData = (std::byte*)PerGemmWorkspace; + QuantScale = (float*)(QuantData + M * BlockCountK * BlkLen); + BlockSum = QuantScale + M * BlockCountK; + } + std::byte* QuantData; // NxBlockCountKxBlkLen + float* QuantScale; // NxBlockCountK + float* BlockSum; // NxBlockCountK + void* PerGemmWorkspace_; // memory for above data + size_t M_, BlockCountK_, BlkLen_; +}; + void MLASCALL MlasSQNBitGemmPackQuantBData( size_t N, @@ -206,7 +233,9 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBData, + void* PackedQuantBDataAndOrBlkSum, + const void* QuantBScale, + const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) { @@ -215,17 +244,38 @@ MlasSQNBitGemmPackQuantBData( return; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBData), - ThreadPool - ); - return; + if (BlkBitWidth == 4) { + if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + assert(QuantBScale == nullptr); + assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSum), + ThreadPool + ); + return; + } else if (Dispatch->SQ4BitGemmPackQuantBDataAndSumBlk != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSum, N, BlockCountK, BlkLen); + assert(QuantBScale); + //assert(QuantBZeroPoint); // QuantBZeroPoint is nullptr if symetric quantization. + Dispatch->SQ4BitGemmPackQuantBDataAndSumBlk( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + packed_quant_b.PackedQuantBData, + static_cast(QuantBScale), + static_cast(QuantBZeroPoint), + packed_quant_b.QuantBBlkSum, + ThreadPool + ); + } } } @@ -262,7 +312,7 @@ typedef void(SQNBitGemmFn)( size_t BlkLen, size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, + PerGemmQuantAWorkspace* PerGemmWorkspace, size_t RangeStartM, size_t RangeCountM, size_t RangeStartN, @@ -274,7 +324,7 @@ SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, + PerGemmQuantAWorkspace* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, @@ -393,14 +443,13 @@ SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - void* const PerGemmWorkspace, + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, const size_t RangeCountN ) { -//#ifndef SQ4BITGEMM_USE_TILE //#ifdef MLAS_TARGET_AMD64_IX86 // if (RangeCountM != 1) { // // perf experiment shows fp32 is faster than int8 in M > 1 cases. @@ -411,18 +460,20 @@ SQ4BitGemm_CompInt8( // ); // return; // } -//#endif //#endif constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); - const size_t lda = k_blks * Q8BlkSize(BlkLen); + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); const size_t ldc = DataParams->ldc; const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); - const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; @@ -430,27 +481,129 @@ SQ4BitGemm_CompInt8( (DataParams->QuantBZeroPoint == nullptr) ? nullptr : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; if (RangeCountM == 1) { + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8) + { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().GemmFloatKernel( + ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true + ); + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + BlkLen, + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + lda, + ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + } else { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + } + return; + } + + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8) + { size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); - const std::byte* a_row = QuantA; const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; const std::byte* b_col_zp = (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndSumBlk) { + size_t RowsRemaining = RangeCountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, k_blks, RowsRemaining, CountN, k_blks, ldc, 1.f, true + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += k_blks * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + + c_blk = C + n; + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + lda, + ldc ); if (DataParams->PostProcessor != nullptr) { @@ -459,6 +612,7 @@ SQ4BitGemm_CompInt8( RangeCountM, CountN, ldc ); } + } return; } @@ -474,73 +628,31 @@ SQ4BitGemm_CompInt8( const float* b_col_scale = QuantBScale + n * k_blks; const std::byte* b_col_zp = (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; -#ifdef SQ4BITGEMM_USE_TILE - int64_t llama_cpp_m = CountN; - int64_t llama_cpp_n = RangeCountM; - int64_t llama_cpp_k = k_blks; - const std::byte* llama_cpp_A = b_col; - int64_t llama_cpp_lda = ldb; - const std::byte* llama_cpp_B = a_row; - int64_t llama_cpp_ldb = lda; - float* llama_cpp_C = c_blk; - int64_t llama_cpp_ldc = ldc; - const float* llama_cpp_QuantBScale = b_col_scale; - int64_t llama_cpp_StrideQuantBScale = k_blks; - llamafile_sgemm( - llama_cpp_m, llama_cpp_n, llama_cpp_k, - llama_cpp_A, llama_cpp_lda, - llama_cpp_B, llama_cpp_ldb, - llama_cpp_C, llama_cpp_ldc, - llama_cpp_QuantBScale, llama_cpp_StrideQuantBScale - ); - (void)bias; - (void)b_col_zp; -#else - if (BlkLen == 32) { - // TODO: this does not work is RangeCountM is not the total M. - const float* a_row_scale = (const float*)(QuantA + RangeCountM * k_blks * BlkLen); - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( BlkLen, - a_row, - a_row_scale, - b_col, - b_col_scale, - b_col_zp, - c_blk, - RangeCountM, - CountN, - K, - k_blks, - bias, - lda, - ldc + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); - } else { - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - // GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - // BlkLen, - // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, - // K, k_blks, bias, lda, ldc - //); - - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } + // GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + // BlkLen, + // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, + // K, k_blks, bias, lda, ldc + //); - c_blk += ldc; - a_row += lda; + // TODO: shall be processed outsize the loop + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); } + + c_blk += ldc; + a_row += lda; } -#endif } } @@ -577,29 +689,38 @@ InitializeWorkspace_CompInt8( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - if (QuantizeARow) { + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; for (size_t m = 0; m < M; ++m) { QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); ARowPtr += data.lda; QuantARowPtr += QuantAStride; } - } else { - float* QuantARowScalePtr = (float*)(QuantARowPtr + M * BlockCountK * BlkLen); + }); + } else { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr); - + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); ARowPtr += data.lda; QuantARowPtr += BlockCountK * BlkLen; QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; } - } - }); + }); + } } struct Operations { @@ -659,12 +780,22 @@ MlasSQNBitGemmBatch( const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); + if (ComputeType == CompInt8) { + // TODO: shall sepqrate QuantBBlkSum from QuantBData + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else { + ComputeOperation(BlkLen, K, Data, nullptr, 0, M, 0, N); + } } return; } @@ -714,9 +845,6 @@ MlasSQNBitGemmBatch( const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; const auto* Data = &DataParams[gemm_i]; - void* PerGemmWorkspace = reinterpret_cast( - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride - ); const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; @@ -727,6 +855,16 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + if (ComputeType == CompInt8) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else { + ComputeOperation(BlkLen, K, Data, nullptr, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index eed866a70a30..2b947a9637dc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -126,6 +126,21 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, // BlockCountK by N => (BlockCountK * N) / 16 by 16 + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndSumBlk = nullptr; + // // CompFp32 kernel function prototypes. // @@ -270,7 +285,8 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const float* A, size_t CountK, std::byte* QuantA, - float* QuantAScale + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARow_CompInt8_Fn2* QuantizeARow_CompInt8_2 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 005846340ba4..ec9fab9260c0 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,7 +22,8 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" -#include "sqnbitgemm_kernel_avx2_int8.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen64.h" MLAS_FORCEINLINE __m256 @@ -363,7 +364,7 @@ SQ4BitGemmKernel_CompInt8_avx2( if (BlkLen == 16) { assert(false); } else if (BlkLen == 32) { - MlasQ4Int8TileGemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -377,21 +378,21 @@ SQ4BitGemmKernel_CompInt8_avx2( Bias, ldc ); - //MlasQ4Int8GemmKernelBlkLen32Avx2>( - // QuantA, - // QuantBData, - // QuantBScale, - // QuantBZeroPoint, - // C, - // CountM, - // CountN, - // CountK, - // BlockCountK, - // Bias, - // lda, - // ldc - //); - } else { + } else if (BlkLen >= 64) { + MlasQ4Int8GemmKernelBlkLen64Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + else { assert(false); } } else { @@ -399,7 +400,7 @@ SQ4BitGemmKernel_CompInt8_avx2( if (BlkLen == 16) { assert(false); } else if (BlkLen == 32) { - MlasQ4Int8TileGemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -413,20 +414,19 @@ SQ4BitGemmKernel_CompInt8_avx2( Bias, ldc ); - // MlasQ4Int8GemmKernelBlkLen32Avx2>( - // QuantA, - // QuantBData, - // QuantBScale, - // QuantBZeroPoint, - // C, - // CountM, - // CountN, - // CountK, - // BlockCountK, - // Bias, - // lda, - // ldc - //); + } else if (BlkLen >= 64) { + MlasQ4Int8GemmKernelBlkLen64Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); } else { assert(false); } @@ -464,7 +464,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( ); } else if (BlkLen == 32) { const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8TileGemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -478,16 +478,6 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( Bias, 0 // ldc, not needed when CountM = 1 ); - // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - // QuantA, - // QuantBData, - // QuantBScale, - // QuantBZeroPoint, - // C, - // CountN, - // BlockStrideQuantB, - // Bias - //); } else { SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( BlkLen, @@ -518,7 +508,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( ); } else if (BlkLen == 32) { const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8TileGemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -1152,18 +1142,32 @@ convert_2_ps_to_epi8(__m256 v0, __m256 v1) return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); } +// horizontally add 8 int32_t +static inline int +hsum_8_epi32(const __m256i a_8_epi32) +{ + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA, - float* QuantAScale + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m256 signBit = _mm256_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_srli_epi16( + _mm256_cmpeq_epi16(_mm256_castps_si256(signBit), _mm256_castps_si256(signBit)), 15); int8_t* blob = reinterpret_cast(QuantA); float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { @@ -1193,6 +1197,7 @@ QuantizeARow_CompInt8_avx2( const __m256 mul = _mm256_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const int klen = std::min(16, (int)(step - kk)); @@ -1211,12 +1216,20 @@ QuantizeARow_CompInt8_avx2( v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); } - __m128i i_8 = convert_2_ps_to_epi8(v0, v1); - _mm_storeu_si128(dst++, i_8); + __m128i i_16_epi8 = convert_2_ps_to_epi8(v0, v1); + _mm_storeu_si128(dst++, i_16_epi8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i_16_epi8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } @@ -1228,7 +1241,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h similarity index 96% rename from onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h rename to onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index f31e19e688ea..85db4f832272 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -12,7 +12,7 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, const __m256i& one_16_epi16, __m256& acc) { const __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8) + bv_32_epi8, av_32_epi8 ); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -52,7 +52,7 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, + const std::byte* /*QuantBZeroPointPtr*/, const float* scale_a0, const float* scale_a1, const float* scale_b, @@ -67,18 +67,18 @@ accumulate_blklen32_r2c1blk2_avx2( // TODO: will this (the second line below) be faster and not keep low_mask in use? __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + //int8_t zp0, zp1; + //get_2_zps(QuantBZeroPointPtr, zp0, zp1); + //bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + //bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + bv0_32_epi8, av00_32_epi8 ); const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + bv1_32_epi8, av01_32_epi8 ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -97,10 +97,10 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av10_32_epi8, bv0_32_epi8) + bv0_32_epi8, av10_32_epi8 ); const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av11_32_epi8, bv1_32_epi8) + bv1_32_epi8, av11_32_epi8 ); const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); @@ -119,7 +119,7 @@ accumulate_blklen32_r1c1blk2_avx2( const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, + const std::byte* /*QuantBZeroPointPtr*/, const float* scale_a0, const float* scale_b, __m256& acc0 @@ -131,16 +131,16 @@ accumulate_blklen32_r1c1blk2_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + //int8_t zp0, zp1; + //get_2_zps(QuantBZeroPointPtr, zp0, zp1); + //bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); + //bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8) + bv0_32_epi8, av00_32_epi8 ); const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) + bv1_32_epi8, av01_32_epi8 ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -219,7 +219,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m256i& av00_32_epi8, const __m256i& av10_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, + const std::byte* /*QuantBZeroPointPtr*/, const float& combined_scale00, const float& combined_scale10, __m256& acc0, @@ -227,14 +227,15 @@ accumulate_blklen32_r2c1blk1_avx2( bool zp_low_half = true ) { + (void)zp_low_half; // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - const int8_t zp = get_zp(zp_low_half, QuantBZeroPointPtr); - const __m256i bzp = _mm256_set1_epi8(zp); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + //const int8_t zp = get_zp(zp_low_half, QuantBZeroPointPtr); + //const __m256i bzp = _mm256_set1_epi8(zp); + //bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); @@ -283,7 +284,7 @@ static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( const __m256i& av00_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, + const std::byte* /*QuantBZeroPointPtr*/, const float& combined_scale00, __m256& acc0 ) @@ -293,9 +294,9 @@ accumulate_blklen32_r1c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - const int8_t zp = get_zp(true, QuantBZeroPointPtr); - const __m256i bzp = _mm256_set1_epi8(zp); - bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); + //const int8_t zp = get_zp(true, QuantBZeroPointPtr); + //const __m256i bzp = _mm256_set1_epi8(zp); + //bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); @@ -474,8 +475,11 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -594,8 +598,11 @@ Q4Int8Gemm2x4x1BlkLen32Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - _mm_storeu_ps(SumPtr, acc_r0); - _mm_storeu_ps(SumPtr + ldc, acc_r1); + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -712,8 +719,8 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); } - *SumPtr = hsum_float_8(acc0); - *(SumPtr + ldc) = hsum_float_8(acc1); + *SumPtr = hsum_float_8(acc0) - *SumPtr; + *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -868,7 +875,9 @@ Q4Int8GemmXx4BlkLen32Avx2( if (BiasPtr != nullptr) { acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - _mm_storeu_ps(SumPtr, acc_r0); + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -962,7 +971,7 @@ Q4Int8GemmXxXBlkLen32Avx2( accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); } - *SumPtr = hsum_float_8(acc0); + *SumPtr = hsum_float_8(acc0) - *SumPtr; if (BiasPtr) { *SumPtr += *BiasPtr; } @@ -983,7 +992,7 @@ Q4Int8GemmXxXBlkLen32Avx2( template MLAS_FORCEINLINE size_t - MlasQ4Int8TileGemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -1085,7 +1094,7 @@ MLAS_FORCEINLINE template accumulator> MLAS_FORCEINLINE size_t -MlasQ4Int8GemmKernelBlkLen32Avx2( +MlasQ4Int8TileGemmKernelBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h new file mode 100644 index 000000000000..2117a11a7682 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -0,0 +1,492 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 1 blks of 64 4b weights a time + constexpr size_t PerAccuBlk1 = 1; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk1; + QuantAScalePtr += PerAccuBlk1; + QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; + QuantBScalePtr += PerAccuBlk1; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64; + QuantAScalePtr += 1; + QuantBDataPtr += BlkDataSizeInBytes32; + QuantBScalePtr += 1; + } + + *SumPtr = hsum_float_8(acc0) - *SumPtr; + *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 1 blks of 64 4b weights a time + constexpr size_t PerAccuBlk1 = 1; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, + QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk1; + QuantAScalePtr += PerAccuBlk1; + QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; + QuantBScalePtr += PerAccuBlk1; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk1 = 1; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk1; + QuantAScalePtr += PerAccuBlk1; + QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; + QuantBScalePtr += PerAccuBlk1; + } + + *SumPtr = hsum_float_8(acc0) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen64Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen64 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen64Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen64Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 1eca0960cf67..fb82bbd79bed 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -231,7 +231,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 45a69c4f2060..9b9064b175f3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -252,7 +252,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 48419fca2efc..c60892a54b56 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -19,17 +19,21 @@ SQ4BitGemmPackQuantBDataSize( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + const size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + return PackedQuantBDataSize + BlkSumSize; } static void -SQ4BitGemmPackQuantBData( +SQ4BitGemmPackQuantBDataAndSumBlk( size_t N, size_t K, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool ) { @@ -112,6 +116,29 @@ SQ4BitGemmPackQuantBData( } } ); + + MlasTrySimpleParallel( + ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float* QuantBScale = QuantBScaleBegin + src_blk_offset; + uint8_t zp = 8; + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + size_t src_zp_offset = ZPCountK * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; + const std::byte low_mask{0X0F}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = *QuantBScale * zp; + } + ); } void diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 00a7bf3d8af4..ab86a30d94a0 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -67,6 +67,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), QuantBZeroPoint.data(), tp.get()); } diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 05c6a0098eec..78615aad6d51 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -70,6 +70,7 @@ class FgemmShortExecuteTest : public MlasTestFixture 0) { PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + QuantBScale, QuantBZeroPoint, GetMlasThreadPool()); } @@ -402,6 +403,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Tue, 14 May 2024 20:31:08 -0700 Subject: [PATCH 08/41] blklen16 Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 14 +- onnxruntime/core/mlas/lib/sqnbitgemm.h | 3 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 258 +++--- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 734 ++++++++++++++++++ .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 299 ++----- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 213 ++--- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 48 +- 10 files changed, 1032 insertions(+), 543 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index c7aa152180a9..9cfbce3fda17 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -258,12 +258,12 @@ MlasSQNBitGemmPackQuantBData( ThreadPool ); return; - } else if (Dispatch->SQ4BitGemmPackQuantBDataAndSumBlk != nullptr) { + } else if (Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSum, N, BlockCountK, BlkLen); assert(QuantBScale); //assert(QuantBZeroPoint); // QuantBZeroPoint is nullptr if symetric quantization. - Dispatch->SQ4BitGemmPackQuantBDataAndSumBlk( + Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, BlkLen, @@ -496,13 +496,11 @@ SQ4BitGemm_CompInt8( const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; const float* b_blk_sum = QuantBBlkSum + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().GemmFloatKernel( + GetMlasPlatform().GemmFloatKernel( ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true ); @@ -512,7 +510,6 @@ SQ4BitGemm_CompInt8( QuantAScale, b_col, b_col_scale, - b_col_zp, c_blk, RangeCountM, CountN, @@ -567,14 +564,12 @@ SQ4BitGemm_CompInt8( const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; const float* b_blk_sum = QuantBBlkSum + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndSumBlk) { + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { size_t RowsRemaining = RangeCountM; const float* a_blksum_row = ABlockSum; while (RowsRemaining > 0) { @@ -595,7 +590,6 @@ SQ4BitGemm_CompInt8( QuantAScale, b_col, b_col_scale, - b_col_zp, c_blk, RangeCountM, CountN, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 2b947a9637dc..bef6f2b9d725 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -139,7 +139,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndSumBlk = nullptr; + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; // // CompFp32 kernel function prototypes. @@ -249,7 +249,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index ec9fab9260c0..3a138bc07619 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -22,6 +22,7 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx2_int8_blklen64.h" @@ -348,7 +349,6 @@ SQ4BitGemmKernel_CompInt8_avx2( const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -359,77 +359,48 @@ SQ4BitGemmKernel_CompInt8_avx2( size_t ldc ) { - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - assert(false); - } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen >= 64) { - MlasQ4Int8GemmKernelBlkLen64Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } - else { - assert(false); - } + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); } else { - constexpr bool HasZeroPoint = false; - if (BlkLen == 16) { - assert(false); - } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountM, - CountN, - CountK, - BlockCountK, - Bias, - ldc - ); - } else if (BlkLen >= 64) { - MlasQ4Int8GemmKernelBlkLen64Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - CountM, - CountN, - BlockCountK, - Bias, - ldc - ); - } else { - assert(false); - } + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); } } @@ -440,7 +411,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( const std::byte* QuantA, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, + const std::byte* /*QuantBZeroPoint*/, float* C, size_t CountN, size_t CountK, @@ -448,104 +419,51 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( const float* Bias ) { - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, // CountM - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0 // ldc, not needed when CountM = 1 - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + if (BlkLen == 16) { + const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); + MlasQ4Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + 1, // CountM + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0 // ldc, not needed when CountM = 1 + ); + } else if (BlkLen == 32) { + const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); + MlasQ4Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + 1, // CountM + CountN, + CountK, + BlockStrideQuantB, + Bias, + 0 // ldc, not needed when CountM = 1 + ); } else { - constexpr bool HasZeroPoint = false; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - 1, // CountM - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0 // ldc, not needed when CountM = 1 - ); - // SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - // QuantA, - // QuantBData, - // QuantBScale, - // QuantBZeroPoint, - // C, - // CountN, - // BlockStrideQuantB, - // Bias - //); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + 1, // CountM + CountN, + BlockStrideQuantB, + Bias, + 0 // ldc, not needed when CountM = 1 + ); } } @@ -1242,7 +1160,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = nullptr; - d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h new file mode 100644 index 000000000000..77696cdb6054 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -0,0 +1,734 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +__m256 +load_and_broadcast_4_scale_2(const float* scale) +{ + // 3 2 1 0 3 2 1 0 (7) + __m256 scale_2_4_ps = _mm256_broadcast_ps((__m128 const*)scale); + + // 2 1 0 0 2 1 0 0 (1) + __m256 scale_2_4_ps_shifted = _mm256_castsi256_ps( + _mm256_bslli_epi128(_mm256_castps_si256(scale_2_4_ps), 4) + ); + + // 3 2 1 0 2 1 0 0: (3) cross lane + __m256 scale_2_4_ps_permutted = _mm256_permute2f128_ps( + scale_2_4_ps_shifted, scale_2_4_ps, 0b00110000 + ); + + // in accumulate_r1_4blk_dot and accumulate_r2_4blk_dot + // _mm256_hadd_epi16 inter leaved dot sum, resulting: + // a31b31|a30b30|a11b11|a10b10|a21b21|a20b20|a01b01|a00b00 + // therefore we need weight to be: + // 3 3 1 1 2 2 0 0 (1) + return _mm256_permute_ps(scale_2_4_ps_permutted, 0b11110101); +} + +MLAS_FORCEINLINE +__m256i +load_16_epi8_as_epi16(const std::byte* ablob) +{ + const __m128i av_epi8 = _mm_lddqu_si128(reinterpret_cast(ablob)); + __m256i av_epi16 = _mm256_cvtepi8_epi16(av_epi8); + return av_epi16; +} + +MLAS_FORCEINLINE void +accumulate_r1_4blk_dot( + const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a, const float* scale_b, + __m256& acc) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av0_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av1_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a_4_ps = load_and_broadcast_4_scale_2(scale_a); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a_4_ps, scale_b_4_ps); + acc = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc); +} + +MLAS_FORCEINLINE void +accumulate_r2_4blk_dot( + const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, + const float* scale_a0, const float* scale_a1, const float* scale_b, + __m256& acc0, __m256& acc1 +) +{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_inter_leaved_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_inter_leaved_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16); + const __m256 sum_8_inter_leaved_ps = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32); + + // load 4 scales + __m256 scale_a0_4_ps = load_and_broadcast_4_scale_2(scale_a0); + __m256 scale_b_4_ps = load_and_broadcast_4_scale_2(scale_b); + __m256 scale_8_ps = _mm256_mul_ps(scale_a0_4_ps, scale_b_4_ps); + acc0 = _mm256_fmadd_ps(sum_8_inter_leaved_ps, scale_8_ps, acc0); + + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_inter_leaved_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_inter_leaved_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_inter_leaved_epi16_); + const __m256 sum_inter_leaved_ps_ = _mm256_cvtepi32_ps(sum_8_inter_leaved_epi32_); + + __m256 scale_a1_4_ps = load_and_broadcast_4_scale_2(scale_a1); + scale_8_ps = _mm256_mul_ps(scale_a1_4_ps, scale_b_4_ps); + acc1 = _mm256_fmadd_ps(sum_inter_leaved_ps_, scale_8_ps, acc1); +} + +static MLAS_FORCEINLINE __m256i +load_4b_packed_1blk_blklen16(const std::byte* QuantBDataPtr) +{ + // | 0 8 |...| 7 15 | + const __m128i bv_packed_64 = _mm_loadl_epi64(reinterpret_cast(QuantBDataPtr)); + const __m128i low_mask = _mm_set1_epi8(0xF); + const __m128i lower_8_epu8 = _mm_and_si128(bv_packed_64, low_mask); // 0~7 + const __m128i upper_8_epu8 = _mm_bslli_si128(_mm_and_si128(_mm_srli_epi16(bv_packed_64, 4), low_mask), 8); // 8~15 + const __m256i bv_16_epu16 = _mm256_cvtepi8_epi16(_mm_add_epi8(upper_8_epu8, lower_8_epu8)); // 0~15 + return bv_16_epu16; +} + +static MLAS_FORCEINLINE void +load_4b_packed_4blk_blklen16(const std::byte* QuantBDataPtr, __m256i& bv0_32_epi8, __m256i& bv1_32_epi8) +{ + // | 0 8 |...| 7 15 | 16 24 |...| 23 31 ||| 32 40 |...| 39 47 | 48 56 |...| 55 63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + // 0~7, 16~22, 32~39, 48~55 + __m256i bv0_32_epi8_ = _mm256_and_si256(bv_packed, low_mask); + // 8~15, 24~31, 40~47, 56~63: (1) + __m256i bv1_32_epi8_ = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8_), 4); + // 0~7, 32~39, 16~22, 48~55 <- cross lane (3) + bv0_32_epi8_ = _mm256_permute4x64_epi64(bv0_32_epi8_, 0b11011000); + // 40~47, 8~15, 56~63, 24~31 <- cross lane (3) + bv1_32_epi8_ = _mm256_permute4x64_epi64(bv1_32_epi8_, 0b01110010); + + // 0~7, 8~15, 16~22, 24~31: (1) + bv0_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b11001100); + + // 40~47, 32~39, 56~63, 48~55: (1) + bv1_32_epi8 = _mm256_blend_epi32(bv0_32_epi8_, bv1_32_epi8_, 0b00110011); + + // 32~39, 40~47, 48~55, 56~63: (1) + bv1_32_epi8 = _mm256_shuffle_epi32(bv1_32_epi8, 0b01001110); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r2_4blk_dot(av00_32_epi8, av01_32_epi8, av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, + scale_a0, scale_a1, scale_b, acc0, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk4_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc +) +{ + __m256i bv0_32_epi8, bv1_32_epi8; + load_4b_packed_4blk_blklen16(QuantBDataPtr, bv0_32_epi8, bv1_32_epi8); + accumulate_r1_4blk_dot(av0_32_epi8, av1_32_epi8, bv0_32_epi8, bv1_32_epi8, scale_a, scale_b, acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk1_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale0, + const float& combined_scale1, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av0_32_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc0 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale0), prod_8_ps, acc0); + + prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av1_32_epi8); + prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc1 = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale1), prod_8_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk1_avx2( + const __m256i& av_16_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + __m256& acc +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m256i bv_16_epu16 = load_4b_packed_1blk_blklen16(QuantBDataPtr); + + __m256i prod_8_epi32 = _mm256_madd_epi16(bv_16_epu16, av_16_epi8); + __m256 prod_8_ps = _mm256_cvtepi32_ps(prod_8_epi32); + acc = _mm256_fmadd_ps(_mm256_set1_ps(combined_scale), prod_8_ps, acc); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + } + + { + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + } + + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + // process 4 blks of 64 4b weights a time + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_blklen16_r2c1blk4_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0) - *SumPtr; + *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + // Col0 + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + } + { + // Col1 + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); + } + { + // Col2 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); + } + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen16_r1c1blk4_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, + QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes8 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + while (k_blks_remaining-- > 0) { + const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_16_epi16, QuantBDataPtr, scale_00, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes8; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 85db4f832272..def16c7068b5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -19,32 +19,6 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } -MLAS_FORCEINLINE void -accumulate_2blk_dot( - const __m256i& av0_32_epi8, const __m256i& av1_32_epi8, - const __m256i& bv0_32_epi8, const __m256i& bv1_32_epi8, - const float& combined_scale0, const float& combined_scale1, - const __m256i& one_16_epi16, - __m256& acc) -{ - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) - ); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale_8_ps = _mm256_set_ps( - combined_scale1, combined_scale1, combined_scale0, combined_scale0, - combined_scale1, combined_scale1, combined_scale0, combined_scale0 - ); - acc = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc); -} - -template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( const __m256i& av00_32_epi8, @@ -52,7 +26,6 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256i& av10_32_epi8, const __m256i& av11_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* /*QuantBZeroPointPtr*/, const float* scale_a0, const float* scale_a1, const float* scale_b, @@ -60,20 +33,18 @@ accumulate_blklen32_r2c1blk2_avx2( __m256& acc1 ) { - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - // TODO: will this (the second line below) be faster and not keep low_mask in use? - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - //int8_t zp0, zp1; - //get_2_zps(QuantBZeroPointPtr, zp0, zp1); - //bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - //bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + // low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + // TODO: this (the second line below) is faster and does not keep low_mask in use. + // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - //accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); - //accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( bv0_32_epi8, av00_32_epi8 ); @@ -82,6 +53,8 @@ accumulate_blklen32_r2c1blk2_avx2( ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -113,13 +86,11 @@ accumulate_blklen32_r2c1blk2_avx2( acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); } -template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( const __m256i& av00_32_epi8, const __m256i& av01_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* /*QuantBZeroPointPtr*/, const float* scale_a0, const float* scale_b, __m256& acc0 @@ -128,20 +99,11 @@ accumulate_blklen32_r1c1blk2_avx2( // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - - //int8_t zp0, zp1; - //get_2_zps(QuantBZeroPointPtr, zp0, zp1); - //bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - //bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - bv0_32_epi8, av00_32_epi8 - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - bv1_32_epi8, av01_32_epi8 - ); + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); @@ -159,132 +121,31 @@ accumulate_blklen32_r1c1blk2_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); } -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - const float& combined_scale01, - const float& combined_scale10, - const float& combined_scale11, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). - // however, it is faster to generate one_16_epi16 than calling _mm256_set1_ep16(1); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); - //low_mask = _mm256_packus_epi16(low_mask, low_mask); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 - // TODO: will this (the second line below) be faster and not keep low_mask in use? - // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - - //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); - - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - // generating constant 1s is fater here. - // __m256i one = _mm256_set1_epi16(1); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - - // performance gains 7% by calling this (accumulate_2blk_dot) instead of 2 accumulate_1blk_dot calls. - // accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); - // accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); - // accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, combined_scale10, one_16_epi16, acc1); - // accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, combined_scale11, one_16_epi16, acc1); - accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); - accumulate_2blk_dot(av10_32_epi8, av11_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale10, combined_scale11, one_16_epi16, acc1); -} - -template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk1_avx2( const __m256i& av00_32_epi8, const __m256i& av10_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* /*QuantBZeroPointPtr*/, const float& combined_scale00, const float& combined_scale10, __m256& acc0, - __m256& acc1, - bool zp_low_half = true + __m256& acc1 ) { - (void)zp_low_half; // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - //const int8_t zp = get_zp(zp_low_half, QuantBZeroPointPtr); - //const __m256i bzp = _mm256_set1_epi8(zp); - //bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); } -template -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const std::byte* QuantBZeroPointPtr, - const float& combined_scale00, - const float& combined_scale01, - __m256& acc0) -{ - // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | v32 v48 | v33 v49 | ... | v46 v62 | v47 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...14, 15, 32, 33,...46, 47 - // TODO: will this be faster and save a use of low_mask? - // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); - - //// This saves one _mm256_extracti128_si256 against using _mm256_set_m128i. - ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); - //__m256i bv1_32_epi8 = _mm256_insertf128_si256(bv1, _mm256_extracti128_si256(bv0, 1), 0); - - int8_t zp0, zp1; - get_2_zps(QuantBZeroPointPtr, zp0, zp1); - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(zp0)); - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(zp1)); - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - //accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, combined_scale00, one_16_epi16, acc0); - //accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, combined_scale01, one_16_epi16, acc0); - accumulate_2blk_dot(av00_32_epi8, av01_32_epi8, bv0_32_epi8, bv1_32_epi8, combined_scale00, combined_scale01, one_16_epi16, acc0); -} - -template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( const __m256i& av00_32_epi8, const std::byte* QuantBDataPtr, - const std::byte* /*QuantBZeroPointPtr*/, const float& combined_scale00, __m256& acc0 ) @@ -294,22 +155,16 @@ accumulate_blklen32_r1c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - //const int8_t zp = get_zp(true, QuantBZeroPointPtr); - //const __m256i bzp = _mm256_set1_epi8(zp); - //bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); } -template MLAS_FORCEINLINE void Q4Int8Gemm2x4x2BlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -330,7 +185,6 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); @@ -339,7 +193,6 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; auto* SumPtr = C + m * ldc; @@ -349,7 +202,6 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc[NCols4 * NRows2] = { _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), @@ -382,7 +234,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( //const float& scale_10 = scale_a10 * QuantBScalePtr[0]; //const float& scale_11 = scale_a11 * QuantBScalePtr[1]; //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); } @@ -393,7 +245,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( //const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; //const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); } @@ -404,7 +256,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( //const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; //const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); } @@ -415,7 +267,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( //const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; //const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); } @@ -424,9 +276,6 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } } // k_blks_remaining // TODO: use a loop in case PerAccuBlk2 is not 2. @@ -443,28 +292,28 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -484,9 +333,6 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; @@ -617,13 +463,11 @@ Q4Int8Gemm2x4x1BlkLen32Avx2( } } -template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -643,7 +487,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); @@ -652,7 +495,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; float* SumPtr = C + m * ldc; @@ -662,7 +504,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); @@ -680,18 +521,8 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); - //const float& scale_a00 = Q8BlkScale(QuantABlk00); - //const float& scale_a01 = Q8BlkScale(QuantABlk01); - //const float& scale_a10 = Q8BlkScale(QuantABlk10); - //const float& scale_a11 = Q8BlkScale(QuantABlk11); - - //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - //const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - //const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc0, acc1); - accumulate_blklen32_r2c1blk2_avx2( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, + accumulate_blklen32_r2c1blk2_avx2( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers @@ -699,9 +530,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } } // TODO: use a loop in case PerAccuBlk2 is not 2. @@ -716,7 +544,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_10, acc0, acc1); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); } *SumPtr = hsum_float_8(acc0) - *SumPtr; @@ -729,9 +557,6 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( // move to next column QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } BiasPtr += BiasPtr != nullptr ? 1 : 0; SumPtr += 1; @@ -739,14 +564,12 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } -template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -767,16 +590,13 @@ Q4Int8GemmXx4BlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM < NRows2); assert(CountN % NCols4 == 0); for (size_t m = 0; m < CountM; m++) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; auto* SumPtr = C + m * ldc; @@ -786,7 +606,6 @@ Q4Int8GemmXx4BlkLen32Avx2( const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; size_t k_blks_remaining = BlockCountK; @@ -799,8 +618,7 @@ Q4Int8GemmXx4BlkLen32Avx2( //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); } { @@ -808,8 +626,8 @@ Q4Int8GemmXx4BlkLen32Avx2( //const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; //const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1] ); } @@ -818,8 +636,8 @@ Q4Int8GemmXx4BlkLen32Avx2( //const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; //const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2] ); } @@ -828,8 +646,8 @@ Q4Int8GemmXx4BlkLen32Avx2( //const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; //const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); } // increment block pointers @@ -837,9 +655,6 @@ Q4Int8GemmXx4BlkLen32Avx2( QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } } // TODO: use a loop in case PerAccuBlk2 is not 2. @@ -852,22 +667,22 @@ Q4Int8GemmXx4BlkLen32Avx2( { // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc[0]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); } { // Col1 const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, acc[1]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); } { // Col2 const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, acc[2]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, acc[3]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); } } @@ -882,24 +697,18 @@ Q4Int8GemmXx4BlkLen32Avx2( // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } } } -template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -920,7 +729,6 @@ Q4Int8GemmXxXBlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM < NRows2); @@ -929,7 +737,6 @@ Q4Int8GemmXxXBlkLen32Avx2( for (size_t m = 0; m < CountM; m++) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; auto* SumPtr = C + m * ldc; @@ -938,7 +745,6 @@ Q4Int8GemmXxXBlkLen32Avx2( const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc0 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; @@ -949,8 +755,8 @@ Q4Int8GemmXxXBlkLen32Avx2( //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); - accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, + accumulate_blklen32_r1c1blk2_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); // increment block pointers @@ -958,9 +764,6 @@ Q4Int8GemmXxXBlkLen32Avx2( QuantAScalePtr += PerAccuBlk2; QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; QuantBScalePtr += PerAccuBlk2; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += 1; - } } // TODO: use a loop in case PerAccuBlk2 is not 2. @@ -968,7 +771,7 @@ Q4Int8GemmXxXBlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, acc0); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); } *SumPtr = hsum_float_8(acc0) - *SumPtr; @@ -979,17 +782,12 @@ Q4Int8GemmXxXBlkLen32Avx2( // move to next column QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += StrideQuantBZeroPoint; - } - BiasPtr += BiasPtr != nullptr ? 1 : 0; SumPtr += 1; } } } -template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen32Avx2( @@ -997,7 +795,6 @@ MLAS_FORCEINLINE const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* QuantBZeroPoint, float* C, size_t CountM, size_t CountN, @@ -1016,7 +813,6 @@ MLAS_FORCEINLINE const size_t lda_scale = BlockCountK; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer @@ -1026,13 +822,11 @@ MLAS_FORCEINLINE size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - //Q4Int8Gemm2x4x1BlkLen32Avx2( - Q4Int8Gemm2x4x2BlkLen32Avx2( + Q4Int8Gemm2x4x2BlkLen32Avx2( QuantA, QuantAScale, QuantBData, QuantBScale, - QuantBZeroPoint, C, multipleRows, multipleCols, @@ -1042,12 +836,11 @@ MLAS_FORCEINLINE ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8Gemm2xXBlkLen32Avx2( + Q4Int8Gemm2xXBlkLen32Avx2( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleCols, multipleRows, remainingCols, @@ -1057,12 +850,11 @@ MLAS_FORCEINLINE } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmXx4BlkLen32Avx2( + Q4Int8GemmXx4BlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, QuantBScale, - QuantBZeroPoint, C + multipleRows * ldc, remainingRows, multipleCols, @@ -1072,12 +864,11 @@ MLAS_FORCEINLINE } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmXxXBlkLen32Avx2( + Q4Int8GemmXxXBlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, QuantBScale + multipleCols * StrideQuantBScale, - QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, remainingCols, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 2117a11a7682..01853b9b18dc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -83,6 +83,7 @@ accumulate_blklen64_r1c1blk1_avx2( MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -95,17 +96,17 @@ Q4Int8GemmR2xC4BlkLen64Avx2( size_t ldc ) { - constexpr size_t BlkLen64 = 64; constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); - - // process 1 blks of 64 4b weights a time - constexpr size_t PerAccuBlk1 = 1; - - const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); @@ -131,25 +132,23 @@ Q4Int8GemmR2xC4BlkLen64Avx2( // process 1 blks of 64 4b weights a time for (size_t k = 0; k < BlockCountK; ++k) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk1; - QuantAScalePtr += PerAccuBlk1; - QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; - QuantBScalePtr += PerAccuBlk1; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; } // k_blks_remaining __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -176,6 +175,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -188,14 +188,17 @@ Q4Int8GemmR2xC1BlkLen64Avx2( size_t ldc ) { - constexpr size_t BlkLen64 = 64; constexpr size_t BlkBitWidth4 = 4; [[maybe_unused]] constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); @@ -217,19 +220,20 @@ Q4Int8GemmR2xC1BlkLen64Avx2( __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); for (size_t k = 0; k < BlockCountK; ++k) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); - - // increment block pointers - QuantAPtr += BlkLen64; - QuantAScalePtr += 1; - QuantBDataPtr += BlkDataSizeInBytes32; - QuantBScalePtr += 1; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; } *SumPtr = hsum_float_8(acc0) - *SumPtr; @@ -250,6 +254,7 @@ Q4Int8GemmR2xC1BlkLen64Avx2( MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -262,17 +267,17 @@ Q4Int8GemmR1xC4BlkLen64Avx2( size_t ldc ) { - constexpr size_t BlkLen64 = 64; constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + constexpr size_t SubblkLen = 64; - // process 1 blks of 64 4b weights a time - constexpr size_t PerAccuBlk1 = 1; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); @@ -293,25 +298,20 @@ Q4Int8GemmR1xC4BlkLen64Avx2( __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; for (size_t k = 0; k < BlockCountK; ++k) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, - QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk1; - QuantAScalePtr += PerAccuBlk1; - QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; - QuantBScalePtr += PerAccuBlk1; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; } __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -333,6 +333,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -345,17 +346,17 @@ Q4Int8GemmR1xC1BlkLen64Avx2( size_t ldc ) { - constexpr size_t BlkLen64 = 64; constexpr size_t BlkBitWidth4 = 4; [[maybe_unused]] constexpr size_t NCols4 = 4; [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes32 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + constexpr size_t SubblkLen = 64; - // process 2 blks of 64 4b weights a time - constexpr size_t PerAccuBlk1 = 1; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; - const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); @@ -375,17 +376,20 @@ Q4Int8GemmR1xC1BlkLen64Avx2( __m256 acc0 = _mm256_setzero_ps(); for (size_t k = 0; k < BlockCountK; ++k) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - - accumulate_blklen64_r1c1blk1_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); - - // increment block pointers - QuantAPtr += BlkLen64 * PerAccuBlk1; - QuantAScalePtr += PerAccuBlk1; - QuantBDataPtr += BlkDataSizeInBytes32 * PerAccuBlk1; - QuantBScalePtr += PerAccuBlk1; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_blklen64_r1c1blk1_avx2( + av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; } *SumPtr = hsum_float_8(acc0) - *SumPtr; @@ -402,29 +406,28 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } -MLAS_FORCEINLINE - size_t - MlasQ4Int8GemmKernelBlkLen64Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc - ) +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) { - constexpr size_t BlkLen64 = 64; constexpr size_t BlkBitWidth4 = 4; constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; - const size_t lda = BlockCountK * BlkLen64 * sizeof(int8_t); + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); const size_t lda_scale = BlockCountK; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); const size_t StrideQuantBScale = BlockCountK; size_t remainingRows = CountM % NRows2; @@ -434,6 +437,7 @@ MLAS_FORCEINLINE if (multipleRows > 0 && multipleCols > 0) { Q4Int8GemmR2xC4BlkLen64Avx2( + BlkLen, QuantA, QuantAScale, QuantBData, @@ -448,6 +452,7 @@ MLAS_FORCEINLINE } if (remainingCols > 0 && multipleRows > 0) { Q4Int8GemmR2xC1BlkLen64Avx2( + BlkLen, QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, @@ -462,6 +467,7 @@ MLAS_FORCEINLINE if (remainingRows > 0 && multipleCols > 0) { Q4Int8GemmR1xC4BlkLen64Avx2( + BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, @@ -476,6 +482,7 @@ MLAS_FORCEINLINE if (remainingCols > 0 && remainingRows > 0) { Q4Int8GemmR1xC1BlkLen64Avx2( + BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index fb82bbd79bed..8345302de749 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -232,7 +232,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = nullptr; - d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 9b9064b175f3..546e897865ed 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -253,7 +253,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = nullptr; - d.SQ4BitGemmPackQuantBDataAndSumBlk = SQ4BitGemmPackQuantBDataAndSumBlk; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index c60892a54b56..0836e6252b91 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -24,7 +24,7 @@ SQ4BitGemmPackQuantBDataSize( } static void -SQ4BitGemmPackQuantBDataAndSumBlk( +SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 8c55b023c710..f8a30f88796b 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -213,12 +213,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { auto print_matrix = [](size_t nrows, size_t ncols, const float* data) { for (size_t row = 0; row < nrows; ++row) { for (size_t col = 0; col < ncols; ++col) { - std::cout << data[row * ncols + col] << "\t"; + std::cout << data[row * ncols + col] << ", "; } std::cout << "\n"; } }; + auto print_matrix_col = [](size_t nrows, size_t ncols, size_t col, const float* data) { + for (size_t row = 0; row < nrows; ++row) { + std::cout << data[row * ncols + col] << ", "; + } + std::cout << "\n"; + }; + std::cout << "A:\n"; print_matrix(M, K, A); std::cout << "B:\n"; @@ -384,6 +391,42 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Sat, 25 May 2024 19:19:02 -0700 Subject: [PATCH 09/41] impl avx512: SQNBITGEMM<4>/BlkLen:32/M:2048/N:4096/K:4096/Threads:1/Symmetric:1/ComputeType:4/real_time_mean 664029830 ns Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/platform.cpp | 2 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 54 +- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 4 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 138 +++- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 544 +++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 675 ++++++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 680 ++++++++++++++++ .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 745 ++++++++++++++++++ .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 39 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 153 ++-- .../test/mlas/unittest/test_sqnbitgemm.cpp | 4 + 11 files changed, 2933 insertions(+), 105 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 3f86b3f7c506..5f29cff04034 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -453,7 +453,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + //this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 3a138bc07619..446f78190539 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1048,29 +1048,6 @@ SQ4BitGemmM1Kernel_CompFp32_avx2( } } -MLAS_FORCEINLINE __m128i -convert_2_ps_to_epi8(__m256 v0, __m256 v1) -{ - __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); - __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); - - __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); - __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); - - return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); -} - -// horizontally add 8 int32_t -static inline int -hsum_8_epi32(const __m256i a_8_epi32) -{ - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - void MLASCALL QuantizeARow_CompInt8_avx2( size_t BlkLen, @@ -1152,6 +1129,37 @@ QuantizeARow_CompInt8_avx2( } } +static void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } + + PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + + if (QuantBScaleBegin) { + ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); + } +} + // // Kernel dispatch structure definition. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 77696cdb6054..13f7db8bfffa 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -7,7 +7,7 @@ #include "sqnbitgemm_kernel_avx_common.h" -__m256 +MLAS_FORCEINLINE __m256 load_and_broadcast_4_scale_2(const float* scale) { // 3 2 1 0 3 2 1 0 (7) @@ -506,7 +506,6 @@ Q4Int8GemmR1xC4BlkLen16Avx2( } while (k_blks_remaining-- > 0) { - // load A const std::byte* QuantABlk0 = QuantAPtr; const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); @@ -614,7 +613,6 @@ Q4Int8GemmR1xC1BlkLen16Avx2( QuantBScalePtr += PerAccuBlk4; } - // TODO: use a loop in case PerAccuBlk2 is not 2. while (k_blks_remaining-- > 0) { const __m256i av_16_epi16 = load_16_epi8_as_epi16(QuantAPtr); const float& scale_a00 = *QuantAScalePtr; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 8345302de749..0691cd29005a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -22,6 +22,10 @@ Module Name: #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // // CompFp32 kernel implementation. @@ -150,18 +154,97 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( // CompInt8 kernel implementation. // +MLAS_FORCEINLINE +void +SQ4BitGemmKernel_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t /*lda*/, + size_t ldc +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } +} + void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ) { // port from MlasQ80BlkQuantRow assert(BlkLen % 16 == 0); const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m256i one_16_epi16 = _mm256_set1_epi16(1); int8_t* blob = reinterpret_cast(QuantA); + float* scale_ptr = QuantAScale; for (size_t k = 0; k < CountK; k += BlkLen) { const size_t step = std::min(BlkLen, CountK - k); @@ -185,13 +268,14 @@ MlasQ80BlkQuantRow_avx512( // Quantize these floats const float scale = maxScalar / 127.f; - *reinterpret_cast(blob) = scale; - blob += sizeof(float); + *scale_ptr = scale; + scale_ptr++; const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; const __m512 mul = _mm512_set1_ps(inverse_scale); __m128i* dst = reinterpret_cast<__m128i*>(blob); + __m256i sum_16_epi16 = _mm256_setzero_si256(); for (size_t kk = 0; kk < step; kk += 16) { const size_t klen = std::min(size_t(16), step - kk); @@ -208,23 +292,50 @@ MlasQ80BlkQuantRow_avx512( // Convert int32 to int8 __m128i i0_8 = _mm512_cvtepi32_epi8(i0); _mm_storeu_si128(dst++, i0_8); + + // accumulate Sum(a_i) + __m256i i_16_epi16 = _mm256_cvtepi8_epi16(i0_8); + sum_16_epi16 = _mm256_hadds_epi16(sum_16_epi16, i_16_epi16); + } if (step < BlkLen) { memset(blob + step, 0, BlkLen - step); } + + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + *AScaledBlkSum = scale * hsum_8_epi32(sum_8_epi32); + AScaledBlkSum++; blob += BlkLen; } } -void MLASCALL -QuantizeARow_CompInt8_avx512( +static void +SQ4BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, size_t BlkLen, - const float* A, - size_t CountK, - std::byte* QuantA + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool ) { - MlasQ80BlkQuantRow_avx512(BlkLen, A, CountK, QuantA); + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == CompInt8) { + SubBlkLen = 128; + } + PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + + if (QuantBScaleBegin) { + ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); + } } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { @@ -232,13 +343,14 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = nullptr; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; - d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8_avx512; + d.SQ4BitGemmM1Kernel_CompInt8 = nullptr; + d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx512; + d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h new file mode 100644 index 000000000000..d759bfc4a176 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -0,0 +1,544 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +//static MLAS_FORCEINLINE __m512i +//combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +//{ +// __m512i result = _mm512_castsi256_si512(a); +// result = _mm512_inserti64x4(result, b, 1); +// return result; +//} + +//static MLAS_FORCEINLINE void +//load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +//{ +// // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | v64 v96 | ... | v95 v127 | +// const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); +// const __m512i low_mask = _mm512_set1_epi8(0x0F); +// __m512i bv0_64_epi8_ = _mm512_and_si512(bv_packed, low_mask); // 0~31, 64~95 +// __m512i bv1_64_epi8_ = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 32~63, 96~127 +// +// // Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 +// __m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); +// __m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); +// __m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); +// __m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); +// +// // Compose new __m512i variables +// bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); +// bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +//} + +static MLAS_FORCEINLINE void +dot_accumulate_1blk( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i zeros = _mm512_setzero_si512(); + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + // process 1 blks of 64 4b weights a time + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } // k_blks_remaining + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]) - *SumPtr; + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]) - *(SumPtr + 1); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]) - *(SumPtr + 2); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]) - *(SumPtr + 3); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]) - *(SumPtr + ldc); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]) - *(SumPtr + ldc + 1); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]) - *(SumPtr + ldc + 2); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]) - *(SumPtr + ldc + 3); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0) - *SumPtr; + *(SumPtr + ldc) = hsum_float_16(acc1) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_blklen128_r1c1blk1_avx512( + av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 + ); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h new file mode 100644 index 000000000000..52eeba8fe215 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -0,0 +1,675 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + + + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_load_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1,2~2,3~3 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 4~4,5~5,6~6,7~7 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00004444111155552222666633337777 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0044115522663377 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_load_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_load_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4]); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1]); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2]); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r2c1blk1_avx2( + av0_16_epi16, av1_16_epi16, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av0_16_epi16 = load_16_epi8_as_epi16(QuantABlk0); + const __m256i av1_16_epi16 = load_16_epi8_as_epi16(QuantABlk0 + lda); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen16_r2c1blk1_avx2(av0_16_epi16, av1_16_epi16, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20) - *SumPtr; + *(SumPtr + ldc) = hsum_float_8(acc21) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = load_16_epi8_as_epi16(QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t +MlasQ4Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h new file mode 100644 index 000000000000..aa7648af208f --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -0,0 +1,680 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" +#include "sqnbitgemm_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" + +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + + +static MLAS_FORCEINLINE void +load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 +} + +static const uint32_t index_array[16] = {0, 0, 2, 2, 0, 0, 2, 2, 1, 1, 3, 3, 1, 1, 3, 3}; + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_load_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_load_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_load_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + __m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 + const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 + + const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + { + // Col0 + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } + + { + // Col1 + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + acc2[1], acc2[NCols4 + 1]); + } + + { + // Col2 + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + acc2[2], acc2[NCols4 + 2]); + } + + { + // Col3 + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r2c1blk1_avx2( + av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + acc2[3], acc2[NCols4 + 3]); + } + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE +Q4Int8GemmR2C1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20) - *SumPtr; + *(SumPtr + ldc) = hsum_float_8(acc21) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + } + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE + size_t + MlasQ4Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q4Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q4Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h new file mode 100644 index 000000000000..0af21c037629 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -0,0 +1,745 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +static MLAS_FORCEINLINE __m512i +combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) +{ + __m512i result = _mm512_castsi256_si512(a); + result = _mm512_inserti64x4(result, b, 1); + return result; +} + +static MLAS_FORCEINLINE void +load_2blk_4b_packed_blklen64(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) +{ + // | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + const __m512i bv_packed = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i low_mask = _mm512_set1_epi8(0x0F); + bv0_64_epi8 = _mm512_and_si512(bv_packed, low_mask); // 0~63 + bv1_64_epi8 = _mm512_srli_epi16(_mm512_sub_epi8(bv_packed, bv0_64_epi8), 4); // 64~127 + + //// Extract lower and higher 256 bits from bv0_64_epi8 and bv1_64_epi8 + //__m256i bv0_lower = _mm512_castsi512_si256(bv0_64_epi8_); + //__m256i bv0_higher = _mm512_extracti64x4_epi64(bv0_64_epi8_, 1); + //__m256i bv1_lower = _mm512_castsi512_si256(bv1_64_epi8_); + //__m256i bv1_higher = _mm512_extracti64x4_epi64(bv1_64_epi8_, 1); + + //// Compose new __m512i variables + //bv0_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_lower), bv1_lower, 1); + //bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); +} + +static MLAS_FORCEINLINE __m512i +load_1blk_4b_packed_blklen64(const std::byte* QuantBDataPtr) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16( + _mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + __m512i bv_64_epi8 = combine_two_m256i_to_m512i(bv0_32_epi8, bv1_32_epi8); + return bv_64_epi8; +} + +static MLAS_FORCEINLINE __m512i +horizontal_add_epi32(__m512i a, __m512i b) +{ + __m512i t1 = _mm512_unpacklo_epi32(a, b); + __m512i t2 = _mm512_unpackhi_epi32(a, b); + __m512i sum = _mm512_add_epi32(t1, t2); + return sum; +} + +static MLAS_FORCEINLINE __m512i +generate_ones_32_epi16() +{ + const __m512i zeros = _mm512_setzero_si512(); + return _mm512_srli_epi16(_mm512_ternarylogic_epi64(zeros, zeros, zeros, 1), 15); +} + +static MLAS_FORCEINLINE void +dot_accumulate_2blk( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + //const __m512i& one_32_epi16, + __m512& acc) +{ + __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); + __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av1_64_epi8); + + __m512i t1 = _mm512_unpacklo_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); + __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i one_32_epi16 = generate_ones_32_epi16(); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // sum for blk: 0 1 0 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + dot_accumulate_2blk( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blk( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk2_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + dot_accumulate_2blk( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + const __m512i zeros = _mm512_setzero_si512(); + //const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); + //const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } + + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_avx512( + const __m512i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc +) +{ + __m512i bv_64_epi8 = load_1blk_4b_packed_blklen64(QuantBDataPtr); + + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + const __m512i one_32_epi16 = _mm512_set1_epi16(1); + + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +MLAS_FORCEINLINE void +Q4Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen64); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } // k_blks_remaining + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + +#if 1 + *SumPtr = _mm512_reduce_add_ps(acc[0]) - *SumPtr; + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]) - *(SumPtr + 1); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]) - *(SumPtr + 2); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]) - *(SumPtr + 3); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]) - *(SumPtr + ldc); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]) - *(SumPtr + ldc + 1); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]) - *(SumPtr + ldc + 2); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]) - *(SumPtr + ldc + 3); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } +#else + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); + _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); +#endif + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +void MLAS_FORCEINLINE +Q4Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM % NRows2 == 0); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0) - *SumPtr; + *(SumPtr + ldc) = hsum_float_16(acc1) - *(SumPtr + ldc); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, + QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, + QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + const __m128 level_r0 = _mm_loadu_ps(SumPtr); + _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q4Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + + // process 2 blks of 128 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + //assert(CountM < NRows2); + //assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 128 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + while (k_blks_remaining-- > 0) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + + accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0) - *SumPtr; + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +MLAS_FORCEINLINE size_t +MlasQ4Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + else + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + if (NRows2 == 2) + Q4Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + else + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q4Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + if (remainingCols > 0 && remainingRows > 0) { + Q4Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 546e897865ed..5721bad243dc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -238,13 +238,45 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } void MLASCALL -MlasQ80BlkQuantRow_avx512( +QuantizeARow_CompInt8_avx512( size_t BlkLen, const float* A, size_t CountK, - std::byte* QuantA + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) ); +static void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == CompInt8) { + SubBlkLen = 64; + } + + PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + + if (QuantBScaleBegin) { + ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); + } +} + // // Kernel dispatch structure definition. // @@ -259,7 +291,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx512vnni; - d.QuantizeARow_CompInt8 = MlasQ80BlkQuantRow_avx512; + d.QuantizeARow_CompInt8 = nullptr; + d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 0836e6252b91..dbdd9f48636f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -24,64 +24,36 @@ SQ4BitGemmPackQuantBDataSize( } static void -SQ4BitGemmPackQuantBDataAndBlkSum( - size_t N, - size_t K, - size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - const float* QuantBScaleBegin, - const std::byte* QuantBZPBegin, - float* BlockSumBegin, - MLAS_THREADPOOL* ThreadPool -) +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) { constexpr size_t BlkBitWidth = 4; - - assert(BlkLen >= 16 && BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t BlkBytePairCount = BlkLen / 4; - - size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { - SubBlkLen = 64; - } + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t SubBlkDataSize = SubBlkLen / 2; const size_t SubBlkBytePairCount = SubBlkLen / 4; const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); const size_t Iterations = N * SubBlkCountK; // one iteration per sub block - // - // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - - // - // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | - // => - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // - - // - // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | ... | v60 v61 | v62 v63 | - // => - // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - // - // When BlkLen = 32 for the remaining blk, it shall be: - // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | - // - + // for avx2 + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // for the remaining blk, it shall be: + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | + + // for avx512 + // dst: | v0 v64 | v1 v65 | ... | v62 v126 | v63 v127 | + // for the remaining blk, it shall be: + // dst blklen64: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // dst blklen32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // dst blklen16: | v0 v8 | v1 v9 | v2 v11 | v3 v12 | v4 v13 | v5 v14 | v6 v15 | v7 v16 | MlasTrySimpleParallel( ThreadPool, Iterations, [&](ptrdiff_t tid) { @@ -94,31 +66,52 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t PackBytePairCount = SubBlkBytePairCount; size_t PackDataSize = SubBlkDataSize; - if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1) { - // this is the last subblk of the column. check if it extends out of the - // BlockCountK. If it does, we shall pack per blocks so that can compute - // on each block instead of each subblk. - if (SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { - PackBytePairCount = BlkBytePairCount; - PackDataSize = BlkDataSize; - } - } - for (size_t byte_pair_idx = 0; byte_pair_idx < PackBytePairCount; ++byte_pair_idx) { + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + PackDataSize / 2]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } + } + else + { + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); } + } ); +} - MlasTrySimpleParallel( - ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { +static void +ComputePackBlkSum( + size_t N, + const float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { const size_t n = tid / BlockCountK; const size_t k_blk = tid % BlockCountK; @@ -392,6 +385,19 @@ hsum_float_8(const __m256 x) return _mm_cvtss_f32(res); } +static inline float +hsum_float_16(const __m512 x) +{ + __m256 hi = _mm512_extractf32x8_ps(x, 1); + __m256 lo = _mm512_castps512_ps256(x); + hi = _mm256_add_ps(hi, lo); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} /** * @brief Horizontally sum 4 vectors and store * the results in the returned vector @@ -424,4 +430,27 @@ FoldAccumulators(const __m512& acc0, const __m512& acc1, const __m512& acc2, con _mm256_add_ps(_mm512_extractf32x8_ps(acc_lo0123, 0), _mm512_extractf32x8_ps(acc_lo0123, 1)); return _mm_add_ps(_mm256_extractf32x4_ps(acc_y, 0), _mm256_extractf32x4_ps(acc_y, 1)); } + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8(__m256 v0, __m256 v1) +{ + __m256i v0_8_epi32 = _mm256_cvtps_epi32(v0); + __m256i v1_8_epi32 = _mm256_cvtps_epi32(v1); + + __m128i v0_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v0_8_epi32, 0), _mm256_extractf128_si256(v0_8_epi32, 1)); + __m128i v1_8_epi16 = _mm_packs_epi32(_mm256_extractf128_si256(v1_8_epi32, 0), _mm256_extractf128_si256(v1_8_epi32, 1)); + + return _mm_packs_epi16(v0_8_epi16, v1_8_epi16); +} + +// horizontally add 8 int32_t +static MLAS_FORCEINLINE int +hsum_8_epi32(const __m256i a_8_epi32) +{ + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a_8_epi32), _mm256_extractf128_si256(a_8_epi32, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} } // namespace diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index f8a30f88796b..6a092f11993d 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -397,6 +397,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Sat, 1 Jun 2024 02:33:27 -0700 Subject: [PATCH 10/41] matmul_nbit & fix alignment for sgemm Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 33 +++++++++++-------- onnxruntime/core/mlas/inc/mlas_qnbit.h | 1 + onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 20 ++--------- onnxruntime/core/mlas/lib/sqnbitgemm.h | 30 +++++++++++++++++ .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 33 +++++++++++++++++-- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 3 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 1 + .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 1 + .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 6 +++- .../test/mlas/unittest/test_sqnbitgemm.cpp | 12 ++++--- 10 files changed, 99 insertions(+), 41 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 24d96699952d..417b117dea11 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -89,6 +89,8 @@ class MatMulNBits final : public OpKernel { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + const Tensor* tensor_zero_point = nullptr; + has_zp_input_ = info.TryGetConstantInput(3, &tensor_zero_point); #ifdef ORT_NEURAL_SPEED const Tensor* tensor_B = nullptr; const Tensor* tensor_scale = nullptr; @@ -99,7 +101,7 @@ class MatMulNBits final : public OpKernel { is_asym_ = info.GetInputCount() >= 4; all_constant_ = B_constant && scale_constant; all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; -#endif +#endif } Status Compute(OpKernelContext* context) const override; @@ -123,8 +125,9 @@ class MatMulNBits final : public OpKernel { IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; + bool has_zp_input_{false}; #if defined(ORT_NEURAL_SPEED) - + bool is_asym_{false}; bool all_constant_{false}; @@ -185,9 +188,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } #else // defined(ORT_NEURAL_SPEED) - + const auto compute_type = static_cast(accuracy_level_); if (input_idx == 1) { - const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { return Status::OK(); } @@ -197,7 +199,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr); if (prepacked_weights) { // TODO: cannot use packed_b_ after assert(false); @@ -205,12 +207,15 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } is_packed = true; - } else if (input_idx == 2) { - // MlasSQNBitGemmPackQuantBData with scales - assert(false); - } else if (input_idx == 3) { - // MlasSQNBitGemmPackQuantBData with zp - assert(false); + } + else if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr); + is_packed = false; + } else if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr); + is_packed = false; } #endif // defined(ORT_NEURAL_SPEED) @@ -333,9 +338,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // mlas nbits implementation requires packed b. update this logic if it changes. if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type) && packed_b_) { IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { + const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( + M, N, K, batch_count, nbits_, block_size_, compute_type); + if (workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index b7b347232518..2c34015b227d 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -191,6 +191,7 @@ MlasSQNBitGemmPackQuantBData( const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, + bool has_zp_input, const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool = nullptr ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 9cfbce3fda17..c1ea47abbae5 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -194,22 +194,6 @@ MlasSQNBitGemmPackQuantBDataSize( return 0; } -struct PackedQuantBDataStruct { - PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) - : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) - { - constexpr size_t BlkBitWidth = 4; - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - PackedQuantBData = (std::byte*)PackedQuantBWorkspace; - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - } - std::byte* PackedQuantBData; - float* QuantBBlkSum; - - void* QuantBWorkspace_; - size_t N_, BlockCountK_, BlkLen_; -}; - struct PerGemmQuantAWorkspace { PerGemmQuantAWorkspace(void* PerGemmWorkspace, size_t M, size_t BlockCountK, size_t BlkLen) : PerGemmWorkspace_(PerGemmWorkspace), M_(M), BlockCountK_(BlockCountK), BlkLen_(BlkLen) @@ -235,6 +219,7 @@ MlasSQNBitGemmPackQuantBData( const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, + bool has_zp_input, const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) @@ -261,8 +246,6 @@ MlasSQNBitGemmPackQuantBData( } else if (Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSum, N, BlockCountK, BlkLen); - assert(QuantBScale); - //assert(QuantBZeroPoint); // QuantBZeroPoint is nullptr if symetric quantization. Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -271,6 +254,7 @@ MlasSQNBitGemmPackQuantBData( static_cast(QuantBData), packed_quant_b.PackedQuantBData, static_cast(QuantBScale), + has_zp_input, static_cast(QuantBZeroPoint), packed_quant_b.QuantBBlkSum, ThreadPool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index bef6f2b9d725..7f29c3045e1c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -27,12 +27,41 @@ Module Name: #include "mlas_qnbit.h" #include "mlasi.h" +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + constexpr MLAS_FORCEINLINE size_t MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) { return BlkLen * BlkBitWidth / 8; } +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + constexpr size_t BlkBitWidth = 4; + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + + const size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(QuantBBlkSum); + QuantBBlkSum = reinterpret_cast( + (QuantBBlkSumAddr + Alignment - 1) & (~(Alignment - 1)) + ); + } + std::byte* PackedQuantBData; + float* QuantBBlkSum; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + template constexpr MLAS_FORCEINLINE size_t MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) @@ -134,6 +163,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, + bool has_zp_input, const std::byte* QuantBZPBegin, float* BlockSumBegin, // BlockCountK by N => (BlockCountK * N) / 16 by 16 MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 446f78190539..3d357643e179 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1138,11 +1138,13 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, + bool has_zp_input, const std::byte* QuantBZPBegin, float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool ) { + constexpr size_t BlkBitWidth = 4; assert(BlkLen >= 16 && BlkLen % 16 == 0); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -1153,11 +1155,38 @@ SQ4BitGemmPackQuantBDataAndBlkSum( SubBlkLen = 64; } - PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + // TODO: move to avx_common + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (QuantBScaleBegin && has_zp_input && !QuantBZPBegin) { + // scale is provided but still missing zp in order to compute the blksum. + // cache the scale in the later half of PackedQuantBData. + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, (float*)(PackedQuantBDataBegin + PackedQuantBDataSize)); + return; + } - if (QuantBScaleBegin) { + // if called with QuantBZPBegin and without QuantBScaleBegin it must be that + // the scale is already cached in PackedQuantBData (offset PackedQuantBDataSize) + bool delete_quant_b_scale_begin = false; + if (!QuantBScaleBegin && QuantBZPBegin) { + QuantBScaleBegin = new float[N * BlockCountK]; + const float* QuantBScaleBeginSaved = reinterpret_cast(PackedQuantBDataBegin + PackedQuantBDataSize); + std::copy(QuantBScaleBeginSaved, QuantBScaleBeginSaved + N * BlockCountK, const_cast(QuantBScaleBegin)); + delete_quant_b_scale_begin = true; + } + + bool last_call = QuantBScaleBegin && (!has_zp_input || QuantBZPBegin); + + if (last_call) { ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); } + if (delete_quant_b_scale_begin) { + delete[] QuantBScaleBegin; + } } // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 13f7db8bfffa..f82dc993ab30 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -625,12 +625,11 @@ Q4Int8GemmR1xC1BlkLen16Avx2( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc0) - *SumPtr; + *SumPtr = hsum_float_8(acc0) - *SumPtr; if (BiasPtr) { *SumPtr += *BiasPtr; } - // move to next column QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; BiasPtr += BiasPtr != nullptr ? 1 : 0; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 0691cd29005a..fc6efce979c9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -318,6 +318,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, + bool /*has_zp_input*/, const std::byte* QuantBZPBegin, float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 5721bad243dc..d64c74273595 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -256,6 +256,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, + bool /*has_zp_input*/, const std::byte* QuantBZPBegin, float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index dbdd9f48636f..860e90a599e4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -19,7 +19,11 @@ SQ4BitGemmPackQuantBDataSize( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + const size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += Alignment - 1; + return PackedQuantBDataSize + BlkSumSize; } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 6a092f11993d..4b0262436aae 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -269,8 +269,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + bool has_zp_input = QuantBZeroPoint != nullptr; MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, - QuantBScale, QuantBZeroPoint, + QuantBScale, has_zp_input, QuantBZeroPoint, GetMlasThreadPool()); } @@ -391,7 +392,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Mon, 10 Jun 2024 11:47:36 -0700 Subject: [PATCH 11/41] fix mlas benchmark not using multi threads Signed-off-by: Liqun Fu --- onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 4abab182201d..8dc1f61894a2 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -53,6 +53,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, std::vector QuantBData(QuantBDataSizeInBytes); std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + bool has_zp_input = !Symmetric; MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), @@ -71,7 +72,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), QuantBZeroPoint.data(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), tp.get()); } From b9493adbe88c4681fcae71774ec3685d1390bd46 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Mon, 10 Jun 2024 12:37:23 -0700 Subject: [PATCH 12/41] profiling Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 22 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 6 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 309 +++++++++++++++++- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 71 ++-- .../test/mlas/bench/bench_sqnbitgemm.cpp | 13 +- 6 files changed, 365 insertions(+), 58 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 4f6e8d0edcc8..5836e42531e5 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -222,7 +222,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); if (prepacked_weights) { // TODO: cannot use packed_b_ after assert(false); @@ -233,11 +233,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } else if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } #endif // defined(ORT_NEURAL_SPEED) @@ -278,6 +278,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); const auto* a_data = a->Data(); @@ -363,11 +364,22 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].Bias = bias_data; data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; + data[i].node_name = this->Node().Name(); } + //auto start2 = std::chrono::high_resolution_clock::now(); // Start timing here - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); + //int count = 200; + //while (count-- > 0) + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + //auto end = std::chrono::high_resolution_clock::now(); // End timing here + + //std::chrono::duration elapsed2 = end - start2; + //// Calculate and print the duration in nanoseconds + // std::chrono::duration elapsed = end - start; + //std::cout << "MlasSQNBitGemmBatch: " << elapsed2.count() << " ns\n"; + //std::cout << "main Duration_M" << M << "xN" << N << "xK" << K << ": " << elapsed.count() << " ns\n"; return Status::OK(); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 2c34015b227d..09c8c2311425 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -22,6 +22,7 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" +#include /** * @brief Define compute types of block quantization, in order of decreasing accuracy. @@ -55,6 +56,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + std::string node_name = ""; }; /** @@ -179,7 +181,7 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[in] QuantBData quantized B data * @param[out] PackedQuantBData packed quantized B data - * @param[in] ThreadPool optional thread pool to use + * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL MlasSQNBitGemmPackQuantBData( @@ -193,5 +195,5 @@ MlasSQNBitGemmPackQuantBData( const void* QuantBScale, bool has_zp_input, const void* QuantBZeroPoint, - MLAS_THREADPOOL* ThreadPool = nullptr + MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index fac0385fae38..253931c02043 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,6 +19,36 @@ Module Name: #include "sqnbitgemm_q8_block.h" #include +#include +#include +#include "core/common/profiler.h" + +class ProfilerWrapper +{ + public: + ProfilerWrapper() + { + profiler_ = std::make_unique(); + profiler_->StartProfiling("profile.json"); + } + + ~ProfilerWrapper() + { + if (profiler_) { + profiler_->EndProfiling(); + } + } + + onnxruntime::profiling::Profiler* operator->() + { + return profiler_.get(); + } + + private: + std::unique_ptr profiler_; +}; + +static ProfilerWrapper profiler_; namespace { @@ -427,6 +457,84 @@ SQ4BitGemm_CompFp32( } } +//#define CALL_SGEMM_SEPARATELY 1 +#if defined(CALL_SGEMM_SEPARATELY) +void +SQ4BitGemm_CompInt8_0( + const size_t BlkLen, + const size_t K, + const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t ldc = DataParams->ldc; + + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM == 1) { + // auto start = std::chrono::high_resolution_clock::now(); // Start timing here + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + float* c_blk = C + n; + std::chrono::high_resolution_clock::time_point tp; + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } + GetMlasPlatform().GemmFloatKernel( + ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true + ); + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "Sep GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } + } + // auto end = std::chrono::high_resolution_clock::now(); // End timing here + //// Calculate and print the duration in nanoseconds + // std::chrono::duration elapsed = end - start; + // std::cout << "Duration_M" << RangeCountM << "xN" << RangeCountN << "xK" << K << ": " << elapsed.count() << " ns\n"; + return; + } else { + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + + float* c_blk = C + n; + + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { + size_t RowsRemaining = RangeCountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, k_blks, RowsRemaining, CountN, k_blks, ldc, 1.f, true + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += k_blks * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + } +} +#endif + void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -462,7 +570,6 @@ SQ4BitGemm_CompInt8( const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; - const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; @@ -470,8 +577,10 @@ SQ4BitGemm_CompInt8( (DataParams->QuantBZeroPoint == nullptr) ? nullptr : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; +#ifndef CALL_SGEMM_SEPARATELY + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; - +#endif float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; @@ -479,20 +588,34 @@ SQ4BitGemm_CompInt8( if (RangeCountM == 1) { if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8) { + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here + size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const float* b_blk_sum = QuantBBlkSum + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + std::chrono::high_resolution_clock::time_point tp; +#ifndef CALL_SGEMM_SEPARATELY + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } GetMlasPlatform().GemmFloatKernel( ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true - ); - + ); + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } +#endif + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -508,6 +631,10 @@ SQ4BitGemm_CompInt8( lda, ldc ); + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( @@ -516,6 +643,11 @@ SQ4BitGemm_CompInt8( ); } } + //auto end = std::chrono::high_resolution_clock::now(); // End timing here + //// Calculate and print the duration in nanoseconds + //std::chrono::duration elapsed = end - start; + //std::cout << "Duration_M" << RangeCountM << "xN" << RangeCountN << "xK" << K << ": " << elapsed.count() << " ns\n"; + return; } else { size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { @@ -553,11 +685,15 @@ SQ4BitGemm_CompInt8( const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + std::chrono::high_resolution_clock::time_point tp; +#ifndef CALL_SGEMM_SEPARATELY + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } + const float* b_blk_sum = QuantBBlkSum + n * k_blks; if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { size_t RowsRemaining = RangeCountM; const float* a_blksum_row = ABlockSum; @@ -571,7 +707,14 @@ SQ4BitGemm_CompInt8( RowsRemaining -= RowsHandled; } } - + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } +#endif + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } c_blk = C + n; GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, @@ -588,6 +731,10 @@ SQ4BitGemm_CompInt8( lda, ldc ); + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt8_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( @@ -722,6 +869,55 @@ constexpr auto OperationMap = []() { return ops; }(); +void +ComputeParallelTasksSGemm(const size_t M, const size_t N, const size_t CountK, const size_t BatchN, + MLAS_THREADPOOL* ThreadPool, + size_t& ThreadCountM, size_t& ThreadCountN, size_t& ThreadsPerGemm) +{ + const double Complexity = double(M) * double(N) * double(CountK); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + + ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + if (N > M) { + const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } +} } // namespace void MLASCALL @@ -738,6 +934,8 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { + //auto start_batch = std::chrono::high_resolution_clock::now(); // Start timing here + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); @@ -756,16 +954,22 @@ MlasSQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; InitializeWorkspaceOperation != nullptr) { + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); + //auto end = std::chrono::high_resolution_clock::now(); // End timing here + //// Calculate and print the duration in nanoseconds + //std::chrono::duration elapsed = end - start; + //std::cout << "InitializeWorkspaceOperation: " << elapsed.count() << " ns\n"; } const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ThreadPool == nullptr) { + if (/*true || */ThreadPool == nullptr) { + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; if (ComputeType == CompInt8) { @@ -775,11 +979,19 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); +#if defined(CALL_SGEMM_SEPARATELY) + SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); +#endif ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { ComputeOperation(BlkLen, K, Data, nullptr, 0, M, 0, N); } } + //auto end = std::chrono::high_resolution_clock::now(); // End timing here + //// Calculate and print the duration in nanoseconds + //std::chrono::duration elapsed = end - start; + //std::cout << "ThreadPool == nullptr: " << elapsed.count() << " ns\n"; + return; } @@ -787,7 +999,66 @@ MlasSQNBitGemmBatch( // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. // + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here + +#if defined(CALL_SGEMM_SEPARATELY) + if (ComputeType == CompInt8) { + size_t ThreadCountM, ThreadCountN, ThreadsPerGemm; + ComputeParallelTasksSGemm(M, N, BlockCountK, BatchN, ThreadPool, + ThreadCountM, ThreadCountN, ThreadsPerGemm); + + //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << "\t" + // << "ThreadCountM: " << ThreadCountM << "\t" + // << "ThreadCountN: " << ThreadCountN << "\n"; + //auto start = std::chrono::high_resolution_clock::now(); // Start timing here + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + + // MlasSgemmThreaded + ptrdiff_t ThreadId = ThreadIdx; + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / + MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + const auto* Data = &DataParams[GemmIdx]; + + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); + const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + GemmIdx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + //ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); + } +#endif const double Complexity = double(M) * double(N) * double(K) * double(BatchN); ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; @@ -824,6 +1095,14 @@ MlasSQNBitGemmBatch( const size_t ThreadCountN = MlasDivRoundup(N, StrideN); ThreadsPerGemm = ThreadCountM * ThreadCountN; + //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << "\t" + // << "ThreadCountM: " << ThreadCountM << "\t" + // << "ThreadCountN: " << ThreadCountN << "\n"; + std::chrono::high_resolution_clock::time_point tp; + if (profiler_->IsEnabled()) { + tp = profiler_->Start(); + } + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; @@ -850,4 +1129,16 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, nullptr, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } }); + if (profiler_->IsEnabled()) { + std::string eventName = DataParams->node_name + "MlasTrySimpleParallel" + std::to_string(ThreadsPerGemm) + "-" + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); + profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + } + + // auto end = std::chrono::high_resolution_clock::now(); // End timing here + //// Calculate and print the duration in nanoseconds + // std::chrono::duration elapsed = end - start; + // std::chrono::duration elapsed_batch = end - start_batch; + + // std::cout << "ThreadPool kernel: " << elapsed.count() << " ns\n"; + // std::cout << "Batch Internal: " << elapsed_batch.count() << " ns\n"; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0db7a4b655de..8e1f7171e6e4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1181,7 +1181,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( bool last_call = QuantBScaleBegin && (!has_zp_input || QuantBZPBegin); - if (last_call) { + if (last_call) { ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); } if (delete_quant_b_scale_begin) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index dedc01de9655..3fc47287f75b 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -263,11 +263,12 @@ void RunTest(const TestOptions& opts, } // namespace TEST(MatMulNBits, Float32) { - for (auto M : {1, 2, 100}) { - for (auto N : {1, 2, 32, 288}) { - for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { - for (auto block_size : {16, 32, 64, 128}) { - for (auto accuracy_level : {0, 1, 4}) { + //onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + for (auto M : {1/*, 2, 100*/}) { + for (auto N : {2560/*1, 2, 32, 288*/}) { + for (auto K : {2560/*16, 32, 64, 128, 256, 1024, 93, 1234*/}) { + for (auto block_size : {/*16, */32/*, 64, 128*/}) { + for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; @@ -282,36 +283,36 @@ TEST(MatMulNBits, Float32) { RunTest(opts); } - { - TestOptions opts = base_opts; - opts.has_zero_point = true; - RunTest(opts); - } - -#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) - { - TestOptions opts = base_opts; - opts.has_g_idx = true; - RunTest(opts); - } - - { - TestOptions opts = base_opts; - opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); - } -#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) - - { - TestOptions opts = base_opts; - opts.has_bias = true; - - // only enabled for CPU EP for now - std::vector> explicit_eps; - explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - - RunTest(opts, std::move(explicit_eps)); - } +// { +// TestOptions opts = base_opts; +// opts.has_zero_point = true; +// RunTest(opts); +// } +// +//#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// { +// TestOptions opts = base_opts; +// opts.has_g_idx = true; +// RunTest(opts); +// } +// +// { +// TestOptions opts = base_opts; +// opts.has_zero_point = true, opts.zp_is_4bit = false; +// RunTest(opts); +// } +//#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// +// { +// TestOptions opts = base_opts; +// opts.has_bias = true; +// +// // only enabled for CPU EP for now +// std::vector> explicit_eps; +// explicit_eps.emplace_back(DefaultCpuExecutionProvider()); +// +// RunTest(opts, std::move(explicit_eps)); +// } } } } diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 4abab182201d..62f84a6bcc36 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -53,6 +53,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, std::vector QuantBData(QuantBDataSizeInBytes); std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); + bool has_zp_input = !Symmetric; MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), @@ -71,7 +72,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), QuantBZeroPoint.data(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), tp.get()); } @@ -88,7 +89,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + //MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); @@ -116,10 +117,10 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgsProduct({ {16, 32, 64, 128, 256}, // BlkLen - {1, 2, 1024, 2048}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads + {1, 1024, 2048}, // M + {48, 2560, 4096, 11008}, // N + {4096, 2560, 10240, 11008}, // K + {1, 8, 64}, // Threads {int64_t{false}, int64_t{true}}, // Symmetric {int64_t{false}, int64_t{true}}, // HasBias {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType From ac66951c53d41aa5cca80133a3ce4088b590cacd Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Sun, 16 Jun 2024 16:02:34 -0700 Subject: [PATCH 13/41] sgemm after sq4bit for avx2 Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 158 ++++++++++-------- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 16 +- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 21 +-- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 16 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 12 +- .../test/contrib_ops/matmul_4bits_test.cc | 68 ++++---- .../test/mlas/unittest/test_sqnbitgemm.cpp | 15 ++ 7 files changed, 167 insertions(+), 139 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 253931c02043..814d8daa50df 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -29,13 +29,13 @@ class ProfilerWrapper ProfilerWrapper() { profiler_ = std::make_unique(); - profiler_->StartProfiling("profile.json"); + //profiler_->StartProfiling("profile.json"); } ~ProfilerWrapper() { if (profiler_) { - profiler_->EndProfiling(); + //profiler_->EndProfiling(); } } @@ -457,6 +457,7 @@ SQ4BitGemm_CompFp32( } } +//#define BlockSumM1Layout 1 //#define CALL_SGEMM_SEPARATELY 1 #if defined(CALL_SGEMM_SEPARATELY) void @@ -478,7 +479,12 @@ SQ4BitGemm_CompInt8_0( const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; +#if defined(BlockSumM1Layout) + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN; +#else const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; +#endif + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; @@ -489,19 +495,25 @@ SQ4BitGemm_CompInt8_0( for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); - const float* b_blk_sum = QuantBBlkSum + n * k_blks; float* c_blk = C + n; - std::chrono::high_resolution_clock::time_point tp; - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } + //std::chrono::high_resolution_clock::time_point tp; + //if (profiler_->IsEnabled()) { + // tp = profiler_->Start(); + //} +#if defined(BlockSumM1Layout) + const float* b_blk_sum = QuantBBlkSum + n; + GetMlasPlatform().KernelM1Routine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); + // GetMlasPlatform().KernelM1TransposeBRoutine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); +#else + const float* b_blk_sum = QuantBBlkSum + n * k_blks; GetMlasPlatform().GemmFloatKernel( ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true ); - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "Sep GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } +#endif + //if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "Sep GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + //} } // auto end = std::chrono::high_resolution_clock::now(); // End timing here //// Calculate and print the duration in nanoseconds @@ -579,7 +591,11 @@ SQ4BitGemm_CompInt8( : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; #ifndef CALL_SGEMM_SEPARATELY const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; +#if defined(BlockSumM1Layout) + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN; +#else const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; +#endif #endif float* C = DataParams->C + RangeStartM * ldc + RangeStartN; @@ -598,24 +614,10 @@ SQ4BitGemm_CompInt8( const float* b_col_scale = QuantBScale + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - std::chrono::high_resolution_clock::time_point tp; -#ifndef CALL_SGEMM_SEPARATELY - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } - - GetMlasPlatform().GemmFloatKernel( - ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true - ); - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } -#endif - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } + //std::chrono::high_resolution_clock::time_point tp; + //if (profiler_->IsEnabled()) { + // tp = profiler_->Start(); + //} GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -631,11 +633,30 @@ SQ4BitGemm_CompInt8( lda, ldc ); - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } - + //if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + //} + +// #ifndef CALL_SGEMM_SEPARATELY +// if (profiler_->IsEnabled()) { +// tp = profiler_->Start(); +// } +#if defined(BlockSumM1Layout) + const float* b_blk_sum = QuantBBlkSum + n; + GetMlasPlatform().KernelM1Routine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); + // GetMlasPlatform().KernelM1TransposeBRoutine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); +#else + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().GemmFloatKernel( + ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, false + ); +#endif + // if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + // } +// #endif if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM, RangeStartN + n, @@ -687,35 +708,10 @@ SQ4BitGemm_CompInt8( const float* b_col_scale = QuantBScale + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - std::chrono::high_resolution_clock::time_point tp; -#ifndef CALL_SGEMM_SEPARATELY - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } - - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { - size_t RowsRemaining = RangeCountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, k_blks, RowsRemaining, CountN, k_blks, ldc, 1.f, true - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += k_blks * RowsHandled; - RowsRemaining -= RowsHandled; - } - } - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } -#endif - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } - c_blk = C + n; + //std::chrono::high_resolution_clock::time_point tp; + //if (profiler_->IsEnabled()) { + // tp = profiler_->Start(); + //} GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -731,11 +727,35 @@ SQ4BitGemm_CompInt8( lda, ldc ); - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt8_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } + //if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt8_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + //} +#ifndef CALL_SGEMM_SEPARATELY + //if (profiler_->IsEnabled()) { + // tp = profiler_->Start(); + //} + + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { + size_t RowsRemaining = RangeCountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, k_blks, RowsRemaining, CountN, k_blks, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += k_blks * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + //if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + //} +#endif if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM, RangeStartN + n, @@ -980,7 +1000,7 @@ MlasSQNBitGemmBatch( reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); #if defined(CALL_SGEMM_SEPARATELY) - SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + //SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); #endif ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -1054,7 +1074,7 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + GemmIdx * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + //SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); //ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index f82dc993ab30..80d67806ea6e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -319,11 +319,8 @@ Q4Int8GemmR2xC4BlkLen16Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -423,8 +420,8 @@ void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc0) - *SumPtr; - *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -541,8 +538,7 @@ Q4Int8GemmR1xC4BlkLen16Avx2( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -625,7 +621,7 @@ Q4Int8GemmR1xC1BlkLen16Avx2( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc0) - *SumPtr; + *SumPtr = hsum_float_8(acc0); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index def16c7068b5..ff059cb698cf 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -324,11 +324,9 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -444,11 +442,9 @@ Q4Int8Gemm2x4x1BlkLen32Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -547,8 +543,8 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); } - *SumPtr = hsum_float_8(acc0) - *SumPtr; - *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -691,8 +687,7 @@ Q4Int8GemmXx4BlkLen32Avx2( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -774,7 +769,7 @@ Q4Int8GemmXxXBlkLen32Avx2( accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); } - *SumPtr = hsum_float_8(acc0) - *SumPtr; + *SumPtr = hsum_float_8(acc0); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 01853b9b18dc..8e6f0995a24f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -158,11 +158,8 @@ Q4Int8GemmR2xC4BlkLen64Avx2( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -236,8 +233,8 @@ Q4Int8GemmR2xC1BlkLen64Avx2( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc0) - *SumPtr; - *(SumPtr + ldc) = hsum_float_8(acc1) - *(SumPtr + ldc); + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -319,8 +316,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -392,7 +388,7 @@ Q4Int8GemmR1xC1BlkLen64Avx2( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc0) - *SumPtr; + *SumPtr = hsum_float_8(acc0); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 9a6510890cba..ccee4ed8f0df 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -132,9 +132,15 @@ ComputePackBlkSum( zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); } - // BlockSum is a width 16 row major matrix - const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; - *(BlockSumBegin + dst_offset) = *QuantBScale * zp; +//#define BlockSumM1Layout 1 +#if defined(BlockSumM1Layout) + // BlockSum is a regular row major matrix + const size_t dst_offset = k_blk * N + n; +#else + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; +#endif + *(BlockSumBegin + dst_offset) = -(*QuantBScale) * zp; } ); } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3fc47287f75b..b3563799116e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -264,10 +264,10 @@ void RunTest(const TestOptions& opts, TEST(MatMulNBits, Float32) { //onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); - for (auto M : {1/*, 2, 100*/}) { - for (auto N : {2560/*1, 2, 32, 288*/}) { - for (auto K : {2560/*16, 32, 64, 128, 256, 1024, 93, 1234*/}) { - for (auto block_size : {/*16, */32/*, 64, 128*/}) { + for (auto M : {1, 2, 100}) { + for (auto N : {2560, 1, 2, 32, 288}) { + for (auto K : {2560, 16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -280,39 +280,39 @@ TEST(MatMulNBits, Float32) { { TestOptions opts = base_opts; + RunTest(opts); + } + + { + TestOptions opts = base_opts; + opts.has_zero_point = true; + RunTest(opts); + } + +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) + { + TestOptions opts = base_opts; + opts.has_g_idx = true; RunTest(opts); } -// { -// TestOptions opts = base_opts; -// opts.has_zero_point = true; -// RunTest(opts); -// } -// -//#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) -// { -// TestOptions opts = base_opts; -// opts.has_g_idx = true; -// RunTest(opts); -// } -// -// { -// TestOptions opts = base_opts; -// opts.has_zero_point = true, opts.zp_is_4bit = false; -// RunTest(opts); -// } -//#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) -// -// { -// TestOptions opts = base_opts; -// opts.has_bias = true; -// -// // only enabled for CPU EP for now -// std::vector> explicit_eps; -// explicit_eps.emplace_back(DefaultCpuExecutionProvider()); -// -// RunTest(opts, std::move(explicit_eps)); -// } + { + TestOptions opts = base_opts; + opts.has_zero_point = true, opts.zp_is_4bit = false; + RunTest(opts); + } +#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) + + { + TestOptions opts = base_opts; + opts.has_bias = true; + + // only enabled for CPU EP for now + std::vector> explicit_eps; + explicit_eps.emplace_back(DefaultCpuExecutionProvider()); + + RunTest(opts, std::move(explicit_eps)); + } } } } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 4b0262436aae..381634bd3643 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -438,6 +438,20 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Mon, 17 Jun 2024 16:38:41 -0700 Subject: [PATCH 14/41] avx512 Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/sqnbitgemm.h | 2 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 36 +------------ .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 9 ++-- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 25 +++++---- .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 16 +++--- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 20 +++---- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 32 +++++------ .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 54 ++++++++++++++++++- .../test/contrib_ops/matmul_4bits_test.cc | 8 +-- 9 files changed, 102 insertions(+), 100 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index b848c9f59aca..03dd020c9b30 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -47,7 +47,7 @@ struct PackedQuantBDataStruct { PackedQuantBData = (std::byte*)PackedQuantBWorkspace; QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - const size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); + constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); const uintptr_t QuantBBlkSumAddr = reinterpret_cast(QuantBBlkSum); QuantBBlkSum = reinterpret_cast( (QuantBBlkSumAddr + Alignment - 1) & (~(Alignment - 1)) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 8e1f7171e6e4..969eab711772 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1144,7 +1144,6 @@ SQ4BitGemmPackQuantBDataAndBlkSum( MLAS_THREADPOOL* ThreadPool ) { - constexpr size_t BlkBitWidth = 4; assert(BlkLen >= 16 && BlkLen % 16 == 0); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -1154,39 +1153,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum( if (BlkLen == 32 && ComputeType == CompInt8) { SubBlkLen = 64; } - - // TODO: move to avx_common - if (QuantBDataBegin) { - PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); - } - - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - - if (QuantBScaleBegin && has_zp_input && !QuantBZPBegin) { - // scale is provided but still missing zp in order to compute the blksum. - // cache the scale in the later half of PackedQuantBData. - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, (float*)(PackedQuantBDataBegin + PackedQuantBDataSize)); - return; - } - - // if called with QuantBZPBegin and without QuantBScaleBegin it must be that - // the scale is already cached in PackedQuantBData (offset PackedQuantBDataSize) - bool delete_quant_b_scale_begin = false; - if (!QuantBScaleBegin && QuantBZPBegin) { - QuantBScaleBegin = new float[N * BlockCountK]; - const float* QuantBScaleBeginSaved = reinterpret_cast(PackedQuantBDataBegin + PackedQuantBDataSize); - std::copy(QuantBScaleBeginSaved, QuantBScaleBeginSaved + N * BlockCountK, const_cast(QuantBScaleBegin)); - delete_quant_b_scale_begin = true; - } - - bool last_call = QuantBScaleBegin && (!has_zp_input || QuantBZPBegin); - - if (last_call) { - ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); - } - if (delete_quant_b_scale_begin) { - delete[] QuantBScaleBegin; - } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, + PackedQuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, BlockSumBegin, ThreadPool); } // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 299095b7f3a6..16ed24640127 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -318,7 +318,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, - bool /*has_zp_input*/, + bool has_zp_input, const std::byte* QuantBZPBegin, float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool @@ -332,11 +332,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( if (ComputeType == CompInt8) { SubBlkLen = 128; } - PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); - - if (QuantBScaleBegin) { - ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); - } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, + PackedQuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, BlockSumBegin, ThreadPool); } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index d759bfc4a176..d9a3366b5864 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -176,14 +176,14 @@ Q4Int8GemmR2xC4BlkLen128Avx512( } // k_blks_remaining #if 1 - *SumPtr = _mm512_reduce_add_ps(acc[0]) - *SumPtr; - *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]) - *(SumPtr + 1); - *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]) - *(SumPtr + 2); - *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]) - *(SumPtr + 3); - *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]) - *(SumPtr + ldc); - *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]) - *(SumPtr + ldc + 1); - *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]) - *(SumPtr + ldc + 2); - *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]) - *(SumPtr + ldc + 3); + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); if (BiasPtr != nullptr) { *SumPtr += *BiasPtr; *(SumPtr + 1) += *(BiasPtr + 1); @@ -281,8 +281,8 @@ Q4Int8GemmR2xC1BlkLen128Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_16(acc0) - *SumPtr; - *(SumPtr + ldc) = hsum_float_16(acc1) - *(SumPtr + ldc); + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -364,8 +364,7 @@ Q4Int8GemmR1xC4BlkLen128Avx512( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -437,7 +436,7 @@ Q4Int8GemmR1xC1BlkLen128Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_16(acc0) - *SumPtr; + *SumPtr = hsum_float_16(acc0); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index 52eeba8fe215..02254a7056bc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -264,11 +264,8 @@ Q4Int8GemmR2xC4BlkLen16Avx512( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -364,8 +361,8 @@ Q4Int8GemmR2C1BlkLen16Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc20) - *SumPtr; - *(SumPtr + ldc) = hsum_float_8(acc21) - *(SumPtr + ldc); + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -485,8 +482,7 @@ Q4Int8GemmR1xC4BlkLen16Avx512( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -568,7 +564,7 @@ Q4Int8GemmR1xC1BlkLen16Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc2) - *SumPtr; + *SumPtr = hsum_float_8(acc2); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index aa7648af208f..e37dfb191546 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -93,7 +93,7 @@ accumulate_blklen32_r2c1blk4_avx512( const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av01_64_epi8); // 2~2,3~3 const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 const __m512i one_32_epi16 = generate_ones_32_epi16(); const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 @@ -113,7 +113,7 @@ accumulate_blklen32_r2c1blk4_avx512( const __m512i dot1_32_epi16 = _mm512_maddubs_epi16(bv1_64_epi8, av11_64_epi8); // 2~2,3~3 const __m512i t1 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 - const __m512i t2 = _mm512_unpackhi_epi32(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 + const __m512i t2 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 00002222000022221111333311113333 const __m512i sum_32_epi16 = _mm512_add_epi16(t1, t2); // 00002222000022221111333311113333 const __m512i one_32_epi16 = generate_ones_32_epi16(); const __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, sum_32_epi16); // 0022002211331133 @@ -267,11 +267,8 @@ Q4Int8GemmR2xC4BlkLen32Avx512( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -368,8 +365,8 @@ Q4Int8GemmR2C1BlkLen32Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc20) - *SumPtr; - *(SumPtr + ldc) = hsum_float_8(acc21) - *(SumPtr + ldc); + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -490,8 +487,7 @@ Q4Int8GemmR1xC4BlkLen32Avx512( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -573,7 +569,7 @@ Q4Int8GemmR1xC1BlkLen32Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_8(acc2) - *SumPtr; + *SumPtr = hsum_float_8(acc2); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 0af21c037629..2494039134ba 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -311,14 +311,14 @@ Q4Int8GemmR2xC4BlkLen64Avx512( } #if 1 - *SumPtr = _mm512_reduce_add_ps(acc[0]) - *SumPtr; - *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]) - *(SumPtr + 1); - *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]) - *(SumPtr + 2); - *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]) - *(SumPtr + 3); - *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]) - *(SumPtr + ldc); - *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]) - *(SumPtr + ldc + 1); - *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]) - *(SumPtr + ldc + 2); - *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]) - *(SumPtr + ldc + 3); + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); if (BiasPtr != nullptr) { *SumPtr += *BiasPtr; *(SumPtr + 1) += *(BiasPtr + 1); @@ -337,11 +337,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512( acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); - - const __m128 level_r1 = _mm_loadu_ps(SumPtr + ldc); - _mm_storeu_ps(SumPtr + ldc, _mm_sub_ps(acc_r1, level_r1)); + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); #endif // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -427,8 +424,8 @@ Q4Int8GemmR2xC1BlkLen64Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_16(acc0) - *SumPtr; - *(SumPtr + ldc) = hsum_float_16(acc1) - *(SumPtr + ldc); + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); if (BiasPtr) { *SumPtr += *BiasPtr; *(SumPtr + ldc) += *BiasPtr; @@ -528,8 +525,7 @@ Q4Int8GemmR1xC4BlkLen64Avx512( acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); } - const __m128 level_r0 = _mm_loadu_ps(SumPtr); - _mm_storeu_ps(SumPtr, _mm_sub_ps(acc_r0, level_r0)); + _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; @@ -610,7 +606,7 @@ Q4Int8GemmR1xC1BlkLen64Avx512( QuantBScalePtr++; } - *SumPtr = hsum_float_16(acc0) - *SumPtr; + *SumPtr = hsum_float_16(acc0); if (BiasPtr) { *SumPtr += *BiasPtr; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index ccee4ed8f0df..fb02da204022 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -22,7 +22,7 @@ SQ4BitGemmPackQuantBDataSize( const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - const size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); + constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); BlkSumSize += Alignment - 1; return PackedQuantBDataSize + BlkSumSize; @@ -145,6 +145,56 @@ ComputePackBlkSum( ); } +#pragma warning(disable:4505) +static void +PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (QuantBScaleBegin && has_zp_input && !QuantBZPBegin) { + // scale is provided but still missing zp in order to compute the blksum. + // cache the scale in the later half of PackedQuantBData. + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, (float*)(PackedQuantBDataBegin + PackedQuantBDataSize)); + return; + } + + // if called with QuantBZPBegin and without QuantBScaleBegin it must be that + // the scale is already cached in PackedQuantBData (offset PackedQuantBDataSize) + bool delete_quant_b_scale_begin = false; + if (!QuantBScaleBegin && QuantBZPBegin) { + QuantBScaleBegin = new float[N * BlockCountK]; + const float* QuantBScaleBeginSaved = reinterpret_cast(PackedQuantBDataBegin + PackedQuantBDataSize); + std::copy(QuantBScaleBeginSaved, QuantBScaleBeginSaved + N * BlockCountK, const_cast(QuantBScaleBegin)); + delete_quant_b_scale_begin = true; + } + + bool last_call = QuantBScaleBegin && (!has_zp_input || QuantBZPBegin); + + if (last_call) { + ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); + } + if (delete_quant_b_scale_begin) { + delete[] QuantBScaleBegin; + } +} +#pragma warning(default:4505) // // Workspace size calculation function implementation. // @@ -344,7 +394,7 @@ get_2_zps(const std::byte* QuantBZeroPointPtr, int8_t& zp0, int8_t& zp1) zp1 = 8; (void)QuantBZeroPointPtr; } -} +} template int8_t MLAS_FORCEINLINE diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index b3563799116e..3b85939cb471 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -265,10 +265,10 @@ void RunTest(const TestOptions& opts, TEST(MatMulNBits, Float32) { //onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); for (auto M : {1, 2, 100}) { - for (auto N : {2560, 1, 2, 32, 288}) { - for (auto K : {2560, 16, 32, 64, 128, 256, 1024, 93, 1234}) { - for (auto block_size : {16, 32, 64, 128}) { - for (auto accuracy_level : {/*0, 1, */4}) { + for (auto N : {/*2560, */1, 2, 32, 288}) { + for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { + for (auto block_size : {16, 32, 64, 128 }) { + for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; From 740031ac55a449bf20061985db79fb5cffe49530 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 27 Jun 2024 16:21:18 -0700 Subject: [PATCH 15/41] layout to follow compute, M1 separate with M > 1 Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 11 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 3 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 85 +-- onnxruntime/core/mlas/lib/sqnbitgemm.h | 43 +- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 213 ++++-- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 389 ++++------ .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 28 +- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 6 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 52 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 177 +++-- .../lib/sqnbitgemm_kernel_avx_common_int8.h | 37 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 688 ++++++++++++++++++ ...bitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 312 ++++++++ .../test/contrib_ops/matmul_4bits_test.cc | 2 +- .../test/mlas/bench/bench_sqnbitgemm.cpp | 9 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 29 +- 16 files changed, 1543 insertions(+), 541 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h create mode 100644 onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 5836e42531e5..e102b4dc6752 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -358,7 +358,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); + data[i].QuantBDataWorkspace = packed_b_.get(); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].Bias = bias_data; @@ -368,7 +368,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } //auto start2 = std::chrono::high_resolution_clock::now(); // Start timing here - //int count = 200; + //const int CountTotal = 2000; + //int count = CountTotal; //while (count-- > 0) MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), thread_pool); @@ -377,9 +378,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { //std::chrono::duration elapsed2 = end - start2; //// Calculate and print the duration in nanoseconds - // std::chrono::duration elapsed = end - start; - //std::cout << "MlasSQNBitGemmBatch: " << elapsed2.count() << " ns\n"; - //std::cout << "main Duration_M" << M << "xN" << N << "xK" << K << ": " << elapsed.count() << " ns\n"; + //std::chrono::duration elapsed = end - start; + //std::cout << "MlasSQNBitGemmBatch: " << elapsed2.count() / CountTotal << " ns\n"; + //std::cout << "main Duration_M" << M << "xN" << N << "xK" << K << ": " << elapsed.count() / CountTotal << " ns\n"; return Status::OK(); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 09c8c2311425..1ad270cc2928 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -46,7 +46,8 @@ typedef enum { struct MLAS_SQNBIT_GEMM_DATA_PARAMS { const float* A = nullptr; ///< address of A (float32 matrix) size_t lda = 0; ///< leading dimension of A - const void* QuantBData = nullptr; ///< address of quantized B (quantized n-bit int values) + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 814d8daa50df..c9b3a55f9617 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -252,7 +252,7 @@ MlasSQNBitGemmPackQuantBData( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, - void* PackedQuantBDataAndOrBlkSum, + void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, bool has_zp_input, const void* QuantBZeroPoint, @@ -274,24 +274,23 @@ MlasSQNBitGemmPackQuantBData( BlkLen, ComputeType, static_cast(QuantBData), - static_cast(PackedQuantBDataAndOrBlkSum), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), ThreadPool ); return; } else if (Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSum, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, BlkLen, ComputeType, static_cast(QuantBData), - packed_quant_b.PackedQuantBData, static_cast(QuantBScale), has_zp_input, static_cast(QuantBZeroPoint), - packed_quant_b.QuantBBlkSum, + packed_quant_b, ThreadPool ); } @@ -363,7 +362,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -565,7 +564,7 @@ SQ4BitGemm_CompInt8( // // route to fp32 compute before int8 compute is improved. // SQ4BitGemm_CompFp32( // BlkLen, -// K, DataParams, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN +// K, DataParams, per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN // ); // return; // } @@ -583,7 +582,8 @@ SQ4BitGemm_CompInt8( const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; - const std::byte* QuantBData = static_cast(DataParams->QuantBData) + RangeStartN * ldb; + assert(RangeStartN % 4 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -602,14 +602,14 @@ SQ4BitGemm_CompInt8( const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; if (RangeCountM == 1) { - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8) + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 && BlkLen == 16) { //auto start = std::chrono::high_resolution_clock::now(); // Start timing here size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); - + assert(n % 4 == 0); const std::byte* b_col = QuantBData + n * ldb; const float* b_col_scale = QuantBScale + n * k_blks; float* c_blk = C + n; @@ -618,6 +618,7 @@ SQ4BitGemm_CompInt8( //if (profiler_->IsEnabled()) { // tp = profiler_->Start(); //} + //assert(false); GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -684,7 +685,7 @@ SQ4BitGemm_CompInt8( GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + a_row, QuantAScale, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); if (DataParams->PostProcessor != nullptr) { @@ -766,44 +767,6 @@ SQ4BitGemm_CompInt8( } return; } - - // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. - // TODO Replace it with an optimized implementation. - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const std::byte* a_row = QuantA; - const std::byte* b_col = QuantBData + n * ldb; - const float* b_col_scale = QuantBScale + n * k_blks; - const std::byte* b_col_zp = - (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; - - float* c_blk = C + n; - const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - for (size_t m = 0; m < RangeCountM; ++m) { - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( - BlkLen, - a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias - ); - // GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( - // BlkLen, - // a_row, b_col, b_col_scale, b_col_zp, c_blk, /*RangeCountM*/1, CountN, - // K, k_blks, bias, lda, ldc - //); - - // TODO: shall be processed outsize the loop - if (DataParams->PostProcessor != nullptr) { - DataParams->PostProcessor->Process( - DataParams->C, RangeStartM, RangeStartN + n, - RangeCountM, CountN, ldc - ); - } - - c_blk += ldc; - a_row += lda; - } - } } typedef void(InitializeWorkspaceFn)( @@ -994,8 +957,10 @@ MlasSQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; if (ComputeType == CompInt8) { // TODO: shall sepqrate QuantBBlkSum from QuantBData - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); @@ -1118,10 +1083,10 @@ MlasSQNBitGemmBatch( //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << "\t" // << "ThreadCountM: " << ThreadCountM << "\t" // << "ThreadCountN: " << ThreadCountN << "\n"; - std::chrono::high_resolution_clock::time_point tp; - if (profiler_->IsEnabled()) { - tp = profiler_->Start(); - } + //std::chrono::high_resolution_clock::time_point tp; + //if (profiler_->IsEnabled()) { + // tp = profiler_->Start(); + //} MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; @@ -1138,8 +1103,10 @@ MlasSQNBitGemmBatch( const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); if (ComputeType == CompInt8) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; @@ -1149,10 +1116,10 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, nullptr, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } }); - if (profiler_->IsEnabled()) { - std::string eventName = DataParams->node_name + "MlasTrySimpleParallel" + std::to_string(ThreadsPerGemm) + "-" + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); - profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - } + //if (profiler_->IsEnabled()) { + // std::string eventName = DataParams->node_name + "MlasTrySimpleParallel" + std::to_string(ThreadsPerGemm) + "-" + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); + // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); + //} // auto end = std::chrono::high_resolution_clock::now(); // End timing here //// Calculate and print the duration in nanoseconds diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 03dd020c9b30..066f1956684b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -38,22 +38,49 @@ MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) return BlkLen * BlkBitWidth / 8; } +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { + // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + //const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + + PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + + //PackedQuantBScale = (float*)PackedQuantBWorkspace; + //PackedQuantBData = (std::byte*)(PackedQuantBScale) + ScaleSize; + //QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); - const uintptr_t QuantBBlkSumAddr = reinterpret_cast(QuantBBlkSum); - QuantBBlkSum = reinterpret_cast( - (QuantBBlkSumAddr + Alignment - 1) & (~(Alignment - 1)) - ); + + //PackedQuantBScale = (float*)PackedQuantBWorkspace; + + //PackedQuantBData = (std::byte*)PackedQuantBWorkspace + ScaleSize; + //QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); + ////PackedQuantBData = (std::byte*)MlasAlignAddress64(PackedQuantBData); + ////QuantBBlkSum = (float*)MlasAlignAddress64(QuantBBlkSum); + + //constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); + //const uintptr_t QuantBBlkSumAddr = reinterpret_cast(QuantBBlkSum); + //QuantBBlkSum = reinterpret_cast((QuantBBlkSumAddr + Alignment - 1) & (~(Alignment - 1))); } std::byte* PackedQuantBData; + float* PackedQuantBScale; float* QuantBBlkSum; void* QuantBWorkspace_; @@ -109,11 +136,10 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - float* BlockSumBegin, // BlockCountK by N => (BlockCountK * N) / 16 by 16 + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ); @@ -246,6 +272,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { typedef void(SQ4BitGemmM1Kernel_CompInt8_Fn)( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 969eab711772..0a3aaad0e9ae 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -26,6 +26,9 @@ Module Name: #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx2_int8_blklen64.h" +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" +#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" + MLAS_FORCEINLINE __m256 load_float_n_avx2(const float* data, int n) @@ -409,9 +412,10 @@ void SQ4BitGemmM1Kernel_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, - const std::byte* /*QuantBZeroPoint*/, + const std::byte* QuantBZeroPoint, float* C, size_t CountN, size_t CountK, @@ -419,51 +423,161 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( const float* Bias ) { - if (BlkLen == 16) { - const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8GemmKernelBlkLen16Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - 1, // CountM - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0 // ldc, not needed when CountM = 1 - ); - } else if (BlkLen == 32) { - const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8GemmKernelBlkLen32Avx2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - 1, // CountM - CountN, - CountK, - BlockStrideQuantB, - Bias, - 0 // ldc, not needed when CountM = 1 - ); + if (QuantBZeroPoint != nullptr) { + constexpr bool HasZeroPoint = true; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } } else { - const float* QuantAScale = (const float*)(QuantA + BlockStrideQuantB * BlkLen); - MlasQ4Int8GemmKernelBlkLen64Avx2( - BlkLen, - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - C, - 1, // CountM - CountN, - BlockStrideQuantB, - Bias, - 0 // ldc, not needed when CountM = 1 - ); + constexpr bool HasZeroPoint = false; + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else if (BlkLen == 32) { + SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_Sym_CompInt8_avx2( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t /*CountK*/, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint) { + if (BlkLen == 16) { + } else if (BlkLen == 32) { + MlasQ4Int8GemmM1KernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } + } else { + if (BlkLen == 16) { + } else if (BlkLen == 32) { + MlasQ4Int8GemmM1KernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } else { + MlasQ4Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + BlockStrideQuantB, + Bias + ); + } } } @@ -1136,11 +1250,10 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - float* BlockSumBegin, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -1153,8 +1266,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( if (BlkLen == 32 && ComputeType == CompInt8) { SubBlkLen = 64; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, - PackedQuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, BlockSumBegin, ThreadPool); + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } // @@ -1173,7 +1285,8 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_Sym_CompInt8_avx2; + //d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index ff059cb698cf..2201a1d7d501 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -62,10 +62,7 @@ accumulate_blklen32_r2c1blk2_avx2( __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_mul_ps( - _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), - _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); - + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); @@ -80,12 +77,68 @@ accumulate_blklen32_r2c1blk2_avx2( const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps_ = _mm256_mul_ps( - _mm256_permute_ps(scale_a1_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), - _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0))); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); } +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk2_no_bc_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float& scale_a00, + const float& scale_a01, + const float& scale_a10, + const float& scale_a11, + const float& scale_b0, + const float& scale_b1, + __m256& acc0, + __m256& acc1 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + + // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); + // low_mask = _mm256_packus_epi16(low_mask, low_mask); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + // TODO: this (the second line below) is faster and does not keep low_mask in use. + // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv1_32_epi8, bv1_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a00 * scale_b0, one_16_epi16, acc0); + accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a01 * scale_b1, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, scale_a10 * scale_b0, one_16_epi16, acc1); + accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, scale_a11 * scale_b1, one_16_epi16, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_no_bc_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float& scale_a_0, + const float& scale_a_1, + const float& scale_b_0, + const float& scale_b_1, + __m256& acc0 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i low_mask = _mm256_set1_epi8(0x0F); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a_0 * scale_b_0, one_16_epi16, acc0); + accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a_1 * scale_b_1, one_16_epi16, acc0); +} + static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( const __m256i& av00_32_epi8, @@ -113,11 +166,7 @@ accumulate_blklen32_r1c1blk2_avx2( __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_mul_ps( - _mm256_permute_ps(scale_a0_2_ps, _MM_SHUFFLE(1, 1, 0, 0)), - _mm256_permute_ps(scale_b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)) - ); - + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); } @@ -183,12 +232,17 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; @@ -211,79 +265,39 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( size_t k_blks_remaining = BlockCountK; // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + 32; - const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + 32; - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); - - //const float& scale_a00 = Q8BlkScale(QuantABlk00); - //const float& scale_a01 = Q8BlkScale(QuantABlk01); - //const float& scale_a10 = Q8BlkScale(QuantABlk10); - //const float& scale_a11 = Q8BlkScale(QuantABlk11); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); { - // Col0 - //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - //const float& scale_10 = scale_a10 * QuantBScalePtr[0]; - //const float& scale_11 = scale_a11 * QuantBScalePtr[1]; - //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, scale_10, scale_11, acc[0], acc[NCols4]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); } - { - // Col1 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - //const float& scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - //const float& scale_11 = scale_a11 * (QuantBScalePtr + StrideQuantBScale)[1]; - //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[1], acc[NCols4 + 1]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); } { - // Col2 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - //const float& scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - //const float& scale_11 = scale_a11 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[2], acc[NCols4 + 2]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); } { - // Col3 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - //const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - //const float& scale_11 = scale_a11 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - //accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, scale_10, scale_11, acc[3], acc[NCols4 + 3]); - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } // k_blks_remaining // TODO: use a loop in case PerAccuBlk2 is not 2. if (k_blks_remaining > 0) { // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const float& scale_a00 = *QuantAScalePtr; const float& scale_a10 = *(QuantAScalePtr + BlockCountK); @@ -297,23 +311,23 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( { // Col1 - const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 - const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -329,129 +343,8 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - - BiasPtr += BiasPtr != nullptr ? NCols4 : 0; - SumPtr += NCols4; - } - } -} - -template -MLAS_FORCEINLINE void -Q4Int8Gemm2x4x1BlkLen32Avx2( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc -) -{ - constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - constexpr size_t Q8Blk32Size = BlkLen32; - - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; - const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); - - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer - assert(CountM % NRows2 == 0); - assert(CountN % NCols4 == 0); - - for (size_t m = 0; m < CountM; m += NRows2) { - const std::byte* QuantBDataColPtr = QuantBData; - const float* QuantBScaleColPtr = QuantBScale; - const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; - const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; - - for (size_t n = 0; n < CountN; n += NCols4) { - const std::byte* QuantAPtr = QuantA + m * lda; - const float* QuantAScalePtr = QuantAScale + m * BlockCountK; - - const std::byte* QuantBDataPtr = QuantBDataColPtr; - const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - - __m256 acc[NCols4 * NRows2] = { - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), - _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() - }; - - // process 2 blks of 64 4b weights a time - for (size_t k = 0; k < BlockCountK; k++) { - const __m256i av0_32_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av1_32_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); - - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); - - { - // Col0 - const float& scale0 = scale_a0 * QuantBScalePtr[0]; - const float& scale1 = scale_a1 * QuantBScalePtr[0]; - accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale0, scale1, acc[0], acc[NCols4], k % 2 == 0); - } - - { - // Col1 - const float& scale0 = scale_a0 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float& scale1 = scale_a1 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + StrideQuantBZeroPoint, scale0, scale1, acc[1], acc[NCols4 + 1], k % 2 == 0); - } - - { - // Col2 - const float& scale0 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float& scale1 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale0, scale1, acc[2], acc[NCols4 + 2], k % 2 == 0); - } - - { - // Col3 - const float& scale0 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale1 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av0_32_epi8, av1_32_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale0, scale1, acc[3], acc[NCols4 + 3], k % 2 == 0); - } - - // increment block pointers - QuantAPtr += Q8Blk32Size; - QuantBDataPtr += BlkDataSizeInBytes16; - QuantBScalePtr++; - if constexpr (HasZeroPoint) { - QuantBZeroPointPtr += k % 2; - } - } // k_blks_remaining - - __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); - __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); - if (BiasPtr != nullptr) { - const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); - acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); - acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); - } - _mm_storeu_ps(SumPtr, acc_r0); - - _mm_storeu_ps(SumPtr + ldc, acc_r1); - - // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; - if constexpr (HasZeroPoint) { - QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; - } + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; @@ -482,7 +375,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); @@ -506,16 +399,10 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( size_t k_blks_remaining = BlockCountK; // process 2 blks of 64 4b weights a time for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const std::byte* QuantABlk00 = QuantAPtr; - const std::byte* QuantABlk01 = QuantABlk00 + BlkLen32; - const std::byte* QuantABlk10 = QuantAPtr + lda; - const std::byte* QuantABlk11 = QuantABlk10 + BlkLen32; - - // load A: - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); - const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); - const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); accumulate_blklen32_r2c1blk2_avx2( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, @@ -528,12 +415,9 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( QuantBScalePtr += PerAccuBlk2; } - // TODO: use a loop in case PerAccuBlk2 is not 2. if (k_blks_remaining > 0) { - // load A - const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantABlk0 + lda)); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const float& scale_a00 = *QuantAScalePtr; const float& scale_a10 = *(QuantAScalePtr + BlockCountK); @@ -552,7 +436,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( // move to next column QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; + QuantBScaleColPtr += BlockCountK; BiasPtr += BiasPtr != nullptr ? 1 : 0; SumPtr += 1; @@ -584,11 +468,16 @@ Q4Int8GemmXx4BlkLen32Avx2( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); assert(CountN % NCols4 == 0); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; for (size_t m = 0; m < CountM; m++) { const std::byte* QuantBDataColPtr = QuantBData; @@ -606,51 +495,36 @@ Q4Int8GemmXx4BlkLen32Avx2( __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); { - // Col0 - //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc[0]); accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); } { - // Col1 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + StrideQuantBScale)[1]; - //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantBZeroPointPtr + 1 * StrideQuantBZeroPoint, scale_00, scale_01, acc[1]); accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1] + av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] ); } { - // Col2 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + 2 * StrideQuantBScale)[1]; - //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, scale_00, scale_01, acc[2]); accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2] + av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] ); } { - // Col3 - //const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - //const float& scale_01 = scale_a01 * (QuantBScalePtr + 3 * StrideQuantBScale)[1]; - //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, scale_00, scale_01, acc[3]); accumulate_blklen32_r1c1blk2_avx2( - av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, - QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, + QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] + ); } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } // TODO: use a loop in case PerAccuBlk2 is not 2. @@ -661,24 +535,24 @@ Q4Int8GemmXx4BlkLen32Avx2( const float& scale_a00 = *QuantAScalePtr; { - // Col0 + // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); } { // Col1 - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc[1]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); } { // Col2 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc[2]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); } { // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc[3]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); } } @@ -690,8 +564,8 @@ Q4Int8GemmXx4BlkLen32Avx2( _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } @@ -723,7 +597,7 @@ Q4Int8GemmXxXBlkLen32Avx2( const size_t lda = BlockCountK * BlkLen32; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM < NRows2); @@ -744,15 +618,12 @@ Q4Int8GemmXxXBlkLen32Avx2( __m256 acc0 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { - const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - - //const float& scale_00 = scale_a00 * QuantBScalePtr[0]; - //const float& scale_01 = scale_a01 * QuantBScalePtr[1]; - //accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantBZeroPointPtr, scale_00, scale_01, acc0); accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr, - QuantAScalePtr, QuantBScalePtr, acc0); + QuantAScalePtr, QuantBScalePtr, acc0 + ); // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk2; @@ -776,7 +647,7 @@ Q4Int8GemmXxXBlkLen32Avx2( // move to next column QuantBDataColPtr += StrideQuantBData; - QuantBScaleColPtr += StrideQuantBScale; + QuantBScaleColPtr += BlockCountK; BiasPtr += BiasPtr != nullptr ? 1 : 0; SumPtr += 1; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 8e6f0995a24f..15898a719632 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -107,7 +107,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( const size_t lda = BlockCountK * BlkLen; const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); @@ -139,16 +139,16 @@ Q4Int8GemmR2xC4BlkLen64Avx2( const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; } QuantAScalePtr++; - QuantBScalePtr++; + QuantBScalePtr += NCols4; } // k_blks_remaining __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -163,7 +163,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } @@ -275,7 +275,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( const size_t lda = BlockCountK * BlkLen; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); assert(CountN % NCols4 == 0); @@ -299,16 +299,16 @@ Q4Int8GemmR1xC4BlkLen64Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; } QuantAScalePtr++; - QuantBScalePtr++; + QuantBScalePtr += NCols4; } __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -320,7 +320,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 16ed24640127..ede3ddfc2659 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -316,11 +316,10 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - float* BlockSumBegin, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -332,8 +331,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( if (ComputeType == CompInt8) { SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, - PackedQuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, BlockSumBegin, ThreadPool); + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index baf6ba90a55e..8af8b8e7f1fe 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -146,6 +146,7 @@ void SQ4BitGemmM1Kernel_CompInt8_avx512vnni( size_t BlkLen, const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -157,44 +158,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( ) { if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } + assert(false); } else { constexpr bool HasZeroPoint = false; if (BlkLen == 16) { @@ -212,6 +176,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } else if (BlkLen == 32) { SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl>( QuantA, + QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, @@ -254,11 +219,10 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, - bool /*has_zp_input*/, + bool has_zp_input, const std::byte* QuantBZPBegin, - float* BlockSumBegin, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -271,11 +235,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum( SubBlkLen = 64; } - PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); - - if (QuantBScaleBegin) { - ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); - } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index fb02da204022..01c57c3a238e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -19,13 +19,50 @@ SQ4BitGemmPackQuantBDataSize( constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); - BlkSumSize += Alignment - 1; + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; - return PackedQuantBDataSize + BlkSumSize; + return PackedQuantBDataSize + ScaleSize + BlkSumSize; +} + +static size_t +GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + size_t scale_dst_offset = T * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += t * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += k_sub_or_blk * 4 + t; + } + return scale_dst_offset; +} + +static size_t +GetContinueLayoutOffset32(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / 2, b = k_blk % 2; + bool te = T == N / 4, be = k_subblk == BlockCountK / 2; + size_t scale_dst_offset = T * 4 * BlockCountK; + if (te) { + scale_dst_offset += t * BlockCountK + k_blk; + } else { + scale_dst_offset += k_subblk * 2 * 4; + if (be) { + scale_dst_offset += b * 4 + t; + } else { + scale_dst_offset += t * 2 + b; + } + } + return scale_dst_offset; } static void @@ -65,9 +102,8 @@ PackQuantB( const size_t n = tid / SubBlkCountK; const size_t k_subblk = tid % SubBlkCountK; - const size_t data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + const size_t src_data_offset = n * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset; size_t PackBytePairCount = SubBlkBytePairCount; size_t PackDataSize = SubBlkDataSize; @@ -95,57 +131,100 @@ PackQuantB( PackDataSize = BlkDataSize; const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; for (size_t k = 0; k < k_blks_remaining; k++) { - pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen == 32) { + const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else { + // shall not reach here (with avx2?) + assert(false); + } } } else { - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + else if (BlkLen == 32) { + const size_t k_blk = k_subblk * SubBlkLen / BlkLen; + const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + else { // if (BlkLen > 32) + const size_t dst_data_offset = GetContinueLayoutOffset64(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } } - } ); } +//#include + static void ComputePackBlkSum( + size_t Blklen, size_t N, - const float* QuantBScaleBegin, + float* QuantBScaleBegin, const std::byte* QuantBZPBegin, float* BlockSumBegin, MLAS_THREADPOOL* ThreadPool, const size_t BlockCountK) { + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t src_blk_offset = n * BlockCountK + k_blk; - const float* QuantBScale = QuantBScaleBegin + src_blk_offset; - uint8_t zp = 8; - if (QuantBZPBegin) { - size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); - size_t src_zp_offset = ZPCountK * n + k_blk / 2; - bool low_zp = k_blk % 2 == 0; - const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; - const std::byte low_mask{0X0F}; - zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); - } + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 8; + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + size_t src_zp_offset = ZPCountK * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset; + const std::byte low_mask{0X0F}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } //#define BlockSumM1Layout 1 #if defined(BlockSumM1Layout) - // BlockSum is a regular row major matrix - const size_t dst_offset = k_blk * N + n; + // BlockSum is a regular row major matrix + const size_t dst_offset = k_blk * N + n; #else - // BlockSum is a width 16 row major matrix - const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; #endif - *(BlockSumBegin + dst_offset) = -(*QuantBScale) * zp; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (Blklen == 16) { + + } else if (Blklen == 32) { + size_t scale_dst_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; + } else if (Blklen > 32) { + const size_t scale_dst_offset = GetContinueLayoutOffset64(N, n, BlockCountK, k_blk); + *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; } + } ); + + //for (int i = 0; i < N * BlockCountK; i++) { + // std::cout << *(QuantBScaleBegin + i) << "\n"; + //} } -#pragma warning(disable:4505) static void PackQuantBDataAndBlkSum( size_t N, @@ -153,48 +232,26 @@ PackQuantBDataAndBlkSum( size_t BlkLen, size_t SubBlkLen, const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - float* BlockSumBegin, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { - constexpr size_t BlkBitWidth = 4; if (QuantBDataBegin) { - PackQuantB(QuantBDataBegin, PackedQuantBDataBegin, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); } - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - - if (QuantBScaleBegin && has_zp_input && !QuantBZPBegin) { - // scale is provided but still missing zp in order to compute the blksum. - // cache the scale in the later half of PackedQuantBData. - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, (float*)(PackedQuantBDataBegin + PackedQuantBDataSize)); - return; + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); } - // if called with QuantBZPBegin and without QuantBScaleBegin it must be that - // the scale is already cached in PackedQuantBData (offset PackedQuantBDataSize) - bool delete_quant_b_scale_begin = false; - if (!QuantBScaleBegin && QuantBZPBegin) { - QuantBScaleBegin = new float[N * BlockCountK]; - const float* QuantBScaleBeginSaved = reinterpret_cast(PackedQuantBDataBegin + PackedQuantBDataSize); - std::copy(QuantBScaleBeginSaved, QuantBScaleBeginSaved + N * BlockCountK, const_cast(QuantBScaleBegin)); - delete_quant_b_scale_begin = true; - } - - bool last_call = QuantBScaleBegin && (!has_zp_input || QuantBZPBegin); - - if (last_call) { - ComputePackBlkSum(N, QuantBScaleBegin, QuantBZPBegin, BlockSumBegin, ThreadPool, BlockCountK); - } - if (delete_quant_b_scale_begin) { - delete[] QuantBScaleBegin; + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); } } -#pragma warning(default:4505) + // // Workspace size calculation function implementation. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h index 250ffeacd7c2..aa0686552838 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -240,6 +240,7 @@ template accumulator> void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( const std::byte* QuantA, + const float* QuantAScale, const std::byte* QuantBData, const float* QuantBScale, const std::byte* QuantBZeroPoint, @@ -273,6 +274,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( int64_t nblk = (int64_t)(CountN)-4; while (nblk >= 0) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -286,14 +288,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -320,7 +322,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -331,9 +334,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -374,6 +377,7 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( nblk += NCols; for (int64_t n = 0; n < nblk; n++) { const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; @@ -383,14 +387,14 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( size_t k_blks_remaining = BlockCountK; for (; k_blks_remaining > 1; k_blks_remaining -= 2) { const std::byte* QuantABlk0 = QuantAPtr; - const std::byte* QuantABlk1 = QuantABlk0 + Q8BlkSize(BlkLen); + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; // load A: - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); - const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk1)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); - const float& scale_a0 = Q8BlkScale(QuantABlk0); - const float& scale_a1 = Q8BlkScale(QuantABlk1); + const float& scale_a0 = *QuantAScalePtr; + const float& scale_a1 = *(QuantAScalePtr + 1); // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; @@ -399,7 +403,8 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( accumulator(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); // increment block pointers - QuantAPtr += Q8BlkSize(BlkLen) * 2; + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; QuantBDataPtr += 16 * 2; QuantBScalePtr += 2; if constexpr (HasZeroPoint) { @@ -410,9 +415,9 @@ SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl( if (k_blks_remaining > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; - const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)Q8BlkData(QuantABlk0)); + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const float& scale_a0 = Q8BlkScale(QuantABlk0); + const float& scale_a0 = *QuantAScalePtr; // Col0 const float& scale_00 = scale_a0 * QuantBScalePtr[0]; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h new file mode 100644 index 000000000000..836c22f2cb82 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -0,0 +1,688 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_zp_avx2( + const __m256i& av_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale, + const std::byte* QuantBZeroPointPtr, + __m256& acc, + const __m256i& low_mask +) +{ + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(low_mask, bv_32_epi8); + + bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + __m256& acc0, + const __m256i& low_mask +) +{ + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // accumulate_blklen32_r1c1blk2_zp_is_8_avx2 is much faster than + // accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2: + // BlkBitWidth:4/BlkLen:32/M:1/N:2560/K:2560/Threads:8/Symmetric:1/HasBias:0/ComputeType:4 + // 36591 vs 40270 ns (the main is 51836 ns). both are not as good as main with genai. + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); + + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( + const __m256i& av0_32_epi8, + const __m256i& av1_32_epi8, + const __m256& scale_a0_8_ps, + const __m256& scale_a1_8_ps, + const std::byte* QuantBDataPtr, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // TODO: consolidate with accumulate_blklen32_r1c1blk2_avx2 using a zp8 template option + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + //const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData2 = 2 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBData1 = 1 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale2 = 2; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc[0], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale, acc[1], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); + //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + } + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i zero = _mm256_setzero_si256(); + const float& scale_a00 = *QuantAScalePtr; + { + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + } + { + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk2 = 2; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountN < NCols4); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bzp8 = _mm256_set1_epi8(8); + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); + //const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); + + if constexpr (HasZeroPoint) { + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + } else { + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += 1; + } + } + + // TODO: use a loop in case PerAccuBlk2 is not 2. + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE +void +MlasQ4Int8GemmM1KernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias + ) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} + +//#define SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout 1 +void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + // port from neon implementation + constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = 32; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout +#else + constexpr bool HasZeroPoint = false; +#endif + + float* CRowPtr = C; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + //const size_t StrideQuantBScale = BlockCountK; + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const __m256i bzp8 = _mm256_set1_epi8(8); + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + (void)StrideQuantBZeroPoint; +#else + const __m256i zero = _mm256_setzero_si256(); + const __m128i low_mask = _mm_set1_epi8(0xF); +#endif + const size_t NCols = 4; + constexpr size_t StrideQuantBScale2 = 2; + constexpr size_t StrideQuantBScale1 = 1; + + int64_t nblk = (int64_t)(CountN)-4; + while (nblk >= 0) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 + acc0 = _mm256_setzero_ps(), + acc1 = _mm256_setzero_ps(), + acc2 = _mm256_setzero_ps(), + acc3 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen))); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + + // Col1 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + StrideQuantBData, QuantBScalePtr + StrideQuantBScale2, acc1, low_mask, bzp8); +#else + const float& scale_10 = scale_a0 * (QuantBScalePtr + StrideQuantBScale2)[0]; + const float& scale_11 = scale_a1 * (QuantBScalePtr + StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_10, acc1); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, false, scale_11, acc1); +#endif + + // Col2 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale2, acc2, low_mask, bzp8); +#else + const float& scale_20 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale2)[0]; + const float& scale_21 = scale_a1 * (QuantBScalePtr + 2 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_20, acc2); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, false, scale_21, acc2); +#endif + // Col3 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale2, acc3, low_mask, bzp8); +#else + const float& scale_30 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale2)[0]; + const float& scale_31 = scale_a1 * (QuantBScalePtr + 3 * StrideQuantBScale2)[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_30, acc3); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData + 16), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, false, scale_31, acc3); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2 * NCols; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_0 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_0, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_0, acc0); +#endif + + // Col1 + const float& scale_1 = scale_a0 * (QuantBScalePtr + StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + StrideQuantBData, scale_1, acc1, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + StrideQuantBZeroPoint, true, scale_1, acc1); +#endif + + // Col2 + const float& scale_2 = scale_a0 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_2, acc2, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 2 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, true, scale_2, acc2); +#endif + + // Col3 + const float& scale_3 = scale_a0 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_3, acc3, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr + 3 * StrideQuantBData), low_mask, zero, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, true, scale_3, acc3); +#endif + } + + __m128 acc_x = FoldAccumulators(acc0, acc1, acc2, acc3); + if (BiasPtr != nullptr) { + acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(BiasPtr)); + } + _mm_storeu_ps(SumPtr, acc_x); + + // move to next NCols columns + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + nblk -= NCols; + } + + nblk += NCols; + for (int64_t n = 0; n < nblk; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + (void)QuantBZeroPointPtr; +#endif + __m256 acc0 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= 2) { + const std::byte* QuantABlk0 = QuantAPtr; + const std::byte* QuantABlk1 = QuantABlk0 + BlkLen; + + // load A: + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + const __m256i av_1_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk1); + +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + const __m256 scale_a0_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk0)); + const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantABlk1)); +#else + const float& scale_a0 = QuantAScalePtr[0]; + const float& scale_a1 = QuantAScalePtr[1]; +#endif + + // Col0 +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_0_epi8, av_1_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr, QuantBScalePtr, acc0, low_mask, bzp8); +#else + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; + const float& scale_01 = scale_a1 * QuantBScalePtr[1]; + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); + accumulate_mul_sum_avx2(av_1_epi8, reinterpret_cast(QuantBDataPtr + 16), low_mask, zero, QuantBZeroPointPtr, false, scale_01, acc0); +#endif + // increment block pointers + QuantAPtr += BlkLen * 2; + QuantAScalePtr += 2; + QuantBDataPtr += 16 * 2; + QuantBScalePtr += 2; + } + + if (k_blks_remaining > 0) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m256i av_0_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); + + const float& scale_a0 = *QuantAScalePtr; + + // Col0 + const float& scale_00 = scale_a0 * QuantBScalePtr[0]; +#if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout + accumulate_blklen32_r1c1blk1_zp_avx2(av_0_epi8, QuantBDataPtr, scale_00, acc0, low_mask, bzp8); +#else + accumulate_mul_sum_avx2(av_0_epi8, reinterpret_cast(QuantBDataPtr), low_mask, zero, QuantBZeroPointPtr, true, scale_00, acc0); +#endif + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h new file mode 100644 index 000000000000..d76b687145b9 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -0,0 +1,312 @@ +#pragma once +#include +#include +#include + +#include "sqnbitgemm.h" +#include "sqnbitgemm_kernel_avx_common.h" + + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + const std::byte* QuantBZeroPointPtr, + const bool is_lower_half_byte_zp, + __m256& acc0, + const __m256i& low_mask +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + const __m256i bzp8 = _mm256_set1_epi8(get_zp(is_lower_half_byte_zp, QuantBZeroPointPtr)); + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +static MLAS_FORCEINLINE void +accumulate_blklen64_r1c1blk1_zp_is_8_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m256& acc0, + const __m256i& low_mask, + const __m256i& bzp8 +) +{ + // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 + __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 + + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); + + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av00_32_epi8, bv0_32_epi8)); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + constexpr size_t SubblkLen64 = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen64; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + assert(CountN % NCols4 == 0); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + const size_t StrideQuantBData1 = 1 * SubblkDataSizeInBytes; + const size_t StrideQuantBScale1 = 1; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc[0], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, QuantBZeroPointPtr + StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[1], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[2], low_mask); + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, is_lower_half_byte_zp, acc[3], low_mask); + } else { + const __m256i bzp8 = _mm256_set1_epi8(8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale1, acc[1], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale1, acc[2], low_mask, bzp8); + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData1, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale1, acc[3], low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen64; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += NCols4 * StrideQuantBZeroPoint; + } + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmM1C1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + constexpr size_t BlkBitWidth4 = 4; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + assert(CountN < NCols4); + + const __m256i low_mask = _mm256_set1_epi8(0x0F); + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + const float* BiasPtr = Bias; + auto* SumPtr = C; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA; + const float* QuantAScalePtr = QuantAScale; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const bool is_lower_half_byte_zp = (k % 2) == 0; + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i bzp8 = _mm256_set1_epi8(8); + + if constexpr (HasZeroPoint) { + accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); + } else { + accumulate_blklen64_r1c1blk1_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + } + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + if constexpr (HasZeroPoint) { + QuantBZeroPointPtr += k % 2; + } + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +template +MLAS_FORCEINLINE void +MlasQ4Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleCols > 0) { + Q4Int8GemmM1C4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + multipleCols, + BlockCountK, + Bias); + } + + if (remainingCols > 0) { + Q4Int8GemmM1C1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, + C + multipleCols, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr); + } +} diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3b85939cb471..3c3eff81bad7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -268,7 +268,7 @@ TEST(MatMulNBits, Float32) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { for (auto block_size : {16, 32, 64, 128 }) { - for (auto accuracy_level : {0, 1, 4}) { + for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 62f84a6bcc36..6f2e1e94e1b0 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -79,9 +79,10 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; - params.QuantBData = PackedQuantBData != nullptr - ? static_cast(PackedQuantBData.get()) - : static_cast(QuantBData.data()); + if (PackedQuantBData != nullptr) + params.QuantBDataWorkspace = static_cast(PackedQuantBData.get()); + else + params.QuantBDataWorkspace = static_cast(QuantBData.data()); params.QuantBScale = QuantBScale.data(); params.QuantBZeroPoint = Symmetric ? nullptr : QuantBZeroPoint.data(); params.Bias = HasBias ? Bias.data() : nullptr; @@ -89,7 +90,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - //MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 381634bd3643..1157cd4c868f 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -56,7 +56,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { const float* A, size_t lda, const void* QuantBData, - const void* PackedQuantBData, + const void* PackedQuantBDataWorkspace, const float* QuantBScale, const void* QuantBZeroPoint, const float* Bias, @@ -71,7 +71,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBData = PackedQuantBData != nullptr ? PackedQuantBData : QuantBData; + params.QuantBDataWorkspace = PackedQuantBDataWorkspace != nullptr ? PackedQuantBDataWorkspace : QuantBData; params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; @@ -265,16 +265,25 @@ class MlasSQNBitGemmTest : public MlasTestBase { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } - void* PackedQuantBData = nullptr; + void* PackedQuantBDataWorkspace = nullptr; if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { - PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); + PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, QuantBScale, has_zp_input, QuantBZeroPoint, GetMlasThreadPool()); } + CallGemm(M, N, K, + A, /* lda */ K, + QuantBData, PackedQuantBDataWorkspace, QuantBScale, QuantBZeroPoint, + Bias, + C, /* ldc */ N, + Workspace, + ComputeType, + Threadpool); + if (ComputeType == CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else if (ComputeType == CompInt8) { @@ -284,15 +293,6 @@ class MlasSQNBitGemmTest : public MlasTestBase { << ComputeType << " (" << ComputeTypeName(ComputeType) << ")"; } - CallGemm(M, N, K, - A, /* lda */ K, - QuantBData, PackedQuantBData, QuantBScale, QuantBZeroPoint, - Bias, - C, /* ldc */ N, - Workspace, - ComputeType, - Threadpool); - size_t f = 0; for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { @@ -479,6 +479,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Fri, 28 Jun 2024 22:48:26 +0000 Subject: [PATCH 16/41] make avx512 run Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 2 +- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index c9b3a55f9617..8ec0f0718b2c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -602,7 +602,7 @@ SQ4BitGemm_CompInt8( const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; if (RangeCountM == 1) { - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 && BlkLen == 16) + if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8 == nullptr || (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 && BlkLen == 16)) { //auto start = std::chrono::high_resolution_clock::now(); // Start timing here diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 01c57c3a238e..3e556b8432b8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -132,7 +132,7 @@ PackQuantB( const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; for (size_t k = 0; k < k_blks_remaining; k++) { const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; - if (BlkLen == 16) { + if (BlkLen == 16 || SubBlkLen == 128) { // TODO: // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); @@ -148,7 +148,7 @@ PackQuantB( } else { - if (BlkLen == 16) { + if (BlkLen == 16 || SubBlkLen == 128) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); @@ -208,7 +208,7 @@ ComputePackBlkSum( const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; #endif *(BlockSumBegin + dst_offset) = -QuantBScale * zp; - if (Blklen == 16) { + if (true || Blklen == 16) { // TODO } else if (Blklen == 32) { size_t scale_dst_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); From d0359391672573e462ad647dc6b898710f36b9af Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 4 Jul 2024 21:04:05 +0000 Subject: [PATCH 17/41] avx512 blklen64 pass Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 68 ++++++++----------- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 64 ++++++++--------- .../test/contrib_ops/matmul_4bits_test.cc | 2 +- onnxruntime/test/mlas/bench/bench_q4dq.cpp | 24 +++---- .../test/mlas/unittest/test_sqnbitgemm.cpp | 5 ++ 5 files changed, 80 insertions(+), 83 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2494039134ba..4a67f2971494 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -243,8 +243,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen64; - const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); @@ -275,20 +275,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } // k_blks_remaining while (k_blks_remaining-- > 0) { @@ -298,16 +294,16 @@ Q4Int8GemmR2xC4BlkLen64Avx512( accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); + QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); + QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); QuantAPtr += BlkLen64; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; } #if 1 @@ -341,8 +337,8 @@ Q4Int8GemmR2xC4BlkLen64Avx512( _mm_storeu_ps(SumPtr + ldc, acc_r1); #endif // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } @@ -465,8 +461,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512( constexpr size_t PerAccuBlk2 = 2; const size_t lda = BlockCountK * BlkLen; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; //assert(CountM < NRows2); //assert(CountN % NCols4 == 0); @@ -490,34 +486,30 @@ Q4Int8GemmR1xC4BlkLen64Avx512( for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; QuantAScalePtr += PerAccuBlk2; - QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; - QuantBScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; } while (k_blks_remaining-- > 0) { const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); QuantAPtr += BlkLen64; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; } __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -528,8 +520,8 @@ Q4Int8GemmR1xC4BlkLen64Avx512( _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 3e556b8432b8..9b804e38d2c9 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -33,7 +33,7 @@ SQ4BitGemmPackQuantBDataSize( } static size_t -GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) { size_t T = n / 4, t = n % 4; bool te = T == N / 4; @@ -47,19 +47,19 @@ GetContinueLayoutOffset64(size_t N, const size_t n, const size_t SubOrBlkCountK, } static size_t -GetContinueLayoutOffset32(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk) +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) { - size_t T = n / 4, t = n % 4, k_subblk = k_blk / 2, b = k_blk % 2; - bool te = T == N / 4, be = k_subblk == BlockCountK / 2; + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; size_t scale_dst_offset = T * 4 * BlockCountK; if (te) { scale_dst_offset += t * BlockCountK + k_blk; } else { - scale_dst_offset += k_subblk * 2 * 4; + scale_dst_offset += k_subblk * blks_per_sub * 4; if (be) { scale_dst_offset += b * 4 + t; } else { - scale_dst_offset += t * 2 + b; + scale_dst_offset += t * blks_per_sub + b; } } return scale_dst_offset; @@ -132,37 +132,35 @@ PackQuantB( const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; for (size_t k = 0; k < k_blks_remaining; k++) { const size_t k_blk = k_subblk * SubBlkLen / BlkLen + k; - if (BlkLen == 16 || SubBlkLen == 128) { // TODO: + if (BlkLen == 16) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); - } else if (BlkLen == 32) { - const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); - } else { - // shall not reach here (with avx2?) - assert(false); } } - } - else - { - if (BlkLen == 16 || SubBlkLen == 128) { + } else { + if (BlkLen == 16) { // not to do the compute order layout yet std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } - else if (BlkLen == 32) { - const size_t k_blk = k_subblk * SubBlkLen / BlkLen; - const size_t dst_data_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); - std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; - pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); - } - else { // if (BlkLen > 32) - const size_t dst_data_offset = GetContinueLayoutOffset64(N, n, SubBlkCountK, k_subblk); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t k_blk = k_subblk * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); } } } @@ -173,7 +171,8 @@ PackQuantB( static void ComputePackBlkSum( - size_t Blklen, + size_t BlkLen, + size_t SubBlkLen, size_t N, float* QuantBScaleBegin, const std::byte* QuantBZPBegin, @@ -208,13 +207,14 @@ ComputePackBlkSum( const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; #endif *(BlockSumBegin + dst_offset) = -QuantBScale * zp; - if (true || Blklen == 16) { // TODO + if (BlkLen == 16) { // TODO - } else if (Blklen == 32) { - size_t scale_dst_offset = GetContinueLayoutOffset32(N, n, BlockCountK, k_blk); + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; - } else if (Blklen > 32) { - const size_t scale_dst_offset = GetContinueLayoutOffset64(N, n, BlockCountK, k_blk); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); *(QuantBScaleBegin + scale_dst_offset) = QuantBScale; } } @@ -248,7 +248,7 @@ PackQuantBDataAndBlkSum( } if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { - ComputePackBlkSum(BlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); } } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3c3eff81bad7..f54a0fd03679 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -267,7 +267,7 @@ TEST(MatMulNBits, Float32) { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { - for (auto block_size : {16, 32, 64, 128 }) { + for (auto block_size : {/*16, 32, */64/*, 128*/ }) { for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index 00234ecfd2ce..eb0727207b83 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -9,10 +9,10 @@ #include "core/util/thread_utils.h" static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -37,10 +37,10 @@ static void BM_QDQBlockwiseQuantizer_QuantizeColumnwise(benchmark::State& state) } static void BM_MlasQuantizeBlockwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); size_t scale_size = (M + quant_block_size - 1) / quant_block_size * N; auto src = RandomVectorUniform(M * N, -16.0f, 14.0f); @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; size_t scale_size = quant_num_M * N; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 1157cd4c868f..b994a21ff979 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -461,7 +461,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Fri, 5 Jul 2024 21:32:55 +0000 Subject: [PATCH 18/41] pass avx512 blklen32 Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 83 +++++++++---------- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 2 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 1 + 4 files changed, 41 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e37dfb191546..60e7e5a84b55 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -146,10 +146,9 @@ Q4Int8GemmR2xC4BlkLen32Avx512( constexpr size_t PerAccuBlk4 = 4; const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = PerAccuBlk4 * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); @@ -185,22 +184,22 @@ Q4Int8GemmR2xC4BlkLen32Avx512( acc[0], acc[NCols4]); accumulate_blklen32_r2c1blk4_avx512( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); accumulate_blklen32_r2c1blk4_avx512( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); accumulate_blklen32_r2c1blk4_avx512( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; } // k_blks_remaining __m256 acc2[NCols4 * NRows2] = { @@ -232,32 +231,31 @@ Q4Int8GemmR2xC4BlkLen32Avx512( { // Col1 - const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData, scale_00, scale_10, + const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); } { // Col2 - const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, scale_10, + const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); } { // Col3 - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r2c1blk1_avx2( - av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, scale_10, + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); } QuantAPtr += BlkLen32; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; } // k_blks_remaining __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); @@ -271,8 +269,8 @@ Q4Int8GemmR2xC4BlkLen32Avx512( _mm_storeu_ps(SumPtr + ldc, acc_r1); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; @@ -306,7 +304,6 @@ Q4Int8GemmR2C1BlkLen32Avx512( const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); const size_t StrideQuantBScale = BlockCountK; - [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer assert(CountM % NRows2 == 0); assert(CountN < NCols4); @@ -406,8 +403,8 @@ Q4Int8GemmR1xC4BlkLen32Avx512( constexpr size_t PerAccuBlk4 = 4; const size_t lda = BlockCountK * BlkLen32; - const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); assert(CountN % NCols4 == 0); @@ -433,19 +430,15 @@ Q4Int8GemmR1xC4BlkLen32Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4; - QuantBScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; } __m256 acc2[NCols4] = { @@ -463,22 +456,22 @@ Q4Int8GemmR1xC4BlkLen32Avx512( accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); } { - const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData, scale_00, acc2[1]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); } { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData, scale_00, acc2[2]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); } { - const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData, scale_00, acc2[3]); + const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); } QuantAPtr += BlkLen32; QuantAScalePtr++; - QuantBDataPtr += BlkDataSizeInBytes16; - QuantBScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; } @@ -490,8 +483,8 @@ Q4Int8GemmR1xC4BlkLen32Avx512( _mm_storeu_ps(SumPtr, acc_r0); // move to next NCols columns - QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 9b804e38d2c9..1a32359696de 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -143,7 +143,7 @@ PackQuantB( int blks_per_sub = (int)(SubBlkLen / BlkLen); const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; - pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); } } } else { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index f54a0fd03679..bc9b5d6d80a7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -267,7 +267,7 @@ TEST(MatMulNBits, Float32) { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { - for (auto block_size : {/*16, 32, */64/*, 128*/ }) { + for (auto block_size : {/*16, */32, 64/*, 128*/ }) { for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index b994a21ff979..d39cc31778d3 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -476,6 +476,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Fri, 5 Jul 2024 21:59:53 +0000 Subject: [PATCH 19/41] pass avx512 blklen 16, 128, 256 Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 34 ++++++++----------- .../test/contrib_ops/matmul_4bits_test.cc | 2 +- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index d9a3366b5864..b5eecefe8474 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -127,7 +127,7 @@ Q4Int8GemmR2xC4BlkLen128Avx512( const size_t lda = BlockCountK * BlkLen; const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM % NRows2 == 0); assert(CountN % NCols4 == 0); @@ -158,21 +158,17 @@ Q4Int8GemmR2xC4BlkLen128Avx512( const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, acc[1], acc[NCols4 + 1]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], acc[NCols4 + 2]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], acc[NCols4 + 3]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; } QuantAScalePtr++; - QuantBScalePtr++; + QuantBScalePtr += NCols4; } // k_blks_remaining #if 1 @@ -210,7 +206,7 @@ Q4Int8GemmR2xC4BlkLen128Avx512( #endif // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } @@ -323,7 +319,7 @@ Q4Int8GemmR1xC4BlkLen128Avx512( const size_t lda = BlockCountK * BlkLen; const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); - const size_t StrideQuantBScale = BlockCountK; + //const size_t StrideQuantBScale = BlockCountK; assert(CountM < NRows2); assert(CountN % NCols4 == 0); @@ -347,16 +343,16 @@ Q4Int8GemmR1xC4BlkLen128Avx512( const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; - QuantBDataPtr += SubblkDataSizeInBytes; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; } QuantAScalePtr++; - QuantBScalePtr++; + QuantBScalePtr +=NCols4; } __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); @@ -368,7 +364,7 @@ Q4Int8GemmR1xC4BlkLen128Avx512( // move to next NCols columns QuantBDataColPtr += NCols4 * StrideQuantBData; - QuantBScaleColPtr += NCols4 * StrideQuantBScale; + QuantBScaleColPtr += NCols4 * BlockCountK; BiasPtr += BiasPtr != nullptr ? NCols4 : 0; SumPtr += NCols4; } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bc9b5d6d80a7..3c3eff81bad7 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -267,7 +267,7 @@ TEST(MatMulNBits, Float32) { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { - for (auto block_size : {/*16, */32, 64/*, 128*/ }) { + for (auto block_size : {16, 32, 64, 128 }) { for (auto accuracy_level : {/*0, 1, */4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; From edee3198ff0efa101905cc664136e6d884207061 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Thu, 11 Jul 2024 10:54:13 -0700 Subject: [PATCH 20/41] pass fp32, refactor sqnbitgemm Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 19 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 301 ++---------------- onnxruntime/core/mlas/lib/sqnbitgemm.h | 17 - .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 104 +----- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx_common.h | 109 +++++-- .../test/contrib_ops/matmul_4bits_test.cc | 2 +- 8 files changed, 122 insertions(+), 434 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index e102b4dc6752..7df608a92d74 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -230,15 +230,16 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } is_packed = true; - } - else if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); - is_packed = false; - } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { - auto zptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); - is_packed = false; + } else if (compute_type == CompInt8) { + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + is_packed = false; + } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); + is_packed = false; + } } #endif // defined(ORT_NEURAL_SPEED) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 8ec0f0718b2c..9358bc2c3363 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -19,36 +19,6 @@ Module Name: #include "sqnbitgemm_q8_block.h" #include -#include -#include -#include "core/common/profiler.h" - -class ProfilerWrapper -{ - public: - ProfilerWrapper() - { - profiler_ = std::make_unique(); - //profiler_->StartProfiling("profile.json"); - } - - ~ProfilerWrapper() - { - if (profiler_) { - //profiler_->EndProfiling(); - } - } - - onnxruntime::profiling::Profiler* operator->() - { - return profiler_.get(); - } - - private: - std::unique_ptr profiler_; -}; - -static ProfilerWrapper profiler_; namespace { @@ -265,20 +235,7 @@ MlasSQNBitGemmPackQuantBData( } if (BlkBitWidth == 4) { - if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { - assert(QuantBScale == nullptr); - assert(QuantBZeroPoint == nullptr); - Dispatch->SQ4BitGemmPackQuantBData( - N, - K, - BlkLen, - ComputeType, - static_cast(QuantBData), - static_cast(PackedQuantBDataAndOrBlkSumWorkspace), - ThreadPool - ); - return; - } else if (Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( @@ -293,6 +250,20 @@ MlasSQNBitGemmPackQuantBData( packed_quant_b, ThreadPool ); + } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. + //assert(QuantBScale == nullptr); + //assert(QuantBZeroPoint == nullptr); + Dispatch->SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); + return; } } } @@ -362,7 +333,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->QuantBDataWorkspace) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -456,96 +427,6 @@ SQ4BitGemm_CompFp32( } } -//#define BlockSumM1Layout 1 -//#define CALL_SGEMM_SEPARATELY 1 -#if defined(CALL_SGEMM_SEPARATELY) -void -SQ4BitGemm_CompInt8_0( - const size_t BlkLen, - const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN -) -{ - const size_t k_blks = MlasDivRoundup(K, BlkLen); - - // quant A scale is embedded in QuantData if QuantScale is nullptr. - const size_t ldc = DataParams->ldc; - - const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; - -#if defined(BlockSumM1Layout) - const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN; -#else - const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; -#endif - - - float* C = DataParams->C + RangeStartM * ldc + RangeStartN; - - if (RangeCountM == 1) { - // auto start = std::chrono::high_resolution_clock::now(); // Start timing here - - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - float* c_blk = C + n; - //std::chrono::high_resolution_clock::time_point tp; - //if (profiler_->IsEnabled()) { - // tp = profiler_->Start(); - //} -#if defined(BlockSumM1Layout) - const float* b_blk_sum = QuantBBlkSum + n; - GetMlasPlatform().KernelM1Routine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); - // GetMlasPlatform().KernelM1TransposeBRoutine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); -#else - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().GemmFloatKernel( - ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, true - ); -#endif - //if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "Sep GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - //} - } - // auto end = std::chrono::high_resolution_clock::now(); // End timing here - //// Calculate and print the duration in nanoseconds - // std::chrono::duration elapsed = end - start; - // std::cout << "Duration_M" << RangeCountM << "xN" << RangeCountN << "xK" << K << ": " << elapsed.count() << " ns\n"; - return; - } else { - size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, size_t{128}); - - const float* b_blk_sum = QuantBBlkSum + n * k_blks; - - float* c_blk = C + n; - - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { - size_t RowsRemaining = RangeCountM; - const float* a_blksum_row = ABlockSum; - while (RowsRemaining > 0) { - auto RowsHandled = GetMlasPlatform().GemmFloatKernel( - a_blksum_row, b_blk_sum, c_blk, k_blks, RowsRemaining, CountN, k_blks, ldc, 1.f, true - ); - - c_blk += ldc * RowsHandled; - a_blksum_row += k_blks * RowsHandled; - RowsRemaining -= RowsHandled; - } - } - } - } -} -#endif - void SQ4BitGemm_CompInt8( const size_t BlkLen, @@ -589,14 +470,8 @@ SQ4BitGemm_CompInt8( (DataParams->QuantBZeroPoint == nullptr) ? nullptr : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; -#ifndef CALL_SGEMM_SEPARATELY const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; -#if defined(BlockSumM1Layout) - const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN; -#else const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; -#endif -#endif float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; @@ -604,8 +479,6 @@ SQ4BitGemm_CompInt8( if (RangeCountM == 1) { if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8 == nullptr || (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 && BlkLen == 16)) { - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here - size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { CountN = std::min(RangeCountN - n, size_t{128}); @@ -614,11 +487,6 @@ SQ4BitGemm_CompInt8( const float* b_col_scale = QuantBScale + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - //std::chrono::high_resolution_clock::time_point tp; - //if (profiler_->IsEnabled()) { - // tp = profiler_->Start(); - //} - //assert(false); GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -634,30 +502,10 @@ SQ4BitGemm_CompInt8( lda, ldc ); - //if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - //} - -// #ifndef CALL_SGEMM_SEPARATELY -// if (profiler_->IsEnabled()) { -// tp = profiler_->Start(); -// } -#if defined(BlockSumM1Layout) - const float* b_blk_sum = QuantBBlkSum + n; - GetMlasPlatform().KernelM1Routine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); - // GetMlasPlatform().KernelM1TransposeBRoutine(ABlockSum, b_blk_sum, c_blk, k_blks, CountN, ldc, 0.0f); -#else const float* b_blk_sum = QuantBBlkSum + n * k_blks; GetMlasPlatform().GemmFloatKernel( ABlockSum, b_blk_sum, c_blk, k_blks, RangeCountM, CountN, k_blks, ldc, 1.f, false ); -#endif - // if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - // } -// #endif if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM, RangeStartN + n, @@ -665,10 +513,6 @@ SQ4BitGemm_CompInt8( ); } } - //auto end = std::chrono::high_resolution_clock::now(); // End timing here - //// Calculate and print the duration in nanoseconds - //std::chrono::duration elapsed = end - start; - //std::cout << "Duration_M" << RangeCountM << "xN" << RangeCountN << "xK" << K << ": " << elapsed.count() << " ns\n"; return; } else { size_t CountN; @@ -709,10 +553,6 @@ SQ4BitGemm_CompInt8( const float* b_col_scale = QuantBScale + n * k_blks; float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - //std::chrono::high_resolution_clock::time_point tp; - //if (profiler_->IsEnabled()) { - // tp = profiler_->Start(); - //} GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, QuantA, @@ -728,15 +568,6 @@ SQ4BitGemm_CompInt8( lda, ldc ); - //if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "SQ4BitGemmKernel_CompInt8_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(K); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - //} - -#ifndef CALL_SGEMM_SEPARATELY - //if (profiler_->IsEnabled()) { - // tp = profiler_->Start(); - //} const float* b_blk_sum = QuantBBlkSum + n * k_blks; if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum) { @@ -752,11 +583,7 @@ SQ4BitGemm_CompInt8( RowsRemaining -= RowsHandled; } } - //if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "GemmFloatKernel_" + std::to_string(RangeCountM) + "_" + std::to_string(CountN) + "_" + std::to_string(k_blks); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - //} -#endif + if (DataParams->PostProcessor != nullptr) { DataParams->PostProcessor->Process( DataParams->C, RangeStartM, RangeStartN + n, @@ -937,21 +764,16 @@ MlasSQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; InitializeWorkspaceOperation != nullptr) { - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); - //auto end = std::chrono::high_resolution_clock::now(); // End timing here - //// Calculate and print the duration in nanoseconds - //std::chrono::duration elapsed = end - start; - //std::cout << "InitializeWorkspaceOperation: " << elapsed.count() << " ns\n"; } const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (/*true || */ThreadPool == nullptr) { + if (ThreadPool == nullptr) { //auto start = std::chrono::high_resolution_clock::now(); // Start timing here for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; @@ -964,19 +786,11 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); -#if defined(CALL_SGEMM_SEPARATELY) - //SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); -#endif ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { ComputeOperation(BlkLen, K, Data, nullptr, 0, M, 0, N); } } - //auto end = std::chrono::high_resolution_clock::now(); // End timing here - //// Calculate and print the duration in nanoseconds - //std::chrono::duration elapsed = end - start; - //std::cout << "ThreadPool == nullptr: " << elapsed.count() << " ns\n"; - return; } @@ -984,66 +798,7 @@ MlasSQNBitGemmBatch( // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. // - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here - -#if defined(CALL_SGEMM_SEPARATELY) - if (ComputeType == CompInt8) { - size_t ThreadCountM, ThreadCountN, ThreadsPerGemm; - ComputeParallelTasksSGemm(M, N, BlockCountK, BatchN, ThreadPool, - ThreadCountM, ThreadCountN, ThreadsPerGemm); - - //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << "\t" - // << "ThreadCountM: " << ThreadCountM << "\t" - // << "ThreadCountN: " << ThreadCountN << "\n"; - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { - ptrdiff_t GemmIdx = tid / ThreadsPerGemm; - ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; - - // MlasSgemmThreaded - ptrdiff_t ThreadId = ThreadIdx; - const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; - const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; - - // - // Partition the operation along the M dimension. - // - - size_t RangeStartM; - size_t RangeCountM; - MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); - - // - // Partition the operation along the N dimension. - // - - size_t RangeStartN; - size_t RangeCountN; - - const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); - - RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - RangeCountN = std::min(N - RangeStartN, RangeCountN); - - const auto* Data = &DataParams[GemmIdx]; - - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBData), N, BlockCountK, BlkLen); - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + GemmIdx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); - //SQ4BitGemm_CompInt8_0(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - //ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); - }); - } -#endif const double Complexity = double(M) * double(N) * double(K) * double(BatchN); ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; @@ -1080,14 +835,6 @@ MlasSQNBitGemmBatch( const size_t ThreadCountN = MlasDivRoundup(N, StrideN); ThreadsPerGemm = ThreadCountM * ThreadCountN; - //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << "\t" - // << "ThreadCountM: " << ThreadCountM << "\t" - // << "ThreadCountN: " << ThreadCountN << "\n"; - //std::chrono::high_resolution_clock::time_point tp; - //if (profiler_->IsEnabled()) { - // tp = profiler_->Start(); - //} - MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; @@ -1116,16 +863,4 @@ MlasSQNBitGemmBatch( ComputeOperation(BlkLen, K, Data, nullptr, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } }); - //if (profiler_->IsEnabled()) { - // std::string eventName = DataParams->node_name + "MlasTrySimpleParallel" + std::to_string(ThreadsPerGemm) + "-" + std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K); - // profiler_->EndTimeAndRecordEvent(onnxruntime::profiling::KERNEL_EVENT, eventName, tp); - //} - - // auto end = std::chrono::high_resolution_clock::now(); // End timing here - //// Calculate and print the duration in nanoseconds - // std::chrono::duration elapsed = end - start; - // std::chrono::duration elapsed_batch = end - start_batch; - - // std::cout << "ThreadPool kernel: " << elapsed.count() << " ns\n"; - // std::cout << "Batch Internal: " << elapsed_batch.count() << " ns\n"; } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 066f1956684b..49f7a355e75b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -60,24 +60,7 @@ struct PackedQuantBDataStruct { PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); - - //PackedQuantBScale = (float*)PackedQuantBWorkspace; - //PackedQuantBData = (std::byte*)(PackedQuantBScale) + ScaleSize; - //QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - - - //PackedQuantBScale = (float*)PackedQuantBWorkspace; - - //PackedQuantBData = (std::byte*)PackedQuantBWorkspace + ScaleSize; - //QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - ////PackedQuantBData = (std::byte*)MlasAlignAddress64(PackedQuantBData); - ////QuantBBlkSum = (float*)MlasAlignAddress64(QuantBBlkSum); - - //constexpr size_t Alignment = MlasQNBitQuantBBlkSumAlignment(); - //const uintptr_t QuantBBlkSumAddr = reinterpret_cast(QuantBBlkSum); - //QuantBBlkSum = reinterpret_cast((QuantBBlkSumAddr + Alignment - 1) & (~(Alignment - 1))); } std::byte* PackedQuantBData; float* PackedQuantBScale; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0a3aaad0e9ae..bb77771b509c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -410,105 +410,6 @@ SQ4BitGemmKernel_CompInt8_avx2( MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( - size_t BlkLen, - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - const std::byte* QuantBZeroPoint, - float* C, - size_t CountN, - size_t CountK, - size_t BlockStrideQuantB, - const float* Bias -) -{ - if (QuantBZeroPoint != nullptr) { - constexpr bool HasZeroPoint = true; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } - } else { - constexpr bool HasZeroPoint = false; - if (BlkLen == 16) { - SQ4BitGemmM1Kernel_BlkLen16_CompInt8_Impl( - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } else if (BlkLen == 32) { - SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( - QuantA, - QuantAScale, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - BlockStrideQuantB, - Bias - ); - } else { - SQ4BitGemmM1Kernel_BlkLen64Plus_CompInt8_Impl( - BlkLen, - QuantA, - QuantBData, - QuantBScale, - QuantBZeroPoint, - C, - CountN, - CountK, - BlockStrideQuantB, - Bias - ); - } - } -} - -MLAS_FORCEINLINE -void -SQ4BitGemmM1Kernel_Sym_CompInt8_avx2( size_t BlkLen, const std::byte* QuantA, const float* QuantAScale, @@ -1276,7 +1177,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; @@ -1285,8 +1186,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_Sym_CompInt8_avx2; - //d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8_avx2; d.SQ4BitGemmKernel_CompInt8 = SQ4BitGemmKernel_CompInt8_avx2; d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index ede3ddfc2659..2f217b31d249 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -338,7 +338,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 8af8b8e7f1fe..cc4dde6a6be3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -245,7 +245,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { MLAS_SQNBIT_GEMM_DISPATCH d; d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = nullptr; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 1a32359696de..dedca6a424ee 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -14,22 +14,101 @@ SQ4BitGemmPackQuantBDataSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + constexpr size_t BlkBitWidth = 4; + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ComputeType == CompInt8) { + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } else { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } +} +static void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ constexpr size_t BlkBitWidth = 4; + assert(BlkLen >= 16 && BlkLen % 16 == 0); + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t ScaleSize = N * BlockCountK * sizeof(float); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - // _mm256_load_si256 requires alignment on a 32-byte boundary - constexpr size_t PackedQuantBDataAlignment = 32; - PackedQuantBDataSize += PackedQuantBDataAlignment - 1; - constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); - BlkSumSize += BlkSumAlignment - 1; + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - return PackedQuantBDataSize + ScaleSize + BlkSumSize; + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); } static size_t @@ -198,14 +277,8 @@ ComputePackBlkSum( zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); } -//#define BlockSumM1Layout 1 -#if defined(BlockSumM1Layout) - // BlockSum is a regular row major matrix - const size_t dst_offset = k_blk * N + n; -#else // BlockSum is a width 16 row major matrix const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; -#endif *(BlockSumBegin + dst_offset) = -QuantBScale * zp; if (BlkLen == 16) { // TODO @@ -219,10 +292,6 @@ ComputePackBlkSum( } } ); - - //for (int i = 0; i < N * BlockCountK; i++) { - // std::cout << *(QuantBScaleBegin + i) << "\n"; - //} } static void diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3c3eff81bad7..3b85939cb471 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -268,7 +268,7 @@ TEST(MatMulNBits, Float32) { for (auto N : {/*2560, */1, 2, 32, 288}) { for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { for (auto block_size : {16, 32, 64, 128 }) { - for (auto accuracy_level : {/*0, 1, */4}) { + for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; From c109b4b29a3c42f8137a06ef69dccb84d07af100 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Thu, 18 Jul 2024 06:36:02 +0000 Subject: [PATCH 21/41] avx512vnni Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/platform.cpp | 2 +- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 49 ++- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 8 +- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 92 ++++-- .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 205 +++++++++++-- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 285 ++++++++++++++---- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 227 +++++++++----- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 108 ++++++- 8 files changed, 771 insertions(+), 205 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 8c4445e3a691..72eb35c89409 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -455,7 +455,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - //this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 2201a1d7d501..6ff52be83b3e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -19,6 +19,14 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +MLAS_FORCEINLINE void +accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( const __m256i& av00_32_epi8, @@ -170,6 +178,7 @@ accumulate_blklen32_r1c1blk2_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); } +template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk1_avx2( const __m256i& av00_32_epi8, @@ -186,11 +195,17 @@ accumulate_blklen32_r2c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); - accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( const __m256i& av00_32_epi8, @@ -204,8 +219,12 @@ accumulate_blklen32_r1c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + if constexpr (vnni) { + accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); + } } MLAS_FORCEINLINE void @@ -306,28 +325,28 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -424,7 +443,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); } *SumPtr = hsum_float_8(acc0); @@ -537,22 +556,22 @@ Q4Int8GemmXx4BlkLen32Avx2( { // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); } { // Col1 const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); } { // Col2 const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); } } @@ -637,7 +656,7 @@ Q4Int8GemmXxXBlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); } *SumPtr = hsum_float_8(acc0); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index c3da9aa52bf5..6289242ac054 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -175,7 +175,7 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( ) { if (BlkLen == 16) { - MlasQ4Int8GemmKernelBlkLen16Avx512( + MlasQ4Int8GemmKernelBlkLen16Avx512( QuantA, QuantAScale, QuantBData, @@ -188,7 +188,7 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( ldc ); } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx512( + MlasQ4Int8GemmKernelBlkLen32Avx512( QuantA, QuantAScale, QuantBData, @@ -201,7 +201,7 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( ldc ); } else if (BlkLen == 64) { - MlasQ4Int8GemmKernelBlkLen64Avx512( + MlasQ4Int8GemmKernelBlkLen64Avx512( BlkLen, QuantA, QuantAScale, @@ -215,7 +215,7 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( ldc ); } else { - MlasQ4Int8GemmKernelBlkLen128Avx512( + MlasQ4Int8GemmKernelBlkLen128Avx512( BlkLen, QuantA, QuantAScale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index b5eecefe8474..60a887345d0e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -57,6 +57,23 @@ dot_accumulate_1blk( acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); } +static MLAS_FORCEINLINE void +dot_accumulate_1blkvnni( + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float combined_scale, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(dot0_16_epi32, bv1_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_set1_ps(combined_scale), acc); +} + +template static MLAS_FORCEINLINE void accumulate_blklen128_r1c1blk1_avx512( const __m512i& av00_64_epi8, @@ -70,11 +87,20 @@ accumulate_blklen128_r1c1blk1_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a) * (*scale_b), acc); + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a) * (*scale_b), acc + ); + } } +template static MLAS_FORCEINLINE void accumulate_blklen128_r2c1blk1_avx512( const __m512i& av00_64_epi8, @@ -92,16 +118,28 @@ accumulate_blklen128_r2c1blk1_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, - (*scale_a0) * (*scale_b), acc0 - ); - dot_accumulate_1blk( - bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, - (*scale_a1) * (*scale_b), acc1 - ); + if constexpr (vnni) { + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blkvnni( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } else { + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, + (*scale_a0) * (*scale_b), acc0 + ); + dot_accumulate_1blk( + bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, + (*scale_a1) * (*scale_b), acc1 + ); + } } +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen128Avx512( const size_t BlkLen, @@ -158,10 +196,10 @@ Q4Int8GemmR2xC4BlkLen128Avx512( const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; @@ -213,6 +251,7 @@ Q4Int8GemmR2xC4BlkLen128Avx512( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen128Avx512( const size_t BlkLen, @@ -266,7 +305,7 @@ Q4Int8GemmR2xC1BlkLen128Avx512( const __m512i av10_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av11_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); - accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, + accumulate_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers @@ -293,6 +332,7 @@ Q4Int8GemmR2xC1BlkLen128Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen128Avx512( const size_t BlkLen, @@ -342,10 +382,10 @@ Q4Int8GemmR1xC4BlkLen128Avx512( for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; @@ -371,6 +411,7 @@ Q4Int8GemmR1xC4BlkLen128Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen128Avx512( const size_t BlkLen, @@ -420,7 +461,7 @@ Q4Int8GemmR1xC1BlkLen128Avx512( const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); - accumulate_blklen128_r1c1blk1_avx512( + accumulate_blklen128_r1c1blk1_avx512( av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 ); @@ -446,6 +487,7 @@ Q4Int8GemmR1xC1BlkLen128Avx512( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen128Avx512( const size_t BlkLen, @@ -476,7 +518,7 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen128Avx512( + Q4Int8GemmR2xC4BlkLen128Avx512( BlkLen, QuantA, QuantAScale, @@ -491,7 +533,7 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen128Avx512( + Q4Int8GemmR2xC1BlkLen128Avx512( BlkLen, QuantA, QuantAScale, @@ -506,7 +548,7 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen128Avx512( + Q4Int8GemmR1xC4BlkLen128Avx512( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, @@ -521,7 +563,7 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen128Avx512( + Q4Int8GemmR1xC1BlkLen128Avx512( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index 02254a7056bc..a71b64ffa26c 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -120,6 +120,101 @@ accumulate_blklen16_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_blklen16_r1c1blk8_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_load_ps(scale_a); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0044115522663377 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen16_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + { + const __m256 scale_a0_ps = _mm256_load_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) + ); // 0123456701234567 + + // TODO: load from memory + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m256 scale_a1_ps = _mm256_load_ps(scale_a1); // 01234567 + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_castsi512_ps( + _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) + ); // 0123456701234567 + + __m512i idx = _mm512_set_epi32(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen16Avx512( const std::byte* QuantA, @@ -177,22 +272,49 @@ Q4Int8GemmR2xC4BlkLen16Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4]); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, - acc[1], acc[NCols4 + 1]); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, - acc[2], acc[NCols4 + 2]); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, - acc[3], acc[NCols4 + 3]); + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale, + acc[3], acc[NCols4 + 3] + ); + } // increment block pointers QuantAPtr += BlkLen16 * PerAccuBlk8; @@ -277,6 +399,7 @@ Q4Int8GemmR2xC4BlkLen16Avx512( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen16Avx512( const std::byte* QuantA, @@ -330,9 +453,17 @@ Q4Int8GemmR2C1BlkLen16Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen16_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + if constexpr (vnni) { + accumulate_blklen16_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen16_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } // increment block pointers QuantAPtr += BlkLen16 * PerAccuBlk8; @@ -378,6 +509,7 @@ Q4Int8GemmR2C1BlkLen16Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx512( const std::byte* QuantA, @@ -429,14 +561,17 @@ Q4Int8GemmR1xC4BlkLen16Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } else { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale, acc[1]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale, acc[2]); + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale, acc[3]); + } QuantAPtr += BlkLen16 * PerAccuBlk8; QuantAScalePtr += PerAccuBlk8; @@ -493,6 +628,7 @@ Q4Int8GemmR1xC4BlkLen16Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx512( const std::byte* QuantA, @@ -542,7 +678,11 @@ Q4Int8GemmR1xC1BlkLen16Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + if constexpr (vnni) { + accumulate_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } QuantAPtr += BlkLen16 * PerAccuBlk8; QuantAScalePtr += PerAccuBlk8; @@ -579,6 +719,7 @@ Q4Int8GemmR1xC1BlkLen16Avx512( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen16Avx512( @@ -612,7 +753,7 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen16Avx512( + Q4Int8GemmR2xC4BlkLen16Avx512( QuantA, QuantAScale, QuantBData, @@ -626,7 +767,7 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2C1BlkLen16Avx512( + Q4Int8GemmR2C1BlkLen16Avx512( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, @@ -640,7 +781,7 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen16Avx512( + Q4Int8GemmR1xC4BlkLen16Avx512( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, @@ -654,7 +795,7 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen16Avx512( + Q4Int8GemmR1xC1BlkLen16Avx512( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 60e7e5a84b55..7feccb182531 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -122,6 +122,95 @@ accumulate_blklen32_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk4_avx512vnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + + const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_load_ps(scale_a); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } +} + +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8, bv1_64_epi8; + load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + //__m512i idx = _mm512_loadu_epi8(&index_array[0]); + + const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + { + const __m128 scale_a0_ps = _mm_load_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); + } + { + const __m128 scale_a1_ps = _mm_load_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 + + const __m512i dot0_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av10_64_epi8); // 0000000011111111 + const __m512i dot1_32_epi16 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av11_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_32_epi16, dot1_32_epi16); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, scale_a1b_16_ps, acc1); + } +} + +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen32Avx512( const std::byte* QuantA, @@ -178,22 +267,49 @@ Q4Int8GemmR2xC4BlkLen32Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, - acc[0], acc[NCols4]); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, - acc[1], acc[NCols4 + 1]); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, - acc[2], acc[NCols4 + 2]); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, - QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, - acc[3], acc[NCols4 + 3]); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, + acc[0], acc[NCols4] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, + acc[1], acc[NCols4 + 1] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, + acc[2], acc[NCols4 + 2] + ); + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, + QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, + acc[3], acc[NCols4 + 3] + ); + } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; @@ -226,31 +342,44 @@ Q4Int8GemmR2xC4BlkLen32Avx512( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } else { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + } } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, - acc2[1], acc2[NCols4 + 1]); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); + } else { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); + } } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, - acc2[2], acc2[NCols4 + 2]); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); + } else { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); + } } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, - acc2[3], acc2[NCols4 + 3]); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); + } else { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); + } } QuantAPtr += BlkLen32; QuantAScalePtr++; @@ -278,6 +407,7 @@ Q4Int8GemmR2xC4BlkLen32Avx512( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen32Avx512( const std::byte* QuantA, @@ -330,9 +460,17 @@ Q4Int8GemmR2C1BlkLen32Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen32_r2c1blk4_avx512( - av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, - QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk4_avx512vnni( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } else { + accumulate_blklen32_r2c1blk4_avx512( + av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, + QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1 + ); + } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; @@ -354,7 +492,11 @@ Q4Int8GemmR2C1BlkLen32Avx512( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + if constexpr (vnni) { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + } else { + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + } QuantAPtr += BlkLen32; QuantAScalePtr++; @@ -379,6 +521,7 @@ Q4Int8GemmR2C1BlkLen32Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen32Avx512( const std::byte* QuantA, @@ -430,10 +573,17 @@ Q4Int8GemmR1xC4BlkLen32Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; @@ -453,19 +603,35 @@ Q4Int8GemmR1xC4BlkLen32Avx512( const float& scale_a00 = *QuantAScalePtr; { const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } else { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + } } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + } else { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); + } } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + } else { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); + } } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + } else { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); + } } QuantAPtr += BlkLen32; @@ -491,6 +657,7 @@ Q4Int8GemmR1xC4BlkLen32Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen32Avx512( const std::byte* QuantA, @@ -540,7 +707,12 @@ Q4Int8GemmR1xC1BlkLen32Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + else { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; @@ -554,7 +726,11 @@ Q4Int8GemmR1xC1BlkLen32Avx512( const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + if constexpr (vnni) { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + } else { + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + } QuantAPtr += BlkLen32; QuantAScalePtr++; @@ -577,20 +753,21 @@ Q4Int8GemmR1xC1BlkLen32Avx512( } } +template MLAS_FORCEINLINE - size_t - MlasQ4Int8GemmKernelBlkLen32Avx512( - const std::byte* QuantA, - const float* QuantAScale, - const std::byte* QuantBData, - const float* QuantBScale, - float* C, - size_t CountM, - size_t CountN, - size_t BlockCountK, - const float* Bias, - size_t ldc - ) +size_t +MlasQ4Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) { constexpr size_t BlkLen32 = 32; constexpr size_t BlkBitWidth4 = 4; @@ -610,7 +787,7 @@ MLAS_FORCEINLINE size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen32Avx512( + Q4Int8GemmR2xC4BlkLen32Avx512( QuantA, QuantAScale, QuantBData, @@ -624,7 +801,7 @@ MLAS_FORCEINLINE ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2C1BlkLen32Avx512( + Q4Int8GemmR2C1BlkLen32Avx512( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, @@ -638,7 +815,7 @@ MLAS_FORCEINLINE } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen32Avx512( + Q4Int8GemmR1xC4BlkLen32Avx512( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, @@ -652,7 +829,7 @@ MLAS_FORCEINLINE } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen32Avx512( + Q4Int8GemmR1xC1BlkLen32Avx512( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 4a67f2971494..2d85a150da45 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -90,6 +90,33 @@ dot_accumulate_2blk( acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); } +static MLAS_FORCEINLINE void +dot_accumulate_2blkvnni( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const float* scale_a, + const __m512i& bv0_64_epi8, + const __m512i& bv1_64_epi8, + const __m512& scale_b_16_ps, + // const __m512i& one_32_epi16, + __m512& acc +) +{ + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av0_64_epi8); + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av1_64_epi8); + + __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m512 scale_a_16_ps = _mm512_broadcast_f32x8(scale_a_8_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); +} + +template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk2_avx512( const __m512i& av00_64_epi8, @@ -110,19 +137,34 @@ accumulate_blklen64_r2c1blk2_avx512( const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); - dot_accumulate_2blk( - av00_64_epi8, av01_64_epi8, scale_a0, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc0 - ); - - dot_accumulate_2blk( - av10_64_epi8, av11_64_epi8, scale_a1, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc1 - ); + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blkvnni( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } else { + dot_accumulate_2blk( + av00_64_epi8, av01_64_epi8, scale_a0, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc0 + ); + + dot_accumulate_2blk( + av10_64_epi8, av11_64_epi8, scale_a1, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc1 + ); + } } +template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk2_avx512( const __m512i& av0_64_epi8, @@ -139,13 +181,22 @@ accumulate_blklen64_r1c1blk2_avx512( const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); - dot_accumulate_2blk( - av0_64_epi8, av1_64_epi8, scale_a, - bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, - acc - ); + if constexpr (vnni) { + dot_accumulate_2blkvnni( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } else { + dot_accumulate_2blk( + av0_64_epi8, av1_64_epi8, scale_a, + bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, + acc + ); + } } +template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx512( const __m512i& av0_64_epi8, @@ -163,35 +214,58 @@ accumulate_blklen64_r2c1blk1_avx512( const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); - const __m512i zeros = _mm512_setzero_si512(); - //const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); - //const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + if constexpr (vnni) { + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); - { - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } - __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); - __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + { + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); - } + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); - { - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } + } else { + const __m512i zeros = _mm512_setzero_si512(); + // const __m512i one_32_epi16_ = _mm512_andnot_epi32(zeros, zeros); + // const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_andnot_epi32(zeros, zeros), 15); + + const __m512i one_32_epi16 = _mm512_srli_epi16(_mm512_ternarylogic_epi32(zeros, zeros, zeros, 1), 15); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av0_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } - __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); - __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + { + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av1_64_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); - acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } } } +template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx512( const __m512i& av_32_epi8, @@ -206,19 +280,32 @@ accumulate_blklen64_r1c1blk1_avx512( const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); - const __m512i one_32_epi16 = _mm512_set1_epi16(1); + if constexpr (vnni) { + const __m512i one_32_epi16 = _mm512_set1_epi16(1); - __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); - __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); + __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); - __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); - __m128 scale_a_ps = _mm_broadcast_ss(scale_a); - __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } else { + const __m512i one_32_epi16 = _mm512_set1_epi16(1); - acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + __m512i dot_32_epi16 = _mm512_maddubs_epi16(bv_64_epi8, av_32_epi8); + __m512i sum_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot_32_epi16); + + __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + + __m128 scale_a_ps = _mm_broadcast_ss(scale_a); + __m512 scale_a_16_ps = _mm512_broadcast_f32x2(scale_a_ps); + + acc = _mm512_fmadd_ps(sum_16_ps, _mm512_mul_ps(scale_a_16_ps, scale_b_16_ps), acc); + } } +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen64Avx512( const std::byte* QuantA, @@ -275,10 +362,10 @@ Q4Int8GemmR2xC4BlkLen64Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; @@ -291,13 +378,13 @@ Q4Int8GemmR2xC4BlkLen64Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); QuantAPtr += BlkLen64; @@ -345,6 +432,7 @@ Q4Int8GemmR2xC4BlkLen64Avx512( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx512( const size_t BlkLen, @@ -399,7 +487,7 @@ Q4Int8GemmR2xC1BlkLen64Avx512( const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); - accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + accumulate_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; @@ -412,7 +500,7 @@ Q4Int8GemmR2xC1BlkLen64Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); - accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + accumulate_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); QuantAPtr += BlkLen64; QuantAScalePtr++; @@ -436,6 +524,7 @@ Q4Int8GemmR2xC1BlkLen64Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx512( const size_t BlkLen, @@ -486,10 +575,10 @@ Q4Int8GemmR1xC4BlkLen64Avx512( for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { const __m512i av0_64_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av1_64_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); - accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; @@ -501,10 +590,10 @@ Q4Int8GemmR1xC4BlkLen64Avx512( while (k_blks_remaining-- > 0) { const __m512i av_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); QuantAPtr += BlkLen64; QuantAScalePtr++; @@ -528,6 +617,7 @@ Q4Int8GemmR1xC4BlkLen64Avx512( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx512( const size_t BlkLen, @@ -578,7 +668,7 @@ Q4Int8GemmR1xC1BlkLen64Avx512( const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); - accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); // increment block pointers QuantAPtr += BlkLen64 * PerAccuBlk2; @@ -590,7 +680,7 @@ Q4Int8GemmR1xC1BlkLen64Avx512( while (k_blks_remaining-- > 0) { const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); - accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); QuantAPtr += BlkLen64; QuantAScalePtr++; @@ -612,6 +702,7 @@ Q4Int8GemmR1xC1BlkLen64Avx512( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx512( const size_t BlkLen, @@ -643,7 +734,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( if (multipleRows > 0 && multipleCols > 0) { if (NRows2 == 2) - Q4Int8GemmR2xC4BlkLen64Avx512( + Q4Int8GemmR2xC4BlkLen64Avx512( QuantA, QuantAScale, QuantBData, @@ -656,7 +747,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( ldc ); else - Q4Int8GemmR1xC4BlkLen64Avx512( + Q4Int8GemmR1xC4BlkLen64Avx512( BlkLen, QuantA, QuantAScale, @@ -672,7 +763,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( } if (remainingCols > 0 && multipleRows > 0) { if (NRows2 == 2) - Q4Int8GemmR2xC1BlkLen64Avx512( + Q4Int8GemmR2xC1BlkLen64Avx512( BlkLen, QuantA, QuantAScale, @@ -685,7 +776,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( Bias ? Bias + multipleCols : nullptr, ldc); else - Q4Int8GemmR1xC1BlkLen64Avx512( + Q4Int8GemmR1xC1BlkLen64Avx512( BlkLen, QuantA, QuantAScale, @@ -701,7 +792,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen64Avx512( + Q4Int8GemmR1xC4BlkLen64Avx512( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, @@ -715,7 +806,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( ldc); } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen64Avx512( + Q4Int8GemmR1xC1BlkLen64Avx512( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 8fc4a879a167..37451413b8c4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -23,6 +23,10 @@ Module Name: #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen16.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen32.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +#include "sqnbitgemm_kernel_avx512_int8_blklen128.h" MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompFp32( @@ -202,6 +206,99 @@ SQ4BitGemmM1Kernel_CompInt8_avx512vnni( } } +MLAS_FORCEINLINE +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ4Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ4Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ4Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ4Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -213,7 +310,7 @@ QuantizeARow_CompInt8_avx512( ); static void -SQ4BitGemmPackQuantBDataAndBlkSum( +SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, @@ -231,10 +328,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { - SubBlkLen = 64; + if (ComputeType == CompInt8) { + SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } @@ -246,7 +342,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; @@ -254,7 +350,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; - d.QuantizeARow_CompInt8 = nullptr; + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx512; return d; From 4b91bedb7335b19dffd87b0b3bd119a428a8f2cc Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Fri, 19 Jul 2024 17:57:27 -0700 Subject: [PATCH 22/41] avxvnni Signed-off-by: Liqun Fu --- onnxruntime/core/mlas/lib/mlasi.h | 2 + onnxruntime/core/mlas/lib/platform.cpp | 1 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 88 ++++++- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 238 ++++++++---------- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 116 ++++++--- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 200 +++++++++------ 6 files changed, 400 insertions(+), 245 deletions(-) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 83200187963e..4239e2ecaeb6 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -993,6 +993,8 @@ extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; + extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 859b7c2f560a..ed437f20f7c2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -409,6 +409,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; + this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 2a46d68692dd..0ed41b150a0e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -344,6 +344,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_avx2( } } +template MLAS_FORCEINLINE void SQ4BitGemmKernel_CompInt8_avx2( @@ -376,7 +377,7 @@ SQ4BitGemmKernel_CompInt8_avx2( ldc ); } else if (BlkLen == 32) { - MlasQ4Int8GemmKernelBlkLen32Avx2( + MlasQ4Int8GemmKernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -390,7 +391,7 @@ SQ4BitGemmKernel_CompInt8_avx2( ldc ); } else { - MlasQ4Int8GemmKernelBlkLen64Avx2( + MlasQ4Int8GemmKernelBlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -406,6 +407,7 @@ SQ4BitGemmKernel_CompInt8_avx2( } } +template MLAS_FORCEINLINE void SQ4BitGemmM1Kernel_CompInt8_avx2( @@ -425,7 +427,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( if (QuantBZeroPoint) { if (BlkLen == 16) { } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -453,7 +455,7 @@ SQ4BitGemmM1Kernel_CompInt8_avx2( } else { if (BlkLen == 16) { } else if (BlkLen == 32) { - MlasQ4Int8GemmM1KernelBlkLen32Avx2( + MlasQ4Int8GemmM1KernelBlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -502,11 +504,66 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2( ) { if (BlkLen >= 32 && CountM == 1) { - SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); return CountM; } - SQ4BitGemmKernel_CompInt8_avx2( + SQ4BitGemmKernel_CompInt8_avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + +size_t +SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen >= 32 && CountM == 1) { + SQ4BitGemmM1Kernel_CompInt8_avx2(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias); + return CountM; + } + + SQ4BitGemmKernel_CompInt8_avx2( BlkLen, QuantA, QuantAScale, @@ -1246,3 +1303,22 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); + +const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + + d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; + d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + + d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 6ff52be83b3e..d1762420b3a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -22,11 +22,12 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, MLAS_FORCEINLINE void accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { - __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +template static MLAS_FORCEINLINE void accumulate_blklen32_r2c1blk2_avx2( const __m256i& av00_32_epi8, @@ -53,100 +54,62 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16( - bv0_32_epi8, av00_32_epi8 - ); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16( - bv1_32_epi8, av01_32_epi8 - ); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - - // generating constant 1s is faster here. - // __m256i one = _mm256_set1_epi16(1); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); - - - const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16( - bv0_32_epi8, av10_32_epi8 - ); - const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16( - bv1_32_epi8, av11_32_epi8 - ); - const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); - const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); - const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); - - __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); - __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r2c1blk2_no_bc_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const __m256i& av10_32_epi8, - const __m256i& av11_32_epi8, - const std::byte* QuantBDataPtr, - const float& scale_a00, - const float& scale_a01, - const float& scale_a10, - const float& scale_a11, - const float& scale_b0, - const float& scale_b1, - __m256& acc0, - __m256& acc1 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - - // generating low_mask of 0x0Fs is not as fast as just calling _mm256_set1_epi8(0x0F). - const __m256i low_mask = _mm256_set1_epi8(0x0F); - //__m256i low_mask = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_packed, bv_packed), 12); - // low_mask = _mm256_packus_epi16(low_mask, low_mask); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - // TODO: this (the second line below) is faster and does not keep low_mask in use. - // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv1_32_epi8, bv1_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a00 * scale_b0, one_16_epi16, acc0); - accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a01 * scale_b1, one_16_epi16, acc0); - accumulate_1blk_dot(av10_32_epi8, bv0_32_epi8, scale_a10 * scale_b0, one_16_epi16, acc1); - accumulate_1blk_dot(av11_32_epi8, bv1_32_epi8, scale_a11 * scale_b1, one_16_epi16, acc1); -} - -static MLAS_FORCEINLINE void -accumulate_blklen32_r1c1blk2_no_bc_avx2( - const __m256i& av00_32_epi8, - const __m256i& av01_32_epi8, - const std::byte* QuantBDataPtr, - const float& scale_a_0, - const float& scale_a_1, - const float& scale_b_0, - const float& scale_b_1, - __m256& acc0 -) -{ - // | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | - const __m256i bv_packed = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); - const __m256i low_mask = _mm256_set1_epi8(0x0F); - __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 - __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - accumulate_1blk_dot(av00_32_epi8, bv0_32_epi8, scale_a_0 * scale_b_0, one_16_epi16, acc0); - accumulate_1blk_dot(av01_32_epi8, bv1_32_epi8, scale_a_1 * scale_b_1, one_16_epi16, acc0); + if constexpr (vnni) { + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); + } + } else { + //{ + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + // generating constant 1s is faster here. + // __m256i one = _mm256_set1_epi16(1); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + //} + //{ + const __m256i dot0_16_epi16_ = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + const __m256i dot1_16_epi16_ = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + const __m256i sum_16_epi16_ = _mm256_hadd_epi16(dot0_16_epi16_, dot1_16_epi16_); + const __m256i sum_8_epi32_ = _mm256_madd_epi16(one_16_epi16, sum_16_epi16_); + const __m256 sum_ps_ = _mm256_cvtepi32_ps(sum_8_epi32_); + + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); + //} + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( const __m256i& av00_32_epi8, @@ -163,19 +126,33 @@ accumulate_blklen32_r1c1blk2_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); // 00110011 - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } } template @@ -227,6 +204,7 @@ accumulate_blklen32_r1c1blk1_avx2( } } +template MLAS_FORCEINLINE void Q4Int8Gemm2x4x2BlkLen32Avx2( const std::byte* QuantA, @@ -291,18 +269,18 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + StrideQuantBScale2, acc[1], acc[NCols4 + 1]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], acc[NCols4 + 2]); } { - accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], acc[NCols4 + 3]); } // increment block pointers @@ -325,28 +303,28 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, scale_10, acc[1], acc[NCols4 + 1]); } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, scale_10, acc[2], acc[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, scale_10, acc[3], acc[NCols4 + 3]); } } // k_blks_remaining @@ -371,6 +349,7 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( } } +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const float* QuantAScale, @@ -423,7 +402,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); - accumulate_blklen32_r2c1blk2_avx2( + accumulate_blklen32_r2c1blk2_avx2( av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); @@ -443,7 +422,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); } *SumPtr = hsum_float_8(acc0); @@ -463,6 +442,7 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( const std::byte* QuantA, @@ -518,23 +498,23 @@ Q4Int8GemmXx4BlkLen32Avx2( const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); { - accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, + accumulate_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1] ); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2] ); } { - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3] ); @@ -556,22 +536,22 @@ Q4Int8GemmXx4BlkLen32Avx2( { // Col0 const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); } { // Col1 const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, acc[1]); } { // Col2 const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, acc[2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, acc[3]); } } @@ -591,6 +571,7 @@ Q4Int8GemmXx4BlkLen32Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( const std::byte* QuantA, @@ -639,7 +620,7 @@ Q4Int8GemmXxXBlkLen32Avx2( for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr)); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + BlkLen32)); - accumulate_blklen32_r1c1blk2_avx2( + accumulate_blklen32_r1c1blk2_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 ); @@ -656,7 +637,7 @@ Q4Int8GemmXxXBlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); } *SumPtr = hsum_float_8(acc0); @@ -673,6 +654,7 @@ Q4Int8GemmXxXBlkLen32Avx2( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen32Avx2( @@ -707,7 +689,7 @@ MLAS_FORCEINLINE size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8Gemm2x4x2BlkLen32Avx2( + Q4Int8Gemm2x4x2BlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -721,7 +703,7 @@ MLAS_FORCEINLINE ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8Gemm2xXBlkLen32Avx2( + Q4Int8Gemm2xXBlkLen32Avx2( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, @@ -735,7 +717,7 @@ MLAS_FORCEINLINE } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmXx4BlkLen32Avx2( + Q4Int8GemmXx4BlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData, @@ -749,7 +731,7 @@ MLAS_FORCEINLINE } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmXxXBlkLen32Avx2( + Q4Int8GemmXxXBlkLen32Avx2( QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, QuantBData + multipleCols * StrideQuantBData, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 15898a719632..ec9a338348c8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -6,6 +6,7 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx2( const __m256i& av00_32_epi8, @@ -25,32 +26,53 @@ accumulate_blklen64_r2c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); - __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); - __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av11_32_epi8); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); - dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); - dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); - sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + + } else { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + + __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_ps = _mm256_broadcast_ss(scale_a0); + __m256 scale_b_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a0_ps, scale_b_ps), acc0); - sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av10_32_epi8); + dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av11_32_epi8); + sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); + + acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); + } } +template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx2( const __m256i& av00_32_epi8, @@ -67,20 +89,32 @@ accumulate_blklen64_r1c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 - const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); - const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } else { + const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); - __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + __m256 scale_a_8_ps = _mm256_broadcast_ss(scale_a); + __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); + + acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); + } } +template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen64Avx2( const size_t BlkLen, @@ -138,10 +172,10 @@ Q4Int8GemmR2xC4BlkLen64Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); // increment block pointers QuantAPtr += SubblkLen; @@ -170,6 +204,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( } } +template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( const size_t BlkLen, @@ -223,7 +258,7 @@ Q4Int8GemmR2xC1BlkLen64Avx2( const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); - accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + accumulate_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); // increment block pointers QuantAPtr += SubblkLen; @@ -249,6 +284,7 @@ Q4Int8GemmR2xC1BlkLen64Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( const size_t BlkLen, @@ -298,10 +334,10 @@ Q4Int8GemmR1xC4BlkLen64Avx2( for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); - accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); // increment block pointers QuantAPtr += SubblkLen; @@ -327,6 +363,7 @@ Q4Int8GemmR1xC4BlkLen64Avx2( } } +template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( const size_t BlkLen, @@ -376,7 +413,7 @@ Q4Int8GemmR1xC1BlkLen64Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - accumulate_blklen64_r1c1blk1_avx2( + accumulate_blklen64_r1c1blk1_avx2( av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0 ); @@ -402,6 +439,7 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } +template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx2( const size_t BlkLen, @@ -432,7 +470,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( size_t multipleCols = CountN - remainingCols; if (multipleRows > 0 && multipleCols > 0) { - Q4Int8GemmR2xC4BlkLen64Avx2( + Q4Int8GemmR2xC4BlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -447,7 +485,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( ); } if (remainingCols > 0 && multipleRows > 0) { - Q4Int8GemmR2xC1BlkLen64Avx2( + Q4Int8GemmR2xC1BlkLen64Avx2( BlkLen, QuantA, QuantAScale, @@ -462,7 +500,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( } if (remainingRows > 0 && multipleCols > 0) { - Q4Int8GemmR1xC4BlkLen64Avx2( + Q4Int8GemmR1xC4BlkLen64Avx2( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, @@ -477,7 +515,7 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( } if (remainingCols > 0 && remainingRows > 0) { - Q4Int8GemmR1xC1BlkLen64Avx2( + Q4Int8GemmR1xC1BlkLen64Avx2( BlkLen, QuantA + multipleRows * lda, QuantAScale + multipleRows * lda_scale, diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 836c22f2cb82..0476766c2351 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -6,7 +6,7 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" -template +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_zp_avx2( const __m256i& av_32_epi8, @@ -24,13 +24,20 @@ accumulate_blklen32_r1c1blk1_zp_avx2( bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); - const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + if constexpr (vnni) { + const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } else { + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); + const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_avx2( const __m256i& av0_32_epi8, @@ -47,29 +54,48 @@ accumulate_blklen32_r1c1blk2_zp_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 - { - bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } + if constexpr (vnni) { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } - { - bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); - const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); - __m256i dot_16_epi16 = _mm256_maddubs_epi16( - _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) - ); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { + { + bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + + { + bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, _mm256_set1_epi8(get_zp(false, QuantBZeroPointPtr))); + const __m256 scale = _mm256_set1_ps(*(scale_a + 1) * *(scale_b + 1)); + __m256i dot_16_epi16 = _mm256_maddubs_epi16( + _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8) + ); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_is_8_avx2( const __m256i& av0_32_epi8, @@ -96,24 +122,39 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + if constexpr (vnni) { + __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_8_epi32 = _mm256_hadd_epi32(dot0_8_epi32, dot1_8_epi32); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); - const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); - __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); - __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); - // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 - __m256 scale_8_ps = _mm256_permute_ps( - _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) - ); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } else { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); + + const __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(low_mask, low_mask), 15); + const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 + __m256 scale_8_ps = _mm256_permute_ps( + _mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0) + ); - acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } } +template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( const __m256i& av0_32_epi8, @@ -136,25 +177,40 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); - } - { - __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); - __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); - const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + if constexpr (vnni) { + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + } else { + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*scale_b), scale_a0_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } + { + __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); + __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); - const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); - acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); + acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); + } } } -template +template MLAS_FORCEINLINE void Q4Int8GemmM1C4BlkLen32Avx2( const std::byte* QuantA, @@ -217,17 +273,17 @@ Q4Int8GemmM1C4BlkLen32Avx2( //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 2 * StrideQuantBData, QuantBScalePtr + 2 * StrideQuantBScale, acc[2], low_mask, bzp8); //accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2(av_00_epi8, av_01_epi8, scale_a0_8_ps, scale_a1_8_ps, QuantBDataPtr + 3 * StrideQuantBData, QuantBScalePtr + 3 * StrideQuantBScale, acc[3], low_mask, bzp8); if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); } else { const __m256i bzp8 = _mm256_set1_epi8(8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + StrideQuantBScale2, acc[1], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * StrideQuantBScale2, acc[2], low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * StrideQuantBScale2, acc[3], low_mask, bzp8); } // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk2; @@ -248,19 +304,19 @@ Q4Int8GemmM1C4BlkLen32Avx2( const float& scale_a00 = *QuantAScalePtr; { const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc[0], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + StrideQuantBData1, scale_00, QuantBZeroPointPtr + StrideQuantBZeroPoint, acc[1], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 2 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 2 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 2 * StrideQuantBZeroPoint, acc[2], low_mask); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 3 * StrideQuantBScale1)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr + 3 * StrideQuantBData1, scale_00, QuantBZeroPointPtr + 3 * StrideQuantBZeroPoint, acc[3], low_mask); } } @@ -283,7 +339,7 @@ Q4Int8GemmM1C4BlkLen32Avx2( } } -template +template MLAS_FORCEINLINE void Q4Int8GemmM1C1BlkLen32Avx2( const std::byte* QuantA, @@ -336,9 +392,9 @@ Q4Int8GemmM1C1BlkLen32Avx2( //const __m256 scale_a1_8_ps = _mm256_set1_ps(Q8BlkScale(QuantAPtr + Q8BlkSize(BlkLen32))); if constexpr (HasZeroPoint) { - accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); + accumulate_blklen32_r1c1blk2_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, acc0, low_mask); } else { - accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); + accumulate_blklen32_r1c1blk2_zp_is_8_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0, low_mask, bzp8); } // increment block pointers @@ -356,7 +412,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); + accumulate_blklen32_r1c1blk1_zp_avx2(av_00_epi8, QuantBDataPtr, scale_00, QuantBZeroPointPtr, acc0, low_mask); } *SumPtr = hsum_float_8(acc0); @@ -376,7 +432,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( } } -template +template MLAS_FORCEINLINE void MlasQ4Int8GemmM1KernelBlkLen32Avx2( @@ -403,7 +459,7 @@ MlasQ4Int8GemmM1KernelBlkLen32Avx2( size_t multipleCols = CountN - remainingCols; if (multipleCols > 0) { - Q4Int8GemmM1C4BlkLen32Avx2( + Q4Int8GemmM1C4BlkLen32Avx2( QuantA, QuantAScale, QuantBData, @@ -416,7 +472,7 @@ MlasQ4Int8GemmM1KernelBlkLen32Avx2( } if (remainingCols > 0) { - Q4Int8GemmM1C1BlkLen32Avx2( + Q4Int8GemmM1C1BlkLen32Avx2( QuantA, QuantAScale, QuantBData + multipleCols * StrideQuantBData, From 8674b9f103f8172631bafa71384b7c7f9c2f4037 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 23 Jul 2024 20:32:31 +0000 Subject: [PATCH 23/41] rm unused ComputeParallelTasksSGemm Signed-off-by: liqunfu --- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 52 +----------------------- 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 8e47a7018c98..5c74d3e510c7 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -518,7 +518,7 @@ SQ4BitGemm_CompInt8( QuantAScale, b_col, b_col_scale, - b_col_zp, + b_col_zp, c_blk, RangeCountM, CountN, @@ -622,56 +622,6 @@ constexpr auto OperationMap = []() { return ops; }(); - -void -ComputeParallelTasksSGemm(const size_t M, const size_t N, const size_t CountK, const size_t BatchN, - MLAS_THREADPOOL* ThreadPool, - size_t& ThreadCountM, size_t& ThreadCountN, size_t& ThreadsPerGemm) -{ - const double Complexity = double(M) * double(N) * double(CountK); - - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - - ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool) * 8; - - if (TargetThreadCount >= MaximumThreadCount) { - TargetThreadCount = MaximumThreadCount; - } - - // - // Segment the operation across multiple threads. - // - // N.B. Currently, the operation is segmented as a 1D partition, which - // works okay for operations involving skinny matrices. - // - - ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; - if (N > M) { - const size_t BlockedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / - MLAS_SGEMM_STRIDEN_THREAD_ALIGN; - - if (size_t(ThreadsPerGemm) > BlockedN) { - ThreadsPerGemm = ptrdiff_t(BlockedN); - } - - ThreadCountM = 1; - ThreadCountN = ThreadsPerGemm; - - } else { - if (size_t(ThreadsPerGemm) > M) { - ThreadsPerGemm = ptrdiff_t(M); - } - - ThreadCountM = ThreadsPerGemm; - ThreadCountN = 1; - } -} } // namespace void MLASCALL From e26e29e8589d32ee673824a1466594407ef0241d Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 24 Jul 2024 19:17:05 +0000 Subject: [PATCH 24/41] avoid _mm256_dpbusds_avx_epi32 in avx512vnni Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 113 ++++++++++-------- onnxruntime/test/mlas/bench/bench_q4dq.cpp | 8 +- 2 files changed, 67 insertions(+), 54 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 7feccb182531..62d4d9f4604f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -210,6 +210,59 @@ accumulate_blklen32_r2c1blk4_avx512vnni( } } +MLAS_FORCEINLINE void +accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) +{ + __m256i sum_8_epi32 = _mm256_dpbusds_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + const float& combined_scale00, + const float& combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + // | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); + __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); + bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); + + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen32Avx512( @@ -342,44 +395,28 @@ Q4Int8GemmR2xC4BlkLen32Avx512( // Col0 const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; - if constexpr (vnni) { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); - } else { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); - } + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); } { // Col1 const float scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 1)[0]; - if constexpr (vnni) { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); - } else { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); - } + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, scale_10, acc2[1], acc2[NCols4 + 1]); } { // Col2 const float scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; const float scale_10 = scale_a10 * (QuantBScalePtr + 2)[0]; - if constexpr (vnni) { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); - } else { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); - } + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[2], acc2[NCols4 + 2]); } { // Col3 const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr + 3)[0]; - if constexpr (vnni) { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); - } else { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); - } + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, scale_10, acc2[3], acc2[NCols4 + 3]); } QuantAPtr += BlkLen32; QuantAScalePtr++; @@ -492,11 +529,7 @@ Q4Int8GemmR2C1BlkLen32Avx512( const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; const float& scale_10 = scale_a10 * (QuantBScalePtr)[0]; - if constexpr (vnni) { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); - } else { - accumulate_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); - } + accumulate_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); QuantAPtr += BlkLen32; QuantAScalePtr++; @@ -603,35 +636,19 @@ Q4Int8GemmR1xC4BlkLen32Avx512( const float& scale_a00 = *QuantAScalePtr; { const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - if constexpr (vnni) { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); - } else { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); - } + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 1)[0]; - if constexpr (vnni) { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); - } else { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); - } + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_00, acc2[1]); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 2)[0]; - if constexpr (vnni) { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); - } else { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); - } + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_00, acc2[2]); } { const float& scale_00 = scale_a00 * (QuantBScalePtr + 3)[0]; - if constexpr (vnni) { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); - } else { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); - } + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_00, acc2[3]); } QuantAPtr += BlkLen32; @@ -726,11 +743,7 @@ Q4Int8GemmR1xC1BlkLen32Avx512( const float& scale_a00 = *QuantAScalePtr; const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; - if constexpr (vnni) { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); - } else { - accumulate_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); - } + accumulate_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); QuantAPtr += BlkLen32; QuantAScalePtr++; diff --git a/onnxruntime/test/mlas/bench/bench_q4dq.cpp b/onnxruntime/test/mlas/bench/bench_q4dq.cpp index a77009bee7b5..6d21ed2eef86 100644 --- a/onnxruntime/test/mlas/bench/bench_q4dq.cpp +++ b/onnxruntime/test/mlas/bench/bench_q4dq.cpp @@ -65,10 +65,10 @@ static void BM_MlasQuantizeBlockwise(benchmark::State& state) { } static void BM_QDQBlockwiseQuantizer_TransposeColumnwise(benchmark::State& state) { - int M = state.range(0); - int N = state.range(1); - int quant_block_size = state.range(2); - int threads = state.range(3); + int M = (int)state.range(0); + int N = (int)state.range(1); + int quant_block_size = (int)state.range(2); + int threads = (int)state.range(3); bool add8 = state.range(4) != 0; int quant_num_M = (M + quant_block_size - 1) / quant_block_size; int blob_size = (quant_block_size + 1) / 2; From 2b0307e05c7dee2ff34e46459fe31f0b07aa9b11 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 24 Jul 2024 13:52:17 -0700 Subject: [PATCH 25/41] fix linux build Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 4 ++-- .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 20 ++++++++--------- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 22 +++++++++---------- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 2 -- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 7 +++--- 5 files changed, 26 insertions(+), 29 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 66f4aea606ef..292ae52d4f99 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,7 +555,7 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S @@ -575,7 +575,7 @@ else() ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mavx512bw -mavx512dq -mavx512vl") + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") set(mlas_platform_srcs_avx512vnni ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index a71b64ffa26c..3cd610796a5e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -33,9 +33,9 @@ accumulate_blklen16_r1c1blk8_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 { - const __m256 scale_a0_ps = _mm256_load_ps(scale_a); // 01234567 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) @@ -74,9 +74,9 @@ accumulate_blklen16_r2c1blk4_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 { - const __m256 scale_a0_ps = _mm256_load_ps(scale_a0); // 01234567 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) @@ -98,7 +98,7 @@ accumulate_blklen16_r2c1blk4_avx512( acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); } { - const __m256 scale_a1_ps = _mm256_load_ps(scale_a1); // 01234567 + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); __m512 scale_a1b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) @@ -133,9 +133,9 @@ accumulate_blklen16_r1c1blk8_avx512vnni( __m512i bv0_64_epi8, bv1_64_epi8; load_4blk_4b_packed_blklen16(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 { - const __m256 scale_a0_ps = _mm256_load_ps(scale_a); // 01234567 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a); // 01234567 const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) @@ -172,9 +172,9 @@ accumulate_blklen16_r2c1blk4_avx512vnni( __m512i bv0_64_epi8, bv1_64_epi8; load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m256 scale_b_ps = _mm256_load_ps(scale_b); // 01234567 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 { - const __m256 scale_a0_ps = _mm256_load_ps(scale_a0); // 01234567 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a0b_ps)) @@ -194,7 +194,7 @@ accumulate_blklen16_r2c1blk4_avx512vnni( acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); } { - const __m256 scale_a1_ps = _mm256_load_ps(scale_a1); // 01234567 + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); // 01234567 const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); __m512 scale_a1b_16_ps = _mm512_castsi512_ps( _mm512_broadcast_i64x4(_mm256_castps_si256(scale_a1b_ps)) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 62d4d9f4604f..b865585c1605 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -39,9 +39,9 @@ accumulate_blklen32_r1c1blk4_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 { - const __m128 scale_a0_ps = _mm_load_ps(scale_a); // 0123 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 @@ -79,9 +79,9 @@ accumulate_blklen32_r2c1blk4_avx512( __m512i bv0_64_epi8, bv1_64_epi8; load_2blk_4b_packed_blklen64(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 { - const __m128 scale_a0_ps = _mm_load_ps(scale_a0); // 0123 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 @@ -101,7 +101,7 @@ accumulate_blklen32_r2c1blk4_avx512( acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); } { - const __m128 scale_a1_ps = _mm_load_ps(scale_a1); // 0123 + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 @@ -135,9 +135,9 @@ accumulate_blklen32_r1c1blk4_avx512vnni( __m512i bv0_64_epi8, bv1_64_epi8; load_4blk_4b_packed_blklen32(QuantBDataPtr, bv0_64_epi8, bv1_64_epi8); - const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 { - const __m128 scale_a0_ps = _mm_load_ps(scale_a); // 0123 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a); // 0123 const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 @@ -175,9 +175,9 @@ accumulate_blklen32_r2c1blk4_avx512vnni( __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); //__m512i idx = _mm512_loadu_epi8(&index_array[0]); - const __m128 scale_b_ps = _mm_load_ps(scale_b); // 0123 + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 { - const __m128 scale_a0_ps = _mm_load_ps(scale_a0); // 0123 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 @@ -193,7 +193,7 @@ accumulate_blklen32_r2c1blk4_avx512vnni( acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); } { - const __m128 scale_a1_ps = _mm_load_ps(scale_a1); // 0123 + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 @@ -213,7 +213,7 @@ accumulate_blklen32_r2c1blk4_avx512vnni( MLAS_FORCEINLINE void accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { - __m256i sum_8_epi32 = _mm256_dpbusds_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); + __m256i sum_8_epi32 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bv_32_epi8, av_32_epi8); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2d85a150da45..dbdcc751d61a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -281,8 +281,6 @@ accumulate_blklen64_r1c1blk1_avx512( const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); if constexpr (vnni) { - const __m512i one_32_epi16 = _mm512_set1_epi16(1); - __m512i dot_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av_32_epi8); __m512 sum_16_ps = _mm512_cvtepi32_ps(dot_16_epi32); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 0476766c2351..db9446228478 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -300,7 +300,6 @@ Q4Int8GemmM1C4BlkLen32Avx2( // load A const std::byte* QuantABlk0 = QuantAPtr; const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); - const __m256i zero = _mm256_setzero_si256(); const float& scale_a00 = *QuantAScalePtr; { const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; @@ -373,7 +372,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; const float* BiasPtr = Bias; auto* SumPtr = C; - + const __m256i low_mask = _mm256_set1_epi8(0x0F); const __m256i bzp8 = _mm256_set1_epi8(8); for (size_t n = 0; n < CountN; n++) { @@ -381,7 +380,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc0 = _mm256_setzero_ps(); size_t k_blks_remaining = BlockCountK; @@ -503,7 +502,7 @@ void SQ4BitGemmM1Kernel_BlkLen32_CompInt8_Impl2( constexpr size_t BlkLen = 32; #if defined SQ4BitGemmM1Kernel_BlkLen32_CompInt8_NewLayout #else - constexpr bool HasZeroPoint = false; + constexpr bool HasZeroPoint = false; #endif float* CRowPtr = C; From 51e97c8f488fece96f229e283877d9416e67fdfe Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Fri, 26 Jul 2024 14:57:27 -0700 Subject: [PATCH 26/41] refactor for Arm64 Signed-off-by: Liqun Fu --- .../cpu/quantization/matmul_nbits.cc | 27 ++++------ onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 50 ++++++++++++++----- .../test/mlas/unittest/test_sqnbitgemm.cpp | 9 +++- 3 files changed, 55 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 7df608a92d74..39c87e8d5237 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -224,13 +224,14 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); if (prepacked_weights) { - // TODO: cannot use packed_b_ after + // TODO: cannot use packed_b_ after following code with std::move. assert(false); prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); } is_packed = true; } else if (compute_type == CompInt8) { +#ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { auto sptr = tensor.Data(); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); @@ -240,6 +241,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), nullptr, has_zp_input_, zptr, nullptr); is_packed = false; } +#endif } #endif // defined(ORT_NEURAL_SPEED) @@ -359,7 +361,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { for (size_t i = 0; i < batch_count; ++i) { data[i].A = a_data + helper.LeftOffsets()[i]; data[i].lda = lda; - data[i].QuantBDataWorkspace = packed_b_.get(); +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); data[i].QuantBScale = scales_data; data[i].QuantBZeroPoint = zero_points_data; data[i].Bias = bias_data; @@ -367,21 +374,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].ldc = N; data[i].node_name = this->Node().Name(); } - //auto start2 = std::chrono::high_resolution_clock::now(); // Start timing here - - //const int CountTotal = 2000; - //int count = CountTotal; - //while (count-- > 0) - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - //auto end = std::chrono::high_resolution_clock::now(); // End timing here + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); - //std::chrono::duration elapsed2 = end - start2; - //// Calculate and print the duration in nanoseconds - //std::chrono::duration elapsed = end - start; - //std::cout << "MlasSQNBitGemmBatch: " << elapsed2.count() / CountTotal << " ns\n"; - //std::cout << "main Duration_M" << M << "xN" << N << "xK" << K << ": " << elapsed.count() / CountTotal << " ns\n"; return Status::OK(); } } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 5c74d3e510c7..f5eb0fdafc85 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -301,7 +301,7 @@ typedef void(SQNBitGemmFn)( size_t BlkLen, size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - PerGemmQuantAWorkspace* PerGemmWorkspace, + void* PerGemmWorkspace, size_t RangeStartM, size_t RangeCountM, size_t RangeStartN, @@ -313,7 +313,7 @@ SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - PerGemmQuantAWorkspace* const PerGemmWorkspace, + void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, @@ -333,7 +333,7 @@ SQ4BitGemm_CompFp32( const float* A = DataParams->A + RangeStartM * lda; - const std::byte* QuantBData = static_cast(DataParams->QuantBDataWorkspace) + RangeStartN * ldb; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; const std::byte* QuantBZeroPoint = (DataParams->QuantBZeroPoint == nullptr) @@ -432,7 +432,7 @@ SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, - PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace, + void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, const size_t RangeStartN, @@ -450,6 +450,8 @@ SQ4BitGemm_CompInt8( // return; // } //#endif +#ifdef MLAS_TARGET_AMD64_IX86 + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); constexpr size_t BlkBitWidth = 4; const size_t k_blks = MlasDivRoundup(K, BlkLen); @@ -475,6 +477,29 @@ SQ4BitGemm_CompInt8( float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#else + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; +#endif size_t CountN; for (size_t n = 0; n < RangeCountN; n += CountN) { @@ -671,18 +696,17 @@ MlasSQNBitGemmBatch( //auto start = std::chrono::high_resolution_clock::now(); // Start timing here for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; - if (ComputeType == CompInt8) { - // TODO: shall sepqrate QuantBBlkSum from QuantBData + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { - ComputeOperation(BlkLen, K, Data, nullptr, 0, M, 0, N); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, 0, M, 0, N); } } return; @@ -743,18 +767,18 @@ MlasSQNBitGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - if (ComputeType == CompInt8) { + void* PerGemmWorkspace = + reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; + if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; - void* PerGemmWorkspace = - reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } else { - ComputeOperation(BlkLen, K, Data, nullptr, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + ComputeOperation(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } }); } diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index cced40afafa3..a78768b06b6b 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -55,7 +55,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { size_t K, const float* A, size_t lda, - const void* QuantBData, + const void* /*QuantBData*/, const void* PackedQuantBDataWorkspace, const float* QuantBScale, const void* QuantBZeroPoint, @@ -71,7 +71,12 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.Bias = Bias; params.C = C; params.ldc = ldc; - params.QuantBDataWorkspace = PackedQuantBDataWorkspace != nullptr ? PackedQuantBDataWorkspace : QuantBData; +#ifdef MLAS_TARGET_AMD64_IX86 + if (ComputeType == CompInt8) { + params.QuantBDataWorkspace = PackedQuantBDataWorkspace; + } +#endif + params.PackedQuantBData = static_cast(PackedQuantBDataWorkspace); params.QuantBScale = QuantBScale; params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; From 48e8639dd5c391a3754b1a0da6b65c605c8593d5 Mon Sep 17 00:00:00 2001 From: Liqun Fu Date: Fri, 26 Jul 2024 15:19:45 -0700 Subject: [PATCH 27/41] more refactor for Arm64 Signed-off-by: Liqun Fu --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 1 - onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 39c87e8d5237..0c0f84303cee 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -281,7 +281,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); const auto* a_data = a->Data(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index f5eb0fdafc85..7ab9b23d3467 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -534,6 +534,7 @@ SQ4BitGemm_CompInt8( RowsRemaining -= RowsHandled; } } +#ifdef MLAS_TARGET_AMD64_IX86 else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; @@ -562,6 +563,7 @@ SQ4BitGemm_CompInt8( ); } } +#endif } } From 705aa1f2fa2b1839ccc82082d52cfb7ef10dc10c Mon Sep 17 00:00:00 2001 From: liqunfu Date: Mon, 29 Jul 2024 20:31:59 +0000 Subject: [PATCH 28/41] hsum_float_16 Signed-off-by: liqunfu --- .../lib/sqnbitgemm_kernel_avx512_int8_blklen32.h | 11 +++++++++++ .../core/mlas/lib/sqnbitgemm_kernel_avx_common.h | 15 +-------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index b865585c1605..6b2a93604af4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -14,6 +14,17 @@ h_add_512(__m512 a) return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); } +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h index 6097849096e5..177f5518bb89 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h @@ -611,7 +611,7 @@ FoldAccumulators(const __m256& acc0, const __m256& acc1, const __m256& acc2, con return acc_y; } -static inline float +static MLAS_FORCEINLINE float hsum_float_8(const __m256 x) { __m128 res = _mm256_extractf128_ps(x, 1); @@ -621,19 +621,6 @@ hsum_float_8(const __m256 x) return _mm_cvtss_f32(res); } -static inline float -hsum_float_16(const __m512 x) -{ - __m256 hi = _mm512_extractf32x8_ps(x, 1); - __m256 lo = _mm512_castps512_ps256(x); - hi = _mm256_add_ps(hi, lo); - __m128 hi128 = _mm256_extractf128_ps(hi, 1); - __m128 lo128 = _mm256_castps256_ps128(hi); - hi128 = _mm_add_ps(hi128, lo128); - hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); - hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); - return _mm_cvtss_f32(hi128); -} /** * @brief Horizontally sum 4 vectors and store * the results in the returned vector From 012e9c46848004acaf3cd0cd5435007a12d0331c Mon Sep 17 00:00:00 2001 From: liqunfu Date: Mon, 29 Jul 2024 21:27:03 +0000 Subject: [PATCH 29/41] hsum_float_16 Signed-off-by: liqunfu --- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 18 ------------------ .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 6b2a93604af4..dc1f4d4a5a25 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -8,24 +8,6 @@ #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" -static MLAS_FORCEINLINE __m256 -h_add_512(__m512 a) -{ - return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); -} - -static MLAS_FORCEINLINE float -hsum_float_16(const __m512 x) -{ - __m256 hi = h_add_512(x); - __m128 hi128 = _mm256_extractf128_ps(hi, 1); - __m128 lo128 = _mm256_castps256_ps128(hi); - hi128 = _mm_add_ps(hi128, lo128); - hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); - hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); - return _mm_cvtss_f32(hi128); -} - static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) { diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index dbdcc751d61a..2a65ac4af0c1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -6,6 +6,24 @@ #include "sqnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +static MLAS_FORCEINLINE __m256 +h_add_512(__m512 a) +{ + return _mm256_add_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)); +} + +static MLAS_FORCEINLINE float +hsum_float_16(const __m512 x) +{ + __m256 hi = h_add_512(x); + __m128 hi128 = _mm256_extractf128_ps(hi, 1); + __m128 lo128 = _mm256_castps256_ps128(hi); + hi128 = _mm_add_ps(hi128, lo128); + hi128 = _mm_add_ps(hi128, _mm_movehl_ps(hi128, hi128)); + hi128 = _mm_add_ss(hi128, _mm_movehdup_ps(hi128)); + return _mm_cvtss_f32(hi128); +} + static MLAS_FORCEINLINE __m512i combine_two_m256i_to_m512i(const __m256i& a, const __m256i& b) { From 21b9138fa390052da11326fc64c8a78f9251a2b3 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 16:46:18 +0000 Subject: [PATCH 30/41] condition for -mavxvnni Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 7 +++++-- .../lib/sqnbitgemm_kernel_avx2_int8_blklen32.h | 18 +++++++++++++++++- .../lib/sqnbitgemm_kernel_avx2_int8_blklen64.h | 12 ++++++++++-- ...nbitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 16 ++++++++++++++++ 4 files changed, 48 insertions(+), 5 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 292ae52d4f99..cc62d36ebfa3 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,8 +555,11 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") - +if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "9") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") +else() + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") +endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index d1762420b3a3..1a8216f69d0a 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -54,6 +54,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); { @@ -78,6 +79,7 @@ accumulate_blklen32_r2c1blk2_avx2( acc1 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc1); } } else { +#endif //{ const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); @@ -106,7 +108,9 @@ accumulate_blklen32_r2c1blk2_avx2( __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); //} +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -126,6 +130,7 @@ accumulate_blklen32_r1c1blk2_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); @@ -139,6 +144,7 @@ accumulate_blklen32_r1c1blk2_avx2( __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); } else { +#endif const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -152,7 +158,9 @@ accumulate_blklen32_r1c1blk2_avx2( // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -171,15 +179,19 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); } else { +#endif __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -196,12 +208,16 @@ accumulate_blklen32_r1c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); } else { +#endif __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index ec9a338348c8..e4acef6b27b3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -26,6 +26,7 @@ accumulate_blklen64_r2c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); @@ -43,8 +44,9 @@ accumulate_blklen64_r2c1blk1_avx2( __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); - + } else { +#endif __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -69,7 +71,9 @@ accumulate_blklen64_r2c1blk1_avx2( __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -89,6 +93,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); @@ -99,6 +104,7 @@ accumulate_blklen64_r1c1blk1_avx2( acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); } else { +#endif const __m256i dot0_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); const __m256i dot1_16_epi16 = _mm256_maddubs_epi16(bv1_32_epi8, av01_32_epi8); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -111,7 +117,9 @@ accumulate_blklen64_r1c1blk1_avx2( __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -134,7 +142,7 @@ Q4Int8GemmR2xC4BlkLen64Avx2( constexpr size_t NCols4 = 4; constexpr size_t NRows2 = 2; constexpr size_t SubblkLen = 64; - + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); const size_t PerBlkSubblkCount = BlkLen / SubblkLen; const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index db9446228478..42fd1131d9a4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -24,17 +24,21 @@ accumulate_blklen32_r1c1blk1_zp_avx2( bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } else { +#endif __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); const __m256i dot_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -54,6 +58,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { { bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); @@ -71,6 +76,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } } else { +#endif { bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); const __m256 scale = _mm256_set1_ps(*(scale_a) * *(scale_b)); @@ -92,7 +98,9 @@ accumulate_blklen32_r1c1blk2_zp_avx2( const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -122,6 +130,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); @@ -135,6 +144,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); } else { +#endif __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); __m256i dot1_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); @@ -151,7 +161,9 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( ); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template @@ -177,6 +189,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); +#if !defined(__GNUC__) || (__GNUC__ > 9) if constexpr (vnni) { { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); @@ -191,6 +204,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } } else { +#endif { __m256i dot0_16_epi16 = _mm256_maddubs_epi16(_mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); __m256i sum_8_epi32 = _mm256_madd_epi16(_mm256_set1_epi16(1), dot0_16_epi16); @@ -207,7 +221,9 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } +#if !defined(__GNUC__) || (__GNUC__ > 9) } +#endif } template From 1fb1c83e13e104ffdaabacaa647fc79ea6ea2072 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 18:29:27 +0000 Subject: [PATCH 31/41] CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 10 Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 4 +++- ...sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cc62d36ebfa3..079067a85bfc 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,9 +555,11 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) -if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "9") +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") else() + message(STATUS "Using -mavx2 -mfma flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") endif() set(mlas_platform_srcs_avx512f diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 42fd1131d9a4..7c9828c6b979 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -24,7 +24,7 @@ accumulate_blklen32_r1c1blk1_zp_avx2( bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { const __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv_32_epi8, bv_32_epi8), _mm256_sign_epi8(av_32_epi8, bv_32_epi8)); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -36,7 +36,7 @@ accumulate_blklen32_r1c1blk1_zp_avx2( const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, dot_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -58,7 +58,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 32~63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { { bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, _mm256_set1_epi8(get_zp(true, QuantBZeroPointPtr))); @@ -98,7 +98,7 @@ accumulate_blklen32_r1c1blk2_zp_avx2( const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -130,7 +130,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av1_32_epi8, bv1_32_epi8)); @@ -161,7 +161,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_avx2( ); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -189,7 +189,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( bv0_32_epi8 = _mm256_sub_epi8(bv0_32_epi8, bzp8); bv1_32_epi8 = _mm256_sub_epi8(bv1_32_epi8, bzp8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), _mm256_sign_epi8(bv0_32_epi8, bv0_32_epi8), _mm256_sign_epi8(av0_32_epi8, bv0_32_epi8)); @@ -221,7 +221,7 @@ accumulate_blklen32_r1c1blk2_zp_is_8_no_bc_avx2( const __m256 scale = _mm256_mul_ps(_mm256_set1_ps(*(scale_b + 1)), scale_a1_8_ps); acc0 = _mm256_fmadd_ps(sum_ps, scale, acc0); } -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } From 85918e98853949b7dfa47a71e38436dbb6add054 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 18:33:06 +0000 Subject: [PATCH 32/41] missed 2 files from (__GNUC__ > 10) Signed-off-by: liqunfu --- .../lib/sqnbitgemm_kernel_avx2_int8_blklen32.h | 16 ++++++++-------- .../lib/sqnbitgemm_kernel_avx2_int8_blklen64.h | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 1a8216f69d0a..9dddf0df53b4 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -54,7 +54,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1_32_epi8 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); { @@ -108,7 +108,7 @@ accumulate_blklen32_r2c1blk2_avx2( __m256 scale_8_ps_ = _mm256_permute_ps(_mm256_mul_ps(scale_a1_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc1 = _mm256_fmadd_ps(sum_ps_, scale_8_ps_, acc1); //} -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -130,7 +130,7 @@ accumulate_blklen32_r1c1blk2_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0~31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32~63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { const __m256i dot0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); const __m256i dot1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); @@ -158,7 +158,7 @@ accumulate_blklen32_r1c1blk2_avx2( // 1 0 1 0 1 0 1 0 -> 1 1 0 0 1 1 0 0 __m256 scale_8_ps = _mm256_permute_ps(_mm256_mul_ps(scale_a0_2_ps, scale_b_2_ps), _MM_SHUFFLE(1, 1, 0, 0)); acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -180,7 +180,7 @@ accumulate_blklen32_r2c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); accumulate_1blk_dot_vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); @@ -189,7 +189,7 @@ accumulate_blklen32_r2c1blk1_avx2( __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); accumulate_1blk_dot(av10_32_epi8, bv_32_epi8, combined_scale10, one_16_epi16, acc1); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -208,14 +208,14 @@ accumulate_blklen32_r1c1blk1_avx2( __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { accumulate_1blk_dot_vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); } else { #endif __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv_32_epi8, bv_32_epi8), 15); accumulate_1blk_dot(av00_32_epi8, bv_32_epi8, combined_scale00, one_16_epi16, acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index e4acef6b27b3..174ebc580904 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -26,7 +26,7 @@ accumulate_blklen64_r2c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); @@ -71,7 +71,7 @@ accumulate_blklen64_r2c1blk1_avx2( __m256 scale_a1_ps = _mm256_broadcast_ss(scale_a1); acc1 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a1_ps, scale_b_ps), acc1); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } @@ -93,7 +93,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256i bv0_32_epi8 = _mm256_and_si256(bv_packed, low_mask); // 0, 1,...30, 31 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 32, 33,...62, 63 -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) if constexpr (vnni) { __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); From 9530ac56fddd6dac6cfc3faeab9deba3db838437 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 19:03:42 +0000 Subject: [PATCH 33/41] missed _mm256_dpbusds_avx_epi32 and print out cmake msgs Signed-off-by: liqunfu --- cmake/onnxruntime_mlas.cmake | 4 ++++ .../core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h | 2 ++ 2 files changed, 6 insertions(+) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 079067a85bfc..c02ac2096db2 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -555,6 +555,10 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) + +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 9dddf0df53b4..af6f52090adc 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -19,6 +19,7 @@ accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +#if !defined(__GNUC__) || (__GNUC__ > 10) MLAS_FORCEINLINE void accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { @@ -26,6 +27,7 @@ accumulate_1blk_dot_vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, c const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); acc = _mm256_fmadd_ps(sum_ps, _mm256_set1_ps(combined_scale), acc); } +#endif template static MLAS_FORCEINLINE void From f77cffd458302fc79bfd185fcb43e33bbbb39060 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 20:13:44 +0000 Subject: [PATCH 34/41] unused zp, etc. Signed-off-by: liqunfu --- .../lib/sqnbitgemm_kernel_avx512_int8_blklen32.h | 12 ++++++------ .../sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 2 +- .../sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index dc1f4d4a5a25..ca12cc14a787 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -38,8 +38,8 @@ accumulate_blklen32_r1c1blk4_avx512( const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - __m512i idx = _mm512_loadu_epi8(&index_array[0]); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av0_64_epi8); // 0~0,1~1 @@ -78,8 +78,8 @@ accumulate_blklen32_r2c1blk4_avx512( const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); __m512 scale_a0b_16_ps = _mm512_broadcast_f32x4(scale_a0b_ps); // 0123012301230123 - //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - __m512i idx = _mm512_loadu_epi8(&index_array[0]); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av00_64_epi8); // 0~0,1~1 @@ -98,8 +98,8 @@ accumulate_blklen32_r2c1blk4_avx512( const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); __m512 scale_a1b_16_ps = _mm512_broadcast_f32x4(scale_a1b_ps); // 0123012301230123 - //__m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); - __m512i idx = _mm512_loadu_epi8(&index_array[0]); + __m512i idx = _mm512_set_epi32(3, 3, 1, 1, 3, 3, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0); + // __m512i idx = _mm512_loadu_epi8(&index_array[0]); scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 0022002211331133 const __m512i dot0_32_epi16 = _mm512_maddubs_epi16(bv0_64_epi8, av10_64_epi8); // 0~0,1~1 diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 7c9828c6b979..45c3963365e6 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -390,7 +390,7 @@ Q4Int8GemmM1C1BlkLen32Avx2( auto* SumPtr = C; const __m256i low_mask = _mm256_set1_epi8(0x0F); - const __m256i bzp8 = _mm256_set1_epi8(8); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); for (size_t n = 0; n < CountN; n++) { const std::byte* QuantAPtr = QuantA; const float* QuantAScalePtr = QuantAScale; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index d76b687145b9..fa44cf7241e1 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -201,6 +201,7 @@ Q4Int8GemmM1C1BlkLen64Avx2( assert(CountN < NCols4); const __m256i low_mask = _mm256_set1_epi8(0x0F); + [[maybe_unused]] const __m256i bzp8 = _mm256_set1_epi8(8); const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; @@ -213,15 +214,14 @@ Q4Int8GemmM1C1BlkLen64Avx2( const float* QuantAScalePtr = QuantAScale; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; + const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; __m256 acc0 = _mm256_setzero_ps(); for (size_t k = 0; k < BlockCountK; ++k) { - const bool is_lower_half_byte_zp = (k % 2) == 0; + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); - const __m256i bzp8 = _mm256_set1_epi8(8); if constexpr (HasZeroPoint) { accumulate_blklen64_r1c1blk1_zp_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, QuantBZeroPointPtr, is_lower_half_byte_zp, acc0, low_mask); From a6fd378880077a5a464c4c3fb42cc05bed41a89b Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 20:22:18 +0000 Subject: [PATCH 35/41] unused zp, etc. Signed-off-by: liqunfu --- .../core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index fa44cf7241e1..e9c3812bde89 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -125,7 +125,7 @@ Q4Int8GemmM1C4BlkLen64Avx2( __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; for (size_t k = 0; k < BlockCountK; ++k) { - const bool is_lower_half_byte_zp = (k % 2) == 0; + [[maybe_unused]] const bool is_lower_half_byte_zp = (k % 2) == 0; for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); From c875e5c90d22eff3c68d2cd3f1ab075fce043bb0 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 20:43:12 +0000 Subject: [PATCH 36/41] remove test code changes Signed-off-by: liqunfu --- .../test/mlas/bench/bench_sqnbitgemm.cpp | 6 +- .../test/mlas/unittest/test_fgemm_fixture.h | 1 - .../test/mlas/unittest/test_sqnbitgemm.cpp | 86 ------------------- 3 files changed, 3 insertions(+), 90 deletions(-) diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 6f2e1e94e1b0..73c78b8cc3d4 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -119,9 +119,9 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgsProduct({ {16, 32, 64, 128, 256}, // BlkLen {1, 1024, 2048}, // M - {48, 2560, 4096, 11008}, // N - {4096, 2560, 10240, 11008}, // K - {1, 8, 64}, // Threads + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads {int64_t{false}, int64_t{true}}, // Symmetric {int64_t{false}, int64_t{true}}, // HasBias {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index 57d5ed2167fc..53b3edafdf84 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -70,7 +70,6 @@ class FgemmShortExecuteTest : public MlasTestFixture Date: Tue, 30 Jul 2024 20:46:41 +0000 Subject: [PATCH 37/41] remove test code changes Signed-off-by: liqunfu --- onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 4fff85cbdb5a..0710981fa17c 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -396,7 +396,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Tue, 30 Jul 2024 21:02:10 +0000 Subject: [PATCH 38/41] lint Signed-off-by: liqunfu --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 0c0f84303cee..bb222a5b5654 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -210,7 +210,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#else // defined(ORT_NEURAL_SPEED) +#else // defined(ORT_NEURAL_SPEED) const auto compute_type = static_cast(accuracy_level_); if (input_idx == InputIndex::B) { if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { From 52fc7fa886cc744c3c18c2317db4f863febf875d Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 22:52:48 +0000 Subject: [PATCH 39/41] lint Signed-off-by: liqunfu --- .../cpu/quantization/matmul_nbits.cc | 2 +- onnxruntime/core/mlas/inc/mlas_qnbit.h | 23 ++++++++++--------- .../test/contrib_ops/matmul_4bits_test.cc | 10 ++++---- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bb222a5b5654..b81682450f76 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -349,7 +349,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { IAllocatorUniquePtr workspace{}; const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( - M, N, K, batch_count, nbits_, block_size_, compute_type); + M, N, K, batch_count, nbits_, block_size_, compute_type); if (workspace_size > 0) { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 1ad270cc2928..0f799cae8289 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -22,7 +22,6 @@ Module Name: #include "mlas.h" #include "mlas_gemm_postprocessor.h" -#include /** * @brief Define compute types of block quantization, in order of decreasing accuracy. @@ -47,7 +46,7 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { const float* A = nullptr; ///< address of A (float32 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; + const std::byte* PackedQuantBData = nullptr; const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block @@ -57,7 +56,6 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; - std::string node_name = ""; }; /** @@ -171,17 +169,20 @@ MlasSQNBitGemmPackQuantBDataSize( * * Because ORT OpKernel::PrePack is called for each input (in this case, QuantBData, * QuantBScale, and QuantBZeroPoint) separately, this function may be called 3 times, first with QuantBData, - * and then QuantBScale and QuantBZeroPoint. The second time the function is called with QuantBScale, + * and then QuantBScale and QuantBZeroPoint. When the function is called with QuantBScale without QuantBZeroPoint, * BlkSum is computed with default zero point 8 and stored at the second part of PackedQuantBDataAndOrBlkSum. * If there is a third call with QuantBZeroPoint, BlkSum is recomputed/adjusted with provided zeropoint. * - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) - * @param[in] BlkLen number of quantized values per block - * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) - * @param[in] QuantBData quantized B data - * @param[out] PackedQuantBData packed quantized B data + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum + * @param[in] QuantBScale quantized B scale + * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] QuantBZeroPoint quantized B zero point * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3b85939cb471..548f24e8ac69 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -263,11 +263,11 @@ void RunTest(const TestOptions& opts, } // namespace TEST(MatMulNBits, Float32) { - //onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); for (auto M : {1, 2, 100}) { - for (auto N : {/*2560, */1, 2, 32, 288}) { - for (auto K : {/*2560, */16, 32, 64, 128, 256, 1024, 93, 1234 }) { - for (auto block_size : {16, 32, 64, 128 }) { + for (auto N : {/*2560, */ 1, 2, 32, 288}) { + for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { for (auto accuracy_level : {0, 1, 4}) { TestOptions base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; @@ -280,7 +280,7 @@ TEST(MatMulNBits, Float32) { { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { From 0933a6b81f2c336bb53ca6e6622ff0d5fc11e3c1 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Tue, 30 Jul 2024 23:02:11 +0000 Subject: [PATCH 40/41] code name Signed-off-by: liqunfu --- onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index b81682450f76..7c7dc18f69c6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -371,7 +371,6 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { data[i].Bias = bias_data; data[i].C = y_data + helper.OutputOffsets()[i]; data[i].ldc = N; - data[i].node_name = this->Node().Name(); } MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), thread_pool); From 2b35c82059ea557373b208169ec2f5e2c5f99646 Mon Sep 17 00:00:00 2001 From: liqunfu Date: Wed, 31 Jul 2024 06:26:22 +0000 Subject: [PATCH 41/41] update reviewers' comments Signed-off-by: liqunfu --- .../cpu/quantization/matmul_nbits.cc | 7 +------ onnxruntime/core/mlas/inc/mlas_qnbit.h | 20 +++++++++--------- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 21 +++---------------- onnxruntime/core/mlas/lib/sqnbitgemm.h | 11 +++++----- .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 4 ++-- .../mlas/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- .../mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp | 2 +- 7 files changed, 24 insertions(+), 43 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 7c7dc18f69c6..5fdd2b017b8a 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -211,6 +211,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } #else // defined(ORT_NEURAL_SPEED) + ORT_UNUSED_PARAMETER(prepacked_weights); const auto compute_type = static_cast(accuracy_level_); if (input_idx == InputIndex::B) { if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { @@ -223,12 +224,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get(), nullptr, has_zp_input_, nullptr, nullptr); - if (prepacked_weights) { - // TODO: cannot use packed_b_ after following code with std::move. - assert(false); - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } is_packed = true; } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 0f799cae8289..232bf2261ef4 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -43,16 +43,16 @@ typedef enum { * @brief Data parameters for float/n-bit quantized int GEMM routine. */ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + const float* A = nullptr; ///< address of A (float32 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const float* Bias = nullptr; ///< optional address of Bias, vector size N + float* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 7ab9b23d3467..a45494ef2e04 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -82,7 +82,7 @@ MlasIsSQNBitGemmAvailable( case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || - (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8_2 != nullptr); + (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } default: { return false; @@ -103,8 +103,6 @@ SQNBitGemmPerGemmWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; if (Dispatch == nullptr) { return 0; @@ -439,17 +437,6 @@ SQ4BitGemm_CompInt8( const size_t RangeCountN ) { -//#ifdef MLAS_TARGET_AMD64_IX86 -// if (RangeCountM != 1) { -// // perf experiment shows fp32 is faster than int8 in M > 1 cases. -// // route to fp32 compute before int8 compute is improved. -// SQ4BitGemm_CompFp32( -// BlkLen, -// K, DataParams, per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN -// ); -// return; -// } -//#endif #ifdef MLAS_TARGET_AMD64_IX86 PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); constexpr size_t BlkBitWidth = 4; @@ -595,11 +582,12 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(N); const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8_2; + const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + // TODO: try parallel on BatchN * M threads because BatchN is usually 1. if (QuantizeARow) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; @@ -665,8 +653,6 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - //auto start_batch = std::chrono::high_resolution_clock::now(); // Start timing here - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); @@ -695,7 +681,6 @@ MlasSQNBitGemmBatch( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); if (ThreadPool == nullptr) { - //auto start = std::chrono::high_resolution_clock::now(); // Start timing here for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index 56af89a6cd1d..2da336ca2f0e 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -53,7 +53,6 @@ struct PackedQuantBDataStruct { // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - //const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); // _mm256_load_si256 requires alignment on a 32-byte boundary @@ -238,7 +237,6 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. - * This kernel handles the special case where M, the number of rows of A and C, is 1. * * @param BlkLen Number of values in a block. * @param QuantA Supplies the quantized A matrix. @@ -249,8 +247,11 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[out] C Supplies the output C matrix. * @param CountN Number of columns of B and C. * @param CountK Number of columns of A and rows of B. - * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. */ typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( size_t BlkLen, @@ -327,7 +328,7 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; - typedef void(QuantizeARow_CompInt8_Fn2)( + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( size_t BlkLen, const float* A, size_t CountK, @@ -335,5 +336,5 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { float* QuantAScale, float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); - QuantizeARow_CompInt8_Fn2* QuantizeARow_CompInt8_2 = nullptr; + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; }; diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 0ed41b150a0e..55d86bb9cc18 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -1299,7 +1299,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; - d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); @@ -1318,7 +1318,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; - d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx2; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp index 6289242ac054..13bd369a065b 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp @@ -366,7 +366,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; - d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx512; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp index 37451413b8c4..6a5c01162c51 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -351,7 +351,7 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; - d.QuantizeARow_CompInt8_2 = QuantizeARow_CompInt8_avx512; + d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; }();