Skip to content

Commit

Permalink
BF16 avx512 native
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jul 17, 2024
1 parent dd8c3bf commit 9fec857
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ public class CausalSelfAttention {

private final AbstractTensor outputProjectionWeights;

private final Optional<Float> preAttentionScale;
private final float attentionScale;

private final AbstractTensor[] qkvResults;
Expand Down Expand Up @@ -109,7 +108,6 @@ public CausalSelfAttention(
this.outputProjectionWeights = outputProjectionWeights;

this.attentionScale = (float) (1.0 / StrictMath.sqrt(c.headSize));
this.preAttentionScale = Optional.of((float)Math.pow(c.headSize, -0.5));

this.qkvResults = new AbstractTensor[3];
this.qkvWeights = new AbstractTensor[] {queryAttnWeights, keyAttnWeights, valueAttnWeights};
Expand Down
62 changes: 45 additions & 17 deletions jlama-native/src/main/c/vector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R
}
}

/*void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
#if defined(__AVX512F__)
int ytiles = (m - m0) / RM;
int xtiles = (n - n0) / RN;
Expand All @@ -1048,15 +1048,38 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R
for (int ni = 0; ni < RN; ++ni) {
int ao = params.aoffset;
int bo = params.boffset;
for(int j = 0; j < params.k; j += 16, ao += 16, bo += 16) { // 512bits == 16floats
// Load float32
__m512 vb = _mm512_loadu_ps(params.bf + params.ldb * (jj + ni) + bo);
for(int j = 0; j < params.k; j += 32, ao += 32, bo += 32) { // 512bits == 32bfloats
// Load shorts
__m512i vb = _mm512_loadu_si512((__m512i*)(params.bs + params.ldb * (jj + ni) + bo));

// Extract lower 8 shorts and convert to int (lower 128 bits)
__m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 0));
// Shift left 16 bits and convert to float
__m512 vb0 = _mm512_castsi512_ps(_mm512_slli_epi32(vb0i, 16));

// Extract lower 8 shorts and convert to int (upper 128 bits)
__m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 1));
// Shift left 16 bits and convert to float
__m512 vb1 = _mm512_castsi512_ps(_mm512_slli_epi32(vb1i, 16));

for (int mi = 0; mi < RM; ++mi) {
__m512 va = _mm512_loadu_ps(params.af + params.lda * (ii + mi) + ao);
// Load shorts
__m512i va = _mm512_loadu_si512((__m512i*)(params.as + params.lda * (jj + ni) + ao));

// Extract lower 8 shorts and convert to int (lower 128 bits)
__m512i va0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(va, 0));
// Shift left 16 bits and convert to float
__m512 va0 = _mm512_castsi512_ps(_mm512_slli_epi32(va0i, 16));

// Extract lower 8 shorts and convert to int (upper 128 bits)
__m512i va1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(va, 1));
// Shift left 16 bits and convert to float
__m512 va1 = _mm512_castsi512_ps(_mm512_slli_epi32(va1i, 16));


// Multiply and accumulate
sums[mi][ni] = _mm512_fmadd_ps(va, vb, sums[mi][ni]);
sums[mi][ni] = _mm512_fmadd_ps(va0, vb0, sums[mi][ni]);
sums[mi][ni] = _mm512_fmadd_ps(va1, vb1, sums[mi][ni]);
}
}
}
Expand All @@ -1065,14 +1088,17 @@ void __attribute__((noinline)) gemm_bf16_256(int m0, int m, int n0, int n, int R
for (int ni = 0; ni < RN; ++ni) {
// Horizontal sum of the vector to get dot product
float r = _mm512_reduce_add_ps(sums[mi][ni]);
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
if (params.rs != NULL)
params.rs[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = fp32_to_bf16(r);
else
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
}
}
}
#else
gemm_bf16_256(m0, m, n0, n, RM, RN, params);
#endif
}*/
}
#endif //!ARM_NEON

void gemm_bf16(int flags, const short *a, int aoffset, const short *b, int boffset, short *rs, float *r, int roffset, int m, int n0, int n, int k, int lda, int ldb, int ldc)
Expand All @@ -1097,10 +1123,9 @@ void gemm_bf16(int flags, const short *a, int aoffset, const short *b, int boffs
};

#if !defined(__ARM_NEON__)
//((flags & HAS_AVX2) != 0)
// ? gemm(0, m, n0, n0 + n, gemm_bf16_512, p)
//:
gemm(0, m, n0, n0 + n, gemm_bf16_256, p);
((flags & HAS_AVX2) != 0)
? gemm(0, m, n0, n0 + n, gemm_bf16_512, p)
: gemm(0, m, n0, n0 + n, gemm_bf16_256, p);
#else
gemm(0, m, n0, n0 + n, gemm_bf16_128_arm, p);
#endif
Expand Down Expand Up @@ -1229,7 +1254,7 @@ void __attribute__((noinline)) gemm_f32_bf16_256(int m0, int m, int n0, int n, i
}
}

void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
void gemm_f32_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_params params) {
#if defined(__AVX512F__)
int ytiles = (m - m0) / RM;
int xtiles = (n - n0) / RN;
Expand Down Expand Up @@ -1257,12 +1282,12 @@ void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_par
__m512i vb = _mm512_loadu_si512((__m512i*)(params.bs + params.ldb * (jj + ni) + bo));

// Extract lower 8 shorts and convert to int (lower 128 bits)
__m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti128_si512(vb, 0));
__m512i vb0i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 0));
// Shift left 16 bits and convert to float
__m512 vb0 = _mm512_castsi512_ps(_mm512_slli_epi32(vb0i, 16));

// Extract lower 8 shorts and convert to int (upper 128 bits)
__m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti128_si512(vb, 1));
__m512i vb1i = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(vb, 1));
// Shift left 16 bits and convert to float
__m512 vb1 = _mm512_castsi512_ps(_mm512_slli_epi32(vb1i, 16));

Expand All @@ -1281,12 +1306,15 @@ void gemm_bf16_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_par
for (int ni = 0; ni < RN; ++ni) {
// Horizontal sum of the vector to get dot product
float r = _mm512_reduce_add_ps(sums[mi][ni]);
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
if (params.rs != NULL)
params.rs[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = fp32_to_bf16(r);
else
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
}
}
}
#else
gemm_bf16_256(m0, m, n0, n, RM, RN, params);
gemm_f32_bf16_256(m0, m, n0, n, RM, RN, params);
#endif
}
#endif //!ARM_NEON
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Generated by jextract

package com.github.tjake.jlama.tensor.operations.cnative;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.nio.ByteOrder;
import java.lang.foreign.*;
import static java.lang.foreign.ValueLayout.*;
final class constants$2 {

// Suppresses default constructor, ensuring non-instantiability.
private constants$2() {}
static final FunctionDescriptor const$0 = FunctionDescriptor.ofVoid(
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$1 = RuntimeHelper.downcallHandle(
"gemm_bf16",
constants$2.const$0
);
static final FunctionDescriptor const$2 = FunctionDescriptor.ofVoid(
JAVA_INT,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
JAVA_INT,
RuntimeHelper.POINTER,
RuntimeHelper.POINTER,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT,
JAVA_INT
);
static final MethodHandle const$3 = RuntimeHelper.downcallHandle(
"gemm_bf16_batch",
constants$2.const$2
);
static final MethodHandle const$4 = RuntimeHelper.downcallHandle(
"gemm_f32_bf16",
constants$2.const$0
);
static final MethodHandle const$5 = RuntimeHelper.downcallHandle(
"gemm_f32_bf16_batch",
constants$2.const$2
);
}


0 comments on commit 9fec857

Please sign in to comment.