Skip to content

Commit

Permalink
Add q4_k support
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Oct 20, 2023
1 parent a861879 commit 4a97d2d
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ static inline bool ggml_vk_build_shader(ggml_type type) {
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return true;
default:
Expand Down Expand Up @@ -960,6 +961,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
break;
default:
Expand All @@ -982,6 +984,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
break;
default:
Expand Down
199 changes: 198 additions & 1 deletion ggml_vk_generate_shaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@
#define A_TYPE block_q3_K
"""
shader_q4_K_defines = """
#define QUANT_K 256
struct block_q4_K
{
f16vec2 d;
uint8_t scales[3*QUANT_K/64];
uint8_t qs[QUANT_K/2];
};
#define A_TYPE block_q4_K
"""
shader_q6_K_defines = """
#define QUANT_K 256
Expand Down Expand Up @@ -568,6 +580,68 @@
}
}
"""
dequant_q4_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE x[];};
layout (binding = 1) writeonly buffer D {D_TYPE y[];};
layout (push_constant) uniform parameter
{
int M;
int K;
int stride_a;
int stride_b;
} p;
void main() {
[[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
const int i = int(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.M * p.K / QUANT_K) {
return;
}
const int tid = int(gl_LocalInvocationID.x);
const int il = tid / 8;
const int ir = tid % 8;
const int is = 2 * il;
const int n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y);
const int y_idx = i * QUANT_K + 64 * il + n * ir;
const int qs_idx = 32*il + n * ir;
uint8_t sc;
uint8_t m;
if (is < 4) {
sc = uint8_t(x[i].scales[is] & 63);
m = uint8_t(x[i].scales[is + 4] & 63);
} else {
sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4));
m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4));
}
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * m;
if (is < 4) {
sc = uint8_t(x[i].scales[is + 1] & 63);
m = uint8_t(x[i].scales[is + 5] & 63);
} else {
sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4));
m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4));
}
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1);
y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2);
}
}
}
"""
dequant_q6_K_body = """
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
Expand Down Expand Up @@ -814,6 +888,125 @@
}
}
"""
mul_mat_vec_q4_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE x[];};
layout (binding = 1) readonly buffer B {B_TYPE y[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (push_constant) uniform parameter
{
int ncols;
} p;
shared FLOAT_TYPE tmp[32];
void main() {
const int row = int(gl_WorkGroupID.x);
const int num_blocks_per_row = p.ncols / QUANT_K;
const int ib0 = row*num_blocks_per_row;
const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int v_in = il % 2;
const int l0 = n * (2 * ir + v_in); // 0...15
const int q_offset = 32*v_im + l0;
const int y_offset = 64*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const int y1_idx = i * QUANT_K + y_offset;
const int y2_idx = y1_idx + 128;
const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y);
const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f);
const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
#if K_QUANTS_PER_ITERATION == 2
const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 2] & 0xf);
const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 3] & 0xf);
const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4);
const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4);
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 2] >> 4);
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 3] >> 4);
const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf);
const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf);
const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 66] & 0xf);
const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 67] & 0xf);
const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 66] >> 4);
const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 67] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3);
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7);
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11);
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15);
const FLOAT_TYPE smin = FLOAT_TYPE(
y[y1_idx ] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx ] * sc6 + y[y2_idx + 32] * sc7
+ y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7
+ y[y1_idx + 2] * sc2 + y[y1_idx + 34] * sc3 + y[y2_idx + 2] * sc6 + y[y2_idx + 34] * sc7
+ y[y1_idx + 3] * sc2 + y[y1_idx + 35] * sc3 + y[y2_idx + 3] * sc6 + y[y2_idx + 35] * sc7
);
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
#else
const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4);
const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4);
const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf);
const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf);
const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4);
const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1);
const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3);
const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_4 + y[y2_idx + 1] * q4_5);
const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7);
const FLOAT_TYPE smin = FLOAT_TYPE(
y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7
+ y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7
);
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(x[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(x[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((x[ib0 + i].scales[v_im + 4] & 0x0f) | ((x[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((x[ib0 + i].scales[v_im + 5] & 0x0f) | ((x[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
#endif
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = 16; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
if (tid == 0) {
dst[row] = D_TYPE(tmp[0]);
}
}
"""
mul_mat_vec_q6_K_body = """
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
Expand Down Expand Up @@ -1068,7 +1261,7 @@
GGML_TYPE_Q8_K: "q8_K",
}

K_QUANTS_PER_ITERATION = 1
K_QUANTS_PER_ITERATION = 2


async def string_to_spv_file(name, code, defines, fp16):
Expand Down Expand Up @@ -1184,6 +1377,8 @@ async def main():
stream.extend((shader_q2_K_defines, dequant_q2_K_body))
elif i == GGML_TYPE_Q3_K:
stream.extend((shader_q3_K_defines, dequant_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, dequant_q4_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, dequant_q6_K_body))
else:
Expand Down Expand Up @@ -1212,6 +1407,8 @@ async def main():
stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
elif i == GGML_TYPE_Q3_K:
stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
elif i == GGML_TYPE_Q4_K:
stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body))
elif i == GGML_TYPE_Q6_K:
stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
else:
Expand Down

0 comments on commit 4a97d2d

Please sign in to comment.