Skip to content

Commit

Permalink
Add q8_q4_512 native simd
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 19, 2023
1 parent 48cd7e4 commit 8e1f24c
Showing 1 changed file with 84 additions and 4 deletions.
88 changes: 84 additions & 4 deletions jlama-native/src/main/c/vector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,6 @@ float dot_product_f32_q4_512(const float* a, int aoffset, const float *bf, const
//Subtract 8 from each int
__m128i eight = _mm_set1_epi8(8);
first_4bits0 = _mm_sub_epi8(first_4bits0, eight);

last_4bits0 = _mm_sub_epi8(last_4bits0, eight);

// Extend these bytes to 32-bit integers (low and high)
Expand Down Expand Up @@ -885,10 +884,91 @@ float dot_product_q8_q4_256(const float *af, const char* a, int aoffset, const f
return dot;
}

float dot_product_q8_q4_512(const float *af, const char* a, int aoffset, const float *bf, const char* b, int boffset, int length) {
#if defined(__AVX512F__)
__m512 sum = _mm512_setzero_ps();

int ao = aoffset;
int bo = boffset;
int numBlocks = length / Q4_BLOCK_SIZE;

// Mask to keep the first 4 bits of each byte
__m128i mask_first_4bits = _mm_set1_epi8(0xF);
//Subtract 8 from each byte to get signed values
__m128i eight = _mm_set1_epi8(8);

__attribute__((aligned(16))) float scalef[16];

//First take the scaling factors of both tensors and multiply them in SIMD
for (int i = 0; i < numBlocks; i += 16) { //512bits == 16floats
// Load float32
__m512 ablock = _mm512_loadu_ps(af + (ao / Q4_BLOCK_SIZE));
__m512 bblock = _mm512_loadu_ps(bf + ((bo*2) / Q4_BLOCK_SIZE));
__m512 scaled = _mm512_mul_ps(ablock, bblock);
_mm512_store_ps(scalef, scaled);

// perform a block at a time
for(int j = 0; j < 16; j++, ao += 32, bo += 16) {
// broadcast the float32 version of 'factor' to all elements
__m512 scale_f32 = _mm512_set1_ps(scalef[j]);

// Load 8 bytes into a 128-bit integer register
__m128i int_vb0 = _mm_load_si128((__m128i const*)(b + bo)); // Load 128 bits

// Masked values
__m128i first_4bits0 = _mm_and_si128(int_vb0, mask_first_4bits);

// Shift first 4 bits to rightmost positions
__m128i last_4bits0 = _mm_srli_epi16(int_vb0, 4);
last_4bits0 = _mm_and_si128(last_4bits0, mask_first_4bits);

//Subtract 8 from each int
first_4bits0 = _mm_sub_epi8(first_4bits0, eight);
last_4bits0 = _mm_sub_epi8(last_4bits0, eight);

// Extend these bytes to 32-bit integers (low and high)
__m512i int_vb_ext_lo0 = _mm512_cvtepi8_epi32(first_4bits0);
__m512i int_vb_ext_hi0 = _mm512_cvtepi8_epi32(last_4bits0);

// Load 16 bytes into 2 128-bit integer registers
__m128i int_va0 = _mm_load_si128((__m128i const*)(a + ao));
__m128i int_va1 = _mm_load_si128((__m128i const*)(a + ao + 16));

//Extend to 32-bit ints
__m512i int_va0_ext = _mm512_cvtepi8_epi32(int_va0);
__m512i int_va1_ext = _mm512_cvtepi8_epi32(int_va1);

// Multiply the 32-bit integers
__m512i isum = _mm512_mullo_epi32(int_va0_ext, int_vb_ext_lo0);
isum = _mm512_add_epi32(_mm512_mullo_epi32(int_va1_ext, int_vb_ext_hi0), isum);

// Convert these 32-bit integers to floats
__m512 fsum = _mm512_cvtepi32_ps(isum);

// Multiply and accumulate
sum = _mm512_fmadd_ps(scale_f32, fsum, sum);
}
}

// Horizontal sum of the vector to get dot product
__attribute__((aligned(16))) float result[16];
_mm512_store_ps(result, sum);

float dot = 0.0;
for(int i = 0; i < 16; ++i) {
dot += result[i];
}

return dot;
#else
return dot_product_q8_q4_256(af, a, aoffset, bf, b, boffset, length);
#endif
}

float dot_product_q8_q4(int flags, const float* af, const char* a, int aoffset, const float *bf, const char* b, int boffset, int length) {
return //((flags & HAS_AVX2) != 0)
//? dot_product_f32_q4_512(a, aoffset, bf, b, boffset, length)
dot_product_q8_q4_256(af, a, aoffset, bf, b, boffset, length);
return ((flags & HAS_AVX2) != 0)
? dot_product_q8_q4_512(af, a, aoffset, bf, b, boffset, length)
: dot_product_q8_q4_256(af, a, aoffset, bf, b, boffset, length);
}

void dot_product_q8_q4_chunked(int flags, float *r, const float* af, const char *a, int aoffset, const float *bf, const char* b, int boffset, int length, int bchunkstart, int bchunksize) {
Expand Down

0 comments on commit 8e1f24c

Please sign in to comment.