Skip to content

Commit

Permalink
Always favor fp16 arithmetic in tinyBLAS
Browse files Browse the repository at this point in the history
It was assumed earlier that upcasting would help precision. However this
wasn't the case, according to levenshtein distance in whisperfile output
which tells me this change makes things objectively better in noticeable
ways. So we now avoid the fp16 conversion, when the ISA is available. It
should be perfectly safe and accurate, even for large sums, since we now
have the ruler reduction divide and conquer approach, in tinyBLAS::gemm.
  • Loading branch information
jart committed Aug 21, 2024
1 parent 6287b60 commit c44664b
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 74 deletions.
6 changes: 2 additions & 4 deletions llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -11091,8 +11091,7 @@ static void ggml_compute_forward_mul_mat(
ith, nth,
src0->type,
src1->type,
dst->type,
dst->op_params[0]))
dst->type))
goto UseGgmlGemm1;
return;
}
Expand Down Expand Up @@ -11153,8 +11152,7 @@ UseGgmlGemm1:;
ith, nth,
src0->type,
vec_dot_type,
dst->type,
dst->op_params[0]))
dst->type))
goto UseGgmlGemm2;
return;
}
Expand Down
6 changes: 2 additions & 4 deletions llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,11 @@ static const struct GemmFuncs {
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @param precision may be used to control the internal compute type
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
int precision) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype, precision);
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
}

/**
Expand Down
20 changes: 10 additions & 10 deletions llamafile/sgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,30 @@ bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void *, const
long, long, const void *, int, int);

bool llamafile_sgemm(long, long, long, const void *, long, const void *, long, void *, long, int,
int, int, int, int, int);
int, int, int, int);
bool llamafile_mixmul(const struct ggml_compute_params *, const struct ggml_tensor *,
const struct ggml_tensor *, const struct ggml_tensor *, struct ggml_tensor *);
size_t llamafile_mixmul_needs(const struct ggml_tensor *, const struct ggml_tensor *,
const struct ggml_tensor *);

bool llamafile_sgemm_unsupported(long, long, long, const void *, long, const void *, long, void *,
long, int, int, int, int, int, int);
long, int, int, int, int, int);
bool llamafile_sgemm_amd_avx(long, long, long, const void *, long, const void *, long, void *, long,
int, int, int, int, int, int);
int, int, int, int, int);
bool llamafile_sgemm_amd_fma(long, long, long, const void *, long, const void *, long, void *, long,
int, int, int, int, int, int);
int, int, int, int, int);
bool llamafile_sgemm_amd_avx2(long, long, long, const void *, long, const void *, long, void *,
long, int, int, int, int, int, int);
long, int, int, int, int, int);
bool llamafile_sgemm_amd_avxvnni(long, long, long, const void *, long, const void *, long, void *,
long, int, int, int, int, int, int);
long, int, int, int, int, int);
bool llamafile_sgemm_amd_avx512f(long, long, long, const void *, long, const void *, long, void *,
long, int, int, int, int, int, int);
long, int, int, int, int, int);
bool llamafile_sgemm_amd_zen4(long, long, long, const void *, long, const void *, long, void *,
long, int, int, int, int, int, int);
long, int, int, int, int, int);
bool llamafile_sgemm_arm80(long, long, long, const void *, long, const void *, long, void *, long,
int, int, int, int, int, int);
int, int, int, int, int);
bool llamafile_sgemm_arm82(long, long, long, const void *, long, const void *, long, void *, long,
int, int, int, int, int, int);
int, int, int, int, int);

bool llamafile_mixmul_unsupported(const struct ggml_compute_params *, const struct ggml_tensor *,
const struct ggml_tensor *, const struct ggml_tensor *,
Expand Down
8 changes: 3 additions & 5 deletions llamafile/sgemm_matmul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
int precision) {
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
precision);
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
assert(res);
}
}
Expand All @@ -63,7 +61,7 @@ int test(void) {

BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
GGML_TYPE_F32, GGML_PREC_DEFAULT));
GGML_TYPE_F32));

int flips = 0;
double err_sum = 0;
Expand Down
8 changes: 3 additions & 5 deletions llamafile/sgemm_sss_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
int precision) {
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
precision);
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
assert(res);
}
}
Expand All @@ -63,7 +61,7 @@ int test(void) {

BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
GGML_TYPE_F32, GGML_PREC_DEFAULT));
GGML_TYPE_F32));

double err_sum = 0;
long long err_worst = 0;
Expand Down
8 changes: 3 additions & 5 deletions llamafile/sgemm_vecdot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
int cpu_get_num_math();

void llamafile_sgemm_openmp(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype,
int precision) {
long ldb, void *C, long ldc, int Atype, int Btype, int Ctype) {
static int nth = cpu_get_num_math();
#pragma omp parallel for
for (int ith = 0; ith < nth; ++ith) {
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype,
precision);
bool res = llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, Atype, Btype, Ctype);
assert(res);
}
}
Expand All @@ -61,7 +59,7 @@ int test(void) {

BENCH(ansiBLAS::sgemm(m, n, k, A, lda, B, ldb, G, ldc));
BENCH(llamafile_sgemm_openmp(m, n, k, A, lda, B, ldb, C, ldc, GGML_TYPE_F32, GGML_TYPE_F32,
GGML_TYPE_F32, GGML_PREC_DEFAULT));
GGML_TYPE_F32));

double err_sum = 0;
long long err_worst = 0;
Expand Down
15 changes: 4 additions & 11 deletions llamafile/tinyblas_cpu_mixmul.inc
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,10 @@ class MixMul {
return false;
}
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (result->op_params[0] == GGML_PREC_F32) {
return mixmat<
4, 1,
tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
} else {
return mixmat<
8, 1,
tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
}
return mixmat<
8, 1,
tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,
ggml_fp16_t, ggml_fp16_t, TC>();
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
return mixmat<
4, 1,
Expand Down
43 changes: 14 additions & 29 deletions llamafile/tinyblas_cpu_sgemm.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
// With the F32, F16, and BF16 data types, the accumulation of roundoff
// errors will only grow logarithmically, thanks to the ruler function.
//
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
Expand All @@ -46,8 +44,7 @@ namespace {

template <typename TC>
bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
TC *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
int precision) {
TC *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {

switch (Atype) {

Expand Down Expand Up @@ -160,23 +157,14 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
if (n < 2)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
return NOT_PROFITABLE;
if (precision == GGML_PREC_F32) {
if (Btype != GGML_TYPE_F32)
return NOT_SUPPORTED;
tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{
k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
return true;
} else {
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
return true;
}
if (Btype == GGML_TYPE_F32)
return WANT_QUANTIZATION;
if (Btype != GGML_TYPE_F16)
return NOT_SUPPORTED;
tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{
k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, C, ldc, ith, nth};
tb.matmul(m, n);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (n < 2 && !FLAG_precise)
// TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?
Expand Down Expand Up @@ -249,7 +237,6 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
(void)nth;
(void)Atype;
(void)Btype;
(void)precision;
}

} // namespace
Expand All @@ -265,8 +252,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,
* GGML_PREC_DEFAULT);
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
Expand All @@ -286,8 +272,7 @@ bool llamafile_sgemm_impl(long m, long n, long k, const void *A, long lda, const
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void *B, long ldb,
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype,
int precision) {
void *C, long ldc, int ith, int nth, int Atype, int Btype, int Ctype) {

assert(m >= 0);
assert(n >= 0);
Expand Down Expand Up @@ -339,7 +324,7 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
switch (Ctype) {
case GGML_TYPE_F32:
return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float *)C, ldc, ith, nth, Atype,
Btype, Ctype, precision);
Btype, Ctype);
default:
return NOT_SUPPORTED;
}
Expand Down
2 changes: 1 addition & 1 deletion llamafile/tinyblas_cpu_unsupported.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

bool llamafile_sgemm_unsupported(long m, long n, long k, const void *A, long lda, const void *B,
long ldb, void *C, long ldc, int ith, int nth, int Atype,
int Btype, int Ctype, int precision) {
int Btype, int Ctype) {
return false;
}

Expand Down

0 comments on commit c44664b

Please sign in to comment.