From 9fec8570b29f2601790f34c4e5d5f48472e1fe09 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Tue, 16 Jul 2024 20:51:11 -0400 Subject: [PATCH] BF16 avx512 native --- .../jlama/model/CausalSelfAttention.java | 2 - jlama-native/src/main/c/vector_simd.c | 62 ++++++++++++----- .../operations/cnative/constants$2.java | 67 +++++++++++++++++++ 3 files changed, 112 insertions(+), 19 deletions(-) create mode 100644 jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java index 6b8057b..aa7cbec 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/CausalSelfAttention.java @@ -40,7 +40,6 @@ public class CausalSelfAttention { private final AbstractTensor outputProjectionWeights; - private final Optional preAttentionScale; private final float attentionScale; private final AbstractTensor[] qkvResults; @@ -109,7 +108,6 @@ public CausalSelfAttention( this.outputProjectionWeights = outputProjectionWeights; this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize)); - this.preAttentionScale = Optional.of((float)Math.pow(c.headSize, -0.5)); this.qkvResults = new AbstractTensor[3]; this.qkvWeights = new AbstractTensor[] {queryAttnWeights, keyAttnWeights, valueAttnWeights}; diff --git a/jlama-native/src/main/c/vector_simd.c b/jlama-native/src/main/c/vector_simd.c index 2c52e16..a953c2a 100644 --- a/jlama-native/src/main/c/vector_simd.c +++ b/jlama-native/src/main/c/vector_simd.c @@ -1025,7 +1025,7 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R } } -/*void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) { +void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) { #if defined(__AVX512F__) int ytiles = (m - m0) / RM; int xtiles = (n - n0) / RN; @@ -1048,15 +1048,38 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R for (int ni = 0; ni < RN; ++ni) { int ao = params.aoffset; int bo = params.boffset; - for(int j = 0; j < params.k; j += 16, ao += 16, bo += 16) { // 512bits == 16floats - // Load float32 - __m512 vb = _mm512_loadu_ps(params.bf + params.ldb * (jj + ni) + bo); + for(int j = 0; j < params.k; j += 32, ao += 32, bo += 32) { // 512bits == 32bfloats + // Load shorts + __m512i vb = _mm512_loadu_si512((__m512i*)(params.bs + params.ldb * (jj + ni) + bo)); + + // Extract lower 8 shorts and convert to int (lower 128 bits) + __m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 0)); + // Shift left 16 bits and convert to float + __m512 vb0 = _mm512_castsi512_ps(_mm512_slli_epi32(vb0i, 16)); + + // Extract lower 8 shorts and convert to int (upper 128 bits) + __m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 1)); + // Shift left 16 bits and convert to float + __m512 vb1 = _mm512_castsi512_ps(_mm512_slli_epi32(vb1i, 16)); for (int mi = 0; mi < RM; ++mi) { - __m512 va = _mm512_loadu_ps(params.af + params.lda * (ii + mi) + ao); + // Load shorts + __m512i va = _mm512_loadu_si512((__m512i*)(params.as + params.lda * (jj + ni) + ao)); + + // Extract lower 8 shorts and convert to int (lower 128 bits) + __m512i va0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(va, 0)); + // Shift left 16 bits and convert to float + __m512 va0 = _mm512_castsi512_ps(_mm512_slli_epi32(va0i, 16)); + + // Extract lower 8 shorts and convert to int (upper 128 bits) + __m512i va1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(va, 1)); + // Shift left 16 bits and convert to float + __m512 va1 = _mm512_castsi512_ps(_mm512_slli_epi32(va1i, 16)); + // Multiply and accumulate - sums[mi][ni] = _mm512_fmadd_ps(va, vb, sums[mi][ni]); + sums[mi][ni] = _mm512_fmadd_ps(va0, vb0, sums[mi][ni]); + sums[mi][ni] = _mm512_fmadd_ps(va1, vb1, sums[mi][ni]); } } } @@ -1065,14 +1088,17 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R for (int ni = 0; ni < RN; ++ni) { // Horizontal sum of the vector to get dot product float r = _mm512_reduce_add_ps(sums[mi][ni]); - params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r; + if (params.rs != NULL) + params.rs[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = fp32_to_bf16(r); + else + params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r; } } } #else gemm_bf16_256(m0, m, n0, n, RM, RN, params); #endif -}*/ +} #endif //!ARM_NEON void gemm_bf16(int flags, const short *a, int aoffset, const short *b, int boffset, short *rs, float *r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc) @@ -1097,10 +1123,9 @@ void gemm_bf16(int flags, const short *a, int aoffset, const short *b, int boffs }; #if !defined(__ARM_NEON__) - //((flags & HAS_AVX2) != 0) - // ? gemm(0, m, n0, n0 + n, gemm_bf16_512, p) - //: - gemm(0, m, n0, n0 + n, gemm_bf16_256, p); + ((flags & HAS_AVX2) != 0) + ? gemm(0, m, n0, n0 + n, gemm_bf16_512, p) + : gemm(0, m, n0, n0 + n, gemm_bf16_256, p); #else gemm(0, m, n0, n0 + n, gemm_bf16_128_arm, p); #endif @@ -1229,7 +1254,7 @@ void __attribute__((noinline)) gemm_f32_bf16_256(int m0, int m, int n0, int n, i } } -void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) { +void gemm_f32_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) { #if defined(__AVX512F__) int ytiles = (m - m0) / RM; int xtiles = (n - n0) / RN; @@ -1257,12 +1282,12 @@ void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_par __m512i vb = _mm512_loadu_si512((__m512i*)(params.bs + params.ldb * (jj + ni) + bo)); // Extract lower 8 shorts and convert to int (lower 128 bits) - __m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti128_si512(vb, 0)); + __m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 0)); // Shift left 16 bits and convert to float __m512 vb0 = _mm512_castsi512_ps(_mm512_slli_epi32(vb0i, 16)); // Extract lower 8 shorts and convert to int (upper 128 bits) - __m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti128_si512(vb, 1)); + __m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 1)); // Shift left 16 bits and convert to float __m512 vb1 = _mm512_castsi512_ps(_mm512_slli_epi32(vb1i, 16)); @@ -1281,12 +1306,15 @@ void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_par for (int ni = 0; ni < RN; ++ni) { // Horizontal sum of the vector to get dot product float r = _mm512_reduce_add_ps(sums[mi][ni]); - params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r; + if (params.rs != NULL) + params.rs[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = fp32_to_bf16(r); + else + params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r; } } } #else - gemm_bf16_256(m0, m, n0, n, RM, RN, params); + gemm_f32_bf16_256(m0, m, n0, n, RM, RN, params); #endif } #endif //!ARM_NEON diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java new file mode 100644 index 0000000..d644d5a --- /dev/null +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/cnative/constants$2.java @@ -0,0 +1,67 @@ +// Generated by jextract + +package com.github.tjake.jlama.tensor.operations.cnative; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.VarHandle; +import java.nio.ByteOrder; +import java.lang.foreign.*; +import static java.lang.foreign.ValueLayout.*; +final class constants$2 { + + // Suppresses default constructor, ensuring non-instantiability. + private constants$2() {} + static final FunctionDescriptor const$0 = FunctionDescriptor.ofVoid( + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT + ); + static final MethodHandle const$1 = RuntimeHelper.downcallHandle( + "gemm_bf16", + constants$2.const$0 + ); + static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid( + JAVA_INT, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + JAVA_INT, + RuntimeHelper.POINTER, + RuntimeHelper.POINTER, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT, + JAVA_INT + ); + static final MethodHandle const$3 = RuntimeHelper.downcallHandle( + "gemm_bf16_batch", + constants$2.const$2 + ); + static final MethodHandle const$4 = RuntimeHelper.downcallHandle( + "gemm_f32_bf16", + constants$2.const$0 + ); + static final MethodHandle const$5 = RuntimeHelper.downcallHandle( + "gemm_f32_bf16_batch", + constants$2.const$2 + ); +} + +