diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index ce91da057e8b5..781cb0731074f 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -665,6 +665,7 @@ static inline bool ggml_vk_build_shader(ggml_type type) { case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: return true; default: @@ -960,6 +961,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: break; default: @@ -982,6 +984,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo case GGML_TYPE_Q8_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: break; default: diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index b245515acdb5d..c888dc6f79d50 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -116,6 +116,18 @@ #define A_TYPE block_q3_K """ +shader_q4_K_defines = """ +#define QUANT_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K/64]; + uint8_t qs[QUANT_K/2]; +}; + +#define A_TYPE block_q4_K +""" shader_q6_K_defines = """ #define QUANT_K 256 @@ -568,6 +580,68 @@ } } """ +dequant_q4_K_body = """ +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE x[];}; +layout (binding = 1) writeonly buffer D {D_TYPE y[];}; + +layout (push_constant) uniform parameter +{ + int M; + int K; + int stride_a; + int stride_b; +} p; + +void main() { + [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { + const int i = int(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const int tid = int(gl_LocalInvocationID.x); + const int il = tid / 8; + const int ir = tid % 8; + const int is = 2 * il; + const int n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y); + + const int y_idx = i * QUANT_K + 64 * il + n * ir; + const int qs_idx = 32*il + n * ir; + + uint8_t sc; + uint8_t m; + if (is < 4) { + sc = uint8_t(x[i].scales[is] & 63); + m = uint8_t(x[i].scales[is + 4] & 63); + } else { + sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4)); + m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4)); + } + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * m; + + if (is < 4) { + sc = uint8_t(x[i].scales[is + 1] & 63); + m = uint8_t(x[i].scales[is + 5] & 63); + } else { + sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4)); + m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4)); + } + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * m; + + for (int l = 0; l < n; ++l) { + y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1); + y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2); + } + } +} +""" dequant_q6_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; @@ -814,6 +888,125 @@ } } """ +mul_mat_vec_q4_K_body = """ +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE x[];}; +layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + int ncols; +} p; + +shared FLOAT_TYPE tmp[32]; + +void main() { + const int row = int(gl_WorkGroupID.x); + + const int num_blocks_per_row = p.ncols / QUANT_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int v_in = il % 2; + + const int l0 = n * (2 * ir + v_in); // 0...15 + const int q_offset = 32*v_im + l0; + const int y_offset = 64*v_im + l0; + + tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp + + [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int y1_idx = i * QUANT_K + y_offset; + const int y2_idx = y1_idx + 128; + + const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y); + + const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f); + const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f); + const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f); + const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f); + const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); + const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); + const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); + const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); + +#if K_QUANTS_PER_ITERATION == 2 + const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 2] & 0xf); + const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 3] & 0xf); + const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 2] >> 4); + const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 3] >> 4); + const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 66] & 0xf); + const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 67] & 0xf); + const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); + const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 66] >> 4); + const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 67] >> 4); + + const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3); + const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7); + const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11); + const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15); + const FLOAT_TYPE smin = FLOAT_TYPE( + y[y1_idx ] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx ] * sc6 + y[y2_idx + 32] * sc7 + + y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7 + + y[y1_idx + 2] * sc2 + y[y1_idx + 34] * sc3 + y[y2_idx + 2] * sc6 + y[y2_idx + 34] * sc7 + + y[y1_idx + 3] * sc2 + y[y1_idx + 35] * sc3 + y[y2_idx + 3] * sc6 + y[y2_idx + 35] * sc7 + ); + tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); +#else + const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); + + const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1); + const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3); + const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_4 + y[y2_idx + 1] * q4_5); + const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7); + const FLOAT_TYPE smin = FLOAT_TYPE( + y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7 + + y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7 + ); + + tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(x[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(x[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((x[ib0 + i].scales[v_im + 4] & 0x0f) | ((x[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((x[ib0 + i].scales[v_im + 5] & 0x0f) | ((x[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); +#endif + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = 16; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + if (tid == 0) { + dst[row] = D_TYPE(tmp[0]); + } +} +""" mul_mat_vec_q6_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; @@ -1068,7 +1261,7 @@ GGML_TYPE_Q8_K: "q8_K", } -K_QUANTS_PER_ITERATION = 1 +K_QUANTS_PER_ITERATION = 2 async def string_to_spv_file(name, code, defines, fp16): @@ -1184,6 +1377,8 @@ async def main(): stream.extend((shader_q2_K_defines, dequant_q2_K_body)) elif i == GGML_TYPE_Q3_K: stream.extend((shader_q3_K_defines, dequant_q3_K_body)) + elif i == GGML_TYPE_Q4_K: + stream.extend((shader_q4_K_defines, dequant_q4_K_body)) elif i == GGML_TYPE_Q6_K: stream.extend((shader_q6_K_defines, dequant_q6_K_body)) else: @@ -1212,6 +1407,8 @@ async def main(): stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body)) elif i == GGML_TYPE_Q3_K: stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body)) + elif i == GGML_TYPE_Q4_K: + stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body)) elif i == GGML_TYPE_Q6_K: stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body)) else: