From 05fc0c60cadab3652869226d2f2aa0bf7d9eda7d Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Thu, 18 Jul 2024 13:37:29 -0700 Subject: [PATCH] [MLAS][AArch64] SQNBitGemm CompInt8 - Use 4x2 tiles (#21380) Update SQNBitGemm ARM NEON kernel to compute 4x2 tile of output. Note: Also tried 2x4 and 4x4 tiles but observed the best microbenchmark results with 4x2 tiles. --- .../mlas/lib/sqnbitgemm_kernel_neon_int8.cpp | 402 +++++++++++------- 1 file changed, 244 insertions(+), 158 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp index db3b9ee65659..ec5cdbc75220 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -155,7 +155,7 @@ namespace template MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( +SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( const std::byte* QuantARowPtr, const std::byte* QuantBDataColPtr, const float* QuantBScaleColPtr, @@ -177,11 +177,13 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { const std::byte* QuantABlkRow0 = QuantAPtr; const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; + const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; const float QuantBScaleCol0 = *QuantBScalePtr; const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); @@ -191,6 +193,10 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; + const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; + const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; + const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; // load B zero point int8_t bzp_col0; @@ -212,13 +218,11 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); + const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); // TODO handling only 16 elements per accumulator at a time here, probably can do better { - // load A - const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); - const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); - // load B const uint8x8_t bv_packed_col0 = vld1_u8(reinterpret_cast(QuantBDataPtr)); const uint8x8_t bv_packed_col1 = vld1_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); @@ -242,24 +246,55 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( bv_col0 = vsubq_s8(bv_col0, vdupq_n_s8(bzp_col0)); bv_col1 = vsubq_s8(bv_col1, vdupq_n_s8(bzp_col1)); - // quantized dot product - int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; - dot00 = vdotq_s32(dot00, av_row0, bv_col0); - dot01 = vdotq_s32(dot01, av_row0, bv_col1); - dot10 = vdotq_s32(dot10, av_row1, bv_col0); - dot11 = vdotq_s32(dot11, av_row1, bv_col1); - - // convert to float - const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); - const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); - const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); - const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + // rows 0 and 1 of A + { + // load A + const int8x16_t av_row0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row1 = vld1q_s8(QuantADataPtrRow1 + 0); + + // quantized dot product + const int32x4_t dot00 = vdotq_s32(int32x4_t{}, av_row0, bv_col0); + const int32x4_t dot01 = vdotq_s32(int32x4_t{}, av_row0, bv_col1); + const int32x4_t dot10 = vdotq_s32(int32x4_t{}, av_row1, bv_col0); + const int32x4_t dot11 = vdotq_s32(int32x4_t{}, av_row1, bv_col1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } - // multiply by scale and update accumulator - acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); - acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); - acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); - acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + // rows 2 and 3 of A + { + // load A + const int8x16_t av_row2 = vld1q_s8(QuantADataPtrRow2 + 0); + const int8x16_t av_row3 = vld1q_s8(QuantADataPtrRow3 + 0); + + // quantized dot product + const int32x4_t dot20 = vdotq_s32(int32x4_t{}, av_row2, bv_col0); + const int32x4_t dot21 = vdotq_s32(int32x4_t{}, av_row2, bv_col1); + const int32x4_t dot30 = vdotq_s32(int32x4_t{}, av_row3, bv_col0); + const int32x4_t dot31 = vdotq_s32(int32x4_t{}, av_row3, bv_col1); + + // convert to float + const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); + const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); + const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); + const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); + + // multiply by scale and update accumulator + acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); + acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); + acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); + acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); + } } // increment block pointers @@ -273,22 +308,30 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( } } - SumPtr[0] = vaddvq_f32(acc00); - SumPtr[1] = vaddvq_f32(acc01); - SumPtr[ldc + 0] = vaddvq_f32(acc10); - SumPtr[ldc + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); + SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); + SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); + SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); + SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); + SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); + SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); if (BiasPtr != nullptr) { - SumPtr[0] += BiasPtr[0]; - SumPtr[1] += BiasPtr[1]; - SumPtr[ldc + 0] += BiasPtr[0]; - SumPtr[ldc + 1] += BiasPtr[1]; + SumPtr[ldc * 0 + 0] += BiasPtr[0]; + SumPtr[ldc * 0 + 1] += BiasPtr[1]; + SumPtr[ldc * 1 + 0] += BiasPtr[0]; + SumPtr[ldc * 1 + 1] += BiasPtr[1]; + SumPtr[ldc * 2 + 0] += BiasPtr[0]; + SumPtr[ldc * 2 + 1] += BiasPtr[1]; + SumPtr[ldc * 3 + 0] += BiasPtr[0]; + SumPtr[ldc * 3 + 1] += BiasPtr[1]; } } template MLAS_FORCEINLINE void -SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( +SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( size_t BlkLen, const std::byte* QuantARowPtr, const std::byte* QuantBDataColPtr, @@ -312,11 +355,13 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( const float* QuantBScalePtr = QuantBScaleColPtr; const std::byte* QuantBZeroPointPtr = QuantBZeroPointColPtr; - float32x4_t acc00{}, acc01{}, acc10{}, acc11{}; + float32x4_t acc00{}, acc01{}, acc10{}, acc11{}, acc20{}, acc21{}, acc30{}, acc31{}; for (size_t k_blk_idx = 0; k_blk_idx < BlockCountK; ++k_blk_idx) { const std::byte* QuantABlkRow0 = QuantAPtr; const std::byte* QuantABlkRow1 = QuantAPtr + StrideQuantA; + const std::byte* QuantABlkRow2 = QuantAPtr + StrideQuantA * 2; + const std::byte* QuantABlkRow3 = QuantAPtr + StrideQuantA * 3; const float QuantBScaleCol0 = *QuantBScalePtr; const float QuantBScaleCol1 = *(QuantBScalePtr + StrideQuantBScale); @@ -326,6 +371,10 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( const float scale01 = Q8BlkScale(QuantABlkRow0) * QuantBScaleCol1; const float scale10 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol0; const float scale11 = Q8BlkScale(QuantABlkRow1) * QuantBScaleCol1; + const float scale20 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol0; + const float scale21 = Q8BlkScale(QuantABlkRow2) * QuantBScaleCol1; + const float scale30 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol0; + const float scale31 = Q8BlkScale(QuantABlkRow3) * QuantBScaleCol1; // load B zero point int8_t bzp_col0; @@ -347,14 +396,10 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( const int8_t* QuantADataPtrRow0 = Q8BlkData(QuantABlkRow0); const int8_t* QuantADataPtrRow1 = Q8BlkData(QuantABlkRow1); + const int8_t* QuantADataPtrRow2 = Q8BlkData(QuantABlkRow2); + const int8_t* QuantADataPtrRow3 = Q8BlkData(QuantABlkRow3); for (size_t sub_blk_idx = 0; sub_blk_idx < SubBlksPerBlk; ++sub_blk_idx) { - // load A - const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); - const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); - const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); - const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); - // load B const uint8x16_t bv_packed_col0 = vld1q_u8(reinterpret_cast(QuantBDataPtr)); const uint8x16_t bv_packed_col1 = vld1q_u8(reinterpret_cast(QuantBDataPtr) + StrideQuantBData); @@ -372,28 +417,65 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( bv_col1_0 = vsubq_s8(bv_col1_0, vdupq_n_s8(bzp_col1)); bv_col1_1 = vsubq_s8(bv_col1_1, vdupq_n_s8(bzp_col1)); - // quantized dot product - int32x4_t dot00{}, dot01{}, dot10{}, dot11{}; - dot00 = vdotq_s32(vdotq_s32(dot00, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); - dot01 = vdotq_s32(vdotq_s32(dot01, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); - dot10 = vdotq_s32(vdotq_s32(dot10, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); - dot11 = vdotq_s32(vdotq_s32(dot11, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); - - // convert to float - const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); - const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); - const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); - const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + // rows 0 and 1 of A + { + // load A + const int8x16_t av_row0_0 = vld1q_s8(QuantADataPtrRow0 + 0); + const int8x16_t av_row0_1 = vld1q_s8(QuantADataPtrRow0 + 16); + const int8x16_t av_row1_0 = vld1q_s8(QuantADataPtrRow1 + 0); + const int8x16_t av_row1_1 = vld1q_s8(QuantADataPtrRow1 + 16); + + // quantized dot product + const int32x4_t dot00 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col0_0), av_row0_1, bv_col0_1); + const int32x4_t dot01 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row0_0, bv_col1_0), av_row0_1, bv_col1_1); + const int32x4_t dot10 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col0_0), av_row1_1, bv_col0_1); + const int32x4_t dot11 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row1_0, bv_col1_0), av_row1_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_00 = vcvtq_f32_s32(dot00); + const float32x4_t dot_f32_01 = vcvtq_f32_s32(dot01); + const float32x4_t dot_f32_10 = vcvtq_f32_s32(dot10); + const float32x4_t dot_f32_11 = vcvtq_f32_s32(dot11); + + // multiply by scale and update accumulator + acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); + acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); + acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); + acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + } - // multiply by scale and update accumulator - acc00 = vfmaq_f32(acc00, dot_f32_00, vdupq_n_f32(scale00)); - acc01 = vfmaq_f32(acc01, dot_f32_01, vdupq_n_f32(scale01)); - acc10 = vfmaq_f32(acc10, dot_f32_10, vdupq_n_f32(scale10)); - acc11 = vfmaq_f32(acc11, dot_f32_11, vdupq_n_f32(scale11)); + // rows 2 and 3 of A + { + // load A + const int8x16_t av_row2_0 = vld1q_s8(QuantADataPtrRow2 + 0); + const int8x16_t av_row2_1 = vld1q_s8(QuantADataPtrRow2 + 16); + const int8x16_t av_row3_0 = vld1q_s8(QuantADataPtrRow3 + 0); + const int8x16_t av_row3_1 = vld1q_s8(QuantADataPtrRow3 + 16); + + // quantized dot product + const int32x4_t dot20 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col0_0), av_row2_1, bv_col0_1); + const int32x4_t dot21 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row2_0, bv_col1_0), av_row2_1, bv_col1_1); + const int32x4_t dot30 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col0_0), av_row3_1, bv_col0_1); + const int32x4_t dot31 = vdotq_s32(vdotq_s32(int32x4_t{}, av_row3_0, bv_col1_0), av_row3_1, bv_col1_1); + + // convert to float + const float32x4_t dot_f32_20 = vcvtq_f32_s32(dot20); + const float32x4_t dot_f32_21 = vcvtq_f32_s32(dot21); + const float32x4_t dot_f32_30 = vcvtq_f32_s32(dot30); + const float32x4_t dot_f32_31 = vcvtq_f32_s32(dot31); + + // multiply by scale and update accumulator + acc20 = vfmaq_f32(acc20, dot_f32_20, vdupq_n_f32(scale20)); + acc21 = vfmaq_f32(acc21, dot_f32_21, vdupq_n_f32(scale21)); + acc30 = vfmaq_f32(acc30, dot_f32_30, vdupq_n_f32(scale30)); + acc31 = vfmaq_f32(acc31, dot_f32_31, vdupq_n_f32(scale31)); + } // increment block data pointers to next sub-block QuantADataPtrRow0 += 32; QuantADataPtrRow1 += 32; + QuantADataPtrRow2 += 32; + QuantADataPtrRow3 += 32; QuantBDataPtr += 16; } @@ -407,16 +489,24 @@ SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( } } - SumPtr[0] = vaddvq_f32(acc00); - SumPtr[1] = vaddvq_f32(acc01); - SumPtr[ldc + 0] = vaddvq_f32(acc10); - SumPtr[ldc + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 0 + 0] = vaddvq_f32(acc00); + SumPtr[ldc * 0 + 1] = vaddvq_f32(acc01); + SumPtr[ldc * 1 + 0] = vaddvq_f32(acc10); + SumPtr[ldc * 1 + 1] = vaddvq_f32(acc11); + SumPtr[ldc * 2 + 0] = vaddvq_f32(acc20); + SumPtr[ldc * 2 + 1] = vaddvq_f32(acc21); + SumPtr[ldc * 3 + 0] = vaddvq_f32(acc30); + SumPtr[ldc * 3 + 1] = vaddvq_f32(acc31); if (BiasPtr != nullptr) { - SumPtr[0] += BiasPtr[0]; - SumPtr[1] += BiasPtr[1]; - SumPtr[ldc + 0] += BiasPtr[0]; - SumPtr[ldc + 1] += BiasPtr[1]; + SumPtr[ldc * 0 + 0] += BiasPtr[0]; + SumPtr[ldc * 0 + 1] += BiasPtr[1]; + SumPtr[ldc * 1 + 0] += BiasPtr[0]; + SumPtr[ldc * 1 + 1] += BiasPtr[1]; + SumPtr[ldc * 2 + 0] += BiasPtr[0]; + SumPtr[ldc * 2 + 1] += BiasPtr[1]; + SumPtr[ldc * 3 + 0] += BiasPtr[0]; + SumPtr[ldc * 3 + 1] += BiasPtr[1]; } } @@ -478,8 +568,8 @@ SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( bv1 = vsubq_s8(bv1, bzp1); // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); - const int32x4_t dot1 = vdotq_s32(vdupq_n_s32(0), av1, bv1); + const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); + const int32x4_t dot1 = vdotq_s32(int32x4_t{}, av1, bv1); // convert to float const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); @@ -527,7 +617,7 @@ SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( bv0 = vsubq_s8(bv0, bzp0); // quantized dot product - const int32x4_t dot0 = vdotq_s32(vdupq_n_s32(0), av0, bv0); + const int32x4_t dot0 = vdotq_s32(int32x4_t{}, av0, bv0); // convert to float const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); @@ -604,9 +694,8 @@ SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( bv_hi1 = vsubq_s8(bv_hi1, bzp1); // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); - dot1 = vdotq_s32(vdotq_s32(dot1, av_lo1, bv_lo1), av_hi1, bv_hi1); + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); + const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo1, bv_lo1), av_hi1, bv_hi1); // convert to float const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); @@ -652,8 +741,7 @@ SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( bv_hi0 = vsubq_s8(bv_hi0, bzp0); // quantized dot product - int32x4_t dot0{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av_lo0, bv_lo0), av_hi0, bv_hi0); + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av_lo0, bv_lo0), av_hi0, bv_hi0); // convert to float const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); @@ -736,9 +824,8 @@ SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( bv3 = vsubq_s8(bv3, bzp); // quantized dot product - int32x4_t dot0{}, dot1{}; - dot0 = vdotq_s32(vdotq_s32(dot0, av0, bv0), av1, bv1); - dot1 = vdotq_s32(vdotq_s32(dot1, av2, bv2), av3, bv3); + const int32x4_t dot0 = vdotq_s32(vdotq_s32(int32x4_t{}, av0, bv0), av1, bv1); + const int32x4_t dot1 = vdotq_s32(vdotq_s32(int32x4_t{}, av2, bv2), av3, bv3); // convert to float const float32x4_t dot_f32_0 = vcvtq_f32_s32(dot0); @@ -834,7 +921,7 @@ SQ4BitGemmKernel_CompInt8_BlkLen16( float* SumRowPtr = C; size_t m_remaining = CountM; - while (m_remaining > 1) { + while (m_remaining > 3) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -845,8 +932,8 @@ SQ4BitGemmKernel_CompInt8_BlkLen16( size_t n_remaining = CountN; while (n_remaining > 1) { - // Compute 2x2 tiles of output - SQ4BitGemm_CompInt8_Compute2x2_BlkLen16( + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLen16( QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, @@ -871,38 +958,30 @@ SQ4BitGemmKernel_CompInt8_BlkLen16( } if (n_remaining > 0) { - // Compute last 2x1 tile of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( - QuantARowPtr + StrideQuantA, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc, - BlockCountK - ); + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLen16( + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } } - // Move to next 2 rows - AdvanceRowPtrs<2>( + // Move to next 4 rows + AdvanceRowPtrs<4>( StrideQuantA, ldc, QuantARowPtr, SumRowPtr ); - m_remaining -= 2; + m_remaining -= 4; } - if (m_remaining > 0) { + while (m_remaining > 0) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -932,6 +1011,14 @@ SQ4BitGemmKernel_CompInt8_BlkLen16( n_remaining -= 1; } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; } } @@ -964,7 +1051,7 @@ SQ4BitGemmKernel_CompInt8_BlkLen32( float* SumRowPtr = C; size_t m_remaining = CountM; - while (m_remaining > 1) { + while (m_remaining > 3) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -975,8 +1062,8 @@ SQ4BitGemmKernel_CompInt8_BlkLen32( size_t n_remaining = CountN; while (n_remaining > 1) { - // Compute 2x2 tiles of output - SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( BlkLen, QuantARowPtr, QuantBDataColPtr, @@ -1002,38 +1089,30 @@ SQ4BitGemmKernel_CompInt8_BlkLen32( } if (n_remaining > 0) { - // Compute last 2x1 tile of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( - QuantARowPtr + StrideQuantA, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc, - BlockCountK - ); + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLen32( + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } } - // Move to next 2 rows - AdvanceRowPtrs<2>( + // Move to next 4 rows + AdvanceRowPtrs<4>( StrideQuantA, ldc, QuantARowPtr, SumRowPtr ); - m_remaining -= 2; + m_remaining -= 4; } - if (m_remaining > 0) { + while (m_remaining > 0) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -1063,6 +1142,14 @@ SQ4BitGemmKernel_CompInt8_BlkLen32( n_remaining -= 1; } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; } } @@ -1095,7 +1182,7 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( float* SumRowPtr = C; size_t m_remaining = CountM; - while (m_remaining > 1) { + while (m_remaining > 3) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -1106,8 +1193,8 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( size_t n_remaining = CountN; while (n_remaining > 1) { - // Compute 2x2 tiles of output - SQ4BitGemm_CompInt8_Compute2x2_BlkLenGreaterThan16( + // Compute 4x2 tiles of output + SQ4BitGemm_CompInt8_Compute4x2_BlkLenGreaterThan16( BlkLen, QuantARowPtr, QuantBDataColPtr, @@ -1133,40 +1220,31 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( } if (n_remaining > 0) { - // Compute last 2x1 tile of output - SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( - BlkLen, - QuantARowPtr, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr, - BlockCountK - ); - - SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( - BlkLen, - QuantARowPtr + StrideQuantA, - QuantBDataColPtr, - QuantBScaleColPtr, - QuantBZeroPointColPtr, - BiasPtr, - SumPtr + ldc, - BlockCountK - ); + // Compute last 4x1 tile of output + for (size_t i = 0; i < 4; ++i) { + SQ4BitGemm_CompInt8_Compute1x1_BlkLenGreaterThan32( + BlkLen, + QuantARowPtr + StrideQuantA * i, + QuantBDataColPtr, + QuantBScaleColPtr, + QuantBZeroPointColPtr, + BiasPtr, + SumPtr + ldc * i, + BlockCountK + ); + } } - // Move to next 2 rows - AdvanceRowPtrs<2>( + // Move to next 4 rows + AdvanceRowPtrs<4>( StrideQuantA, ldc, QuantARowPtr, SumRowPtr ); - m_remaining -= 2; + m_remaining -= 4; } - if (m_remaining > 0) { + while (m_remaining > 0) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; @@ -1197,6 +1275,14 @@ SQ4BitGemmKernel_CompInt8_BlkLenGreaterThan32( n_remaining -= 1; } + + // Move to next row + AdvanceRowPtrs<1>( + StrideQuantA, ldc, + QuantARowPtr, SumRowPtr + ); + + m_remaining -= 1; } }