Skip to content

Commit

Permalink
Add q5_k support
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Oct 20, 2023
1 parent 4a97d2d commit 0ec595f
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ static inline bool ggml_vk_build_shader(ggml_type type) {
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return true;
default:
Expand Down Expand Up @@ -962,6 +963,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
Expand All @@ -985,6 +987,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
Expand Down Expand Up @@ -2789,7 +2792,7 @@ void ggml_vk_check_results_1(ggml_compute_params * params, ggml_tensor * tensor)

avg_err /= tensor->ne[3] * tensor->ne[2] * tensor->ne[1] * tensor->ne[0];

if (avg_err > 1.0 || std::isnan(avg_err)) {
if (avg_err > 0.1 || std::isnan(avg_err)) {
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << std::endl;
std::cerr << "tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << std::endl;
if (tensor->src[0] != nullptr) {
Expand Down
208 changes: 202 additions & 6 deletions ggml_vk_generate_shaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@
#define A_TYPE block_q4_K
"""
shader_q5_K_defines = """
#define QUANT_K 256
struct block_q5_K
{
f16vec2 d;
uint8_t scales[12];
uint8_t qh[QUANT_K/8];
uint8_t qs[QUANT_K/2];
};
#define A_TYPE block_q5_K
"""
shader_q6_K_defines = """
#define QUANT_K 256
Expand Down Expand Up @@ -635,13 +648,77 @@
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * m;
for (int l = 0; l < n; ++l) {
[[unroll]] for (int l = 0; l < n; ++l) {
y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1);
y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2);
}
}
}
"""
dequant_q5_K_body = """
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE x[];};
layout (binding = 1) writeonly buffer D {D_TYPE y[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) {
return;
}
const int tid = int(gl_LocalInvocationID.x);
const int il = tid / 16;
const int ir = tid % 16;
const int is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y);
const int y_idx = i * QUANT_K + 64 * il + 2 * ir;
const int qs_idx = 32*il + 2 * ir;
const int qh_idx = 2 * ir;
uint8_t sc;
uint8_t m;
if (is < 4) {
sc = uint8_t(x[i].scales[is] & 63);
m = uint8_t(x[i].scales[is + 4] & 63);
} else {
sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4));
m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4));
}
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * m;
if (is < 4) {
sc = uint8_t(x[i].scales[is + 1] & 63);
m = uint8_t(x[i].scales[is + 5] & 63);
} else {
sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4));
m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4));
}
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * m;
const uint8_t hm1 = uint8_t(1 << (2 * il ));
const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
y[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx ] & 0xF) + (((x[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
y[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx + 1] & 0xF) + (((x[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
y[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx ] >> 4) + (((x[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
y[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx + 1] >> 4) + (((x[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
}
}
"""
dequant_q6_K_body = """
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
Expand Down Expand Up @@ -981,10 +1058,10 @@
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1);
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3);
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_4 + y[y2_idx + 1] * q4_5);
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7);
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx ] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3 );
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7 );
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx ] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11);
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15);
const FLOAT_TYPE smin = FLOAT_TYPE(
y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7
+ y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7
Expand All @@ -1007,6 +1084,121 @@
}
}
"""
mul_mat_vec_q5_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE x[];};
layout (binding = 1) readonly buffer B {B_TYPE y[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (push_constant) uniform parameter
{
int ncols;
} p;
shared FLOAT_TYPE tmp[32];
void main() {
const int row = int(gl_WorkGroupID.x);
const int num_blocks_per_row = p.ncols / QUANT_K;
const int ib0 = row*num_blocks_per_row;
const int tid = int(gl_LocalInvocationID.x)/2; // 0...31 or 0...16
const int ix = int(gl_LocalInvocationID.x)%2; // 0 or 0, 1
const int il = tid/4; // 0...3
const int ir = tid - 4*il; // 0...7 or 0...3
const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int v_in = il % 2;
const int l0 = 4*ir + 2*v_in; // 0...15
const int q_offset = 32*v_im + l0;
const int y_offset = 64*v_im + l0;
const uint8_t hm1 = uint8_t(1 << (2*v_im));
const uint8_t hm2 = uint8_t(hm1 << 4);
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (int i = ix; i < num_blocks_per_row; i += 2) {
const int y1_idx = i * QUANT_K + y_offset;
const int y2_idx = y1_idx + 128;
const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y);
const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f);
const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 16] & 0xf);
const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 17] & 0xf);
const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4);
const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4);
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 16] >> 4);
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 17] >> 4);
const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf);
const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf);
const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 80] & 0xf);
const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 81] & 0xf);
const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 80] >> 4);
const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 81] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(
y[y1_idx ] * (q4_0 + (((x[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0))
+ y[y1_idx + 1] * (q4_1 + (((x[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0))
+ y[y1_idx + 16] * (q4_2 + (((x[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
+ y[y1_idx + 17] * (q4_3 + (((x[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
);
const FLOAT_TYPE sy = FLOAT_TYPE(
y[y1_idx + 32] * (q4_4 + (((x[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0))
+ y[y1_idx + 33] * (q4_5 + (((x[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0))
+ y[y1_idx + 48] * (q4_6 + (((x[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
+ y[y1_idx + 49] * (q4_7 + (((x[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
);
const FLOAT_TYPE sz = FLOAT_TYPE(
y[y2_idx ] * (q4_8 + (((x[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0))
+ y[y2_idx + 1] * (q4_9 + (((x[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0))
+ y[y2_idx + 16] * (q4_10 + (((x[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
+ y[y2_idx + 17] * (q4_11 + (((x[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
);
const FLOAT_TYPE sw = FLOAT_TYPE(
y[y2_idx + 32] * (q4_12 + (((x[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0))
+ y[y2_idx + 33] * (q4_13 + (((x[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0))
+ y[y2_idx + 48] * (q4_14 + (((x[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
+ y[y2_idx + 49] * (q4_15 + (((x[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
);
const FLOAT_TYPE smin = FLOAT_TYPE(
(y[y1_idx] + y[y1_idx + 1] + y[y1_idx + 16] + y[y1_idx + 17]) * sc2 + (y[y1_idx + 32] + y[y1_idx + 33] + y[y1_idx + 48] + y[y1_idx + 49]) * sc3
+ (y[y2_idx] + y[y2_idx + 1] + y[y2_idx + 16] + y[y2_idx + 17]) * sc6 + (y[y2_idx + 32] + y[y2_idx + 33] + y[y2_idx + 48] + y[y2_idx + 49]) * sc7
);
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
if (tid == 0) {
dst[row] = D_TYPE(tmp[0]);
}
}
"""
mul_mat_vec_q6_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
Expand Down Expand Up @@ -1067,7 +1259,7 @@
tmp[16 * ix + tid] += sum;
#else
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
for (int l = 0; l < 4; ++l) {
[[unroll]] for (int l = 0; l < 4; ++l) {
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
Expand Down Expand Up @@ -1379,6 +1571,8 @@ async def main():
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, dequant_q4_K_body))
elif i == GGML_TYPE_Q5_K:
stream.extend((shader_q5_K_defines, dequant_q5_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
else:
Expand Down Expand Up @@ -1409,6 +1603,8 @@ async def main():
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body))
elif i == GGML_TYPE_Q5_K:
stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
else:
Expand Down

0 comments on commit 0ec595f

Please sign in to comment.