diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 781cb0731074f..1e2b2fd2feb32 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -666,6 +666,7 @@ static inline bool ggml_vk_build_shader(ggml_type type) { case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return true; default: @@ -962,6 +963,7 @@ static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: break; default: @@ -985,6 +987,7 @@ static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bo case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: break; default: @@ -2789,7 +2792,7 @@ void ggml_vk_check_results_1(ggml_compute_params * params, ggml_tensor * tensor) avg_err /= tensor->ne[3] * tensor->ne[2] * tensor->ne[1] * tensor->ne[0]; - if (avg_err > 1.0 || std::isnan(avg_err)) { + if (avg_err > 0.1 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << std::endl; std::cerr << "tensor->backend: " << tensor->backend << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << std::endl; if (tensor->src[0] != nullptr) { diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index c888dc6f79d50..a11d59694ea12 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -128,6 +128,19 @@ #define A_TYPE block_q4_K """ +shader_q5_K_defines = """ +#define QUANT_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K/8]; + uint8_t qs[QUANT_K/2]; +}; + +#define A_TYPE block_q5_K +""" shader_q6_K_defines = """ #define QUANT_K 256 @@ -635,13 +648,77 @@ const FLOAT_TYPE d2 = dall * sc; const FLOAT_TYPE m2 = dmin * m; - for (int l = 0; l < n; ++l) { + [[unroll]] for (int l = 0; l < n; ++l) { y[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(x[i].qs[qs_idx + l] & 0xF) - m1); y[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(x[i].qs[qs_idx + l] >> 4) - m2); } } } """ +dequant_q5_K_body = """ +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE x[];}; +layout (binding = 1) writeonly buffer D {D_TYPE y[];}; + +layout (push_constant) uniform parameter +{ + int M; + int K; + int stride_a; + int stride_b; +} p; + +void main() { + [[unroll]] for (int wgy = 0; wgy < 256; wgy++) { + const int i = int(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.M * p.K / QUANT_K) { + return; + } + + const int tid = int(gl_LocalInvocationID.x); + const int il = tid / 16; + const int ir = tid % 16; + const int is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(x[i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(x[i].d.y); + + const int y_idx = i * QUANT_K + 64 * il + 2 * ir; + const int qs_idx = 32*il + 2 * ir; + const int qh_idx = 2 * ir; + + uint8_t sc; + uint8_t m; + if (is < 4) { + sc = uint8_t(x[i].scales[is] & 63); + m = uint8_t(x[i].scales[is + 4] & 63); + } else { + sc = uint8_t((x[i].scales[is + 4] & 0xF) | ((x[i].scales[is - 4] >> 6) << 4)); + m = uint8_t((x[i].scales[is + 4] >> 4) | ((x[i].scales[is ] >> 6) << 4)); + } + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * m; + + if (is < 4) { + sc = uint8_t(x[i].scales[is + 1] & 63); + m = uint8_t(x[i].scales[is + 5] & 63); + } else { + sc = uint8_t((x[i].scales[is + 5] & 0xF) | ((x[i].scales[is - 3] >> 6) << 4)); + m = uint8_t((x[i].scales[is + 5] >> 4) | ((x[i].scales[is + 1] >> 6) << 4)); + } + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * m; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + y[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx ] & 0xF) + (((x[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + y[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((x[i].qs[qs_idx + 1] & 0xF) + (((x[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + y[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx ] >> 4) + (((x[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + y[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((x[i].qs[qs_idx + 1] >> 4) + (((x[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} +""" dequant_q6_K_body = """ layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; @@ -981,10 +1058,10 @@ const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); - const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx] * q4_0 + y[y1_idx + 1] * q4_1); - const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_2 + y[y1_idx + 33] * q4_3); - const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx] * q4_4 + y[y2_idx + 1] * q4_5); - const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_6 + y[y2_idx + 33] * q4_7); + const FLOAT_TYPE sx = FLOAT_TYPE(y[y1_idx ] * q4_0 + y[y1_idx + 1] * q4_1 + y[y1_idx + 2] * q4_2 + y[y1_idx + 3] * q4_3 ); + const FLOAT_TYPE sy = FLOAT_TYPE(y[y1_idx + 32] * q4_4 + y[y1_idx + 33] * q4_5 + y[y1_idx + 34] * q4_6 + y[y1_idx + 35] * q4_7 ); + const FLOAT_TYPE sz = FLOAT_TYPE(y[y2_idx ] * q4_8 + y[y2_idx + 1] * q4_9 + y[y2_idx + 2] * q4_10 + y[y2_idx + 3] * q4_11); + const FLOAT_TYPE sw = FLOAT_TYPE(y[y2_idx + 32] * q4_12 + y[y2_idx + 33] * q4_13 + y[y2_idx + 34] * q4_14 + y[y2_idx + 35] * q4_15); const FLOAT_TYPE smin = FLOAT_TYPE( y[y1_idx] * sc2 + y[y1_idx + 32] * sc3 + y[y2_idx] * sc6 + y[y2_idx + 32] * sc7 + y[y1_idx + 1] * sc2 + y[y1_idx + 33] * sc3 + y[y2_idx + 1] * sc6 + y[y2_idx + 33] * sc7 @@ -1007,6 +1084,121 @@ } } """ +mul_mat_vec_q5_K_body = """ +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE x[];}; +layout (binding = 1) readonly buffer B {B_TYPE y[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (push_constant) uniform parameter +{ + int ncols; +} p; + +shared FLOAT_TYPE tmp[32]; + +void main() { + const int row = int(gl_WorkGroupID.x); + + const int num_blocks_per_row = p.ncols / QUANT_K; + const int ib0 = row*num_blocks_per_row; + + const int tid = int(gl_LocalInvocationID.x)/2; // 0...31 or 0...16 + const int ix = int(gl_LocalInvocationID.x)%2; // 0 or 0, 1 + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il; // 0...7 or 0...3 + + const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int v_in = il % 2; + + const int l0 = 4*ir + 2*v_in; // 0...15 + const int q_offset = 32*v_im + l0; + const int y_offset = 64*v_im + l0; + + const uint8_t hm1 = uint8_t(1 << (2*v_im)); + const uint8_t hm2 = uint8_t(hm1 << 4); + + tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp + + [[unroll]] for (int i = ix; i < num_blocks_per_row; i += 2) { + const int y1_idx = i * QUANT_K + y_offset; + const int y2_idx = y1_idx + 128; + + const FLOAT_TYPE dall = FLOAT_TYPE(x[ib0 + i].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(x[ib0 + i].d.y); + + const uint8_t sc0 = uint8_t( x[ib0 + i].scales[v_im * 2 ] & 0x3f); + const uint8_t sc1 = uint8_t( x[ib0 + i].scales[v_im * 2 + 1] & 0x3f); + const uint8_t sc2 = uint8_t( x[ib0 + i].scales[v_im * 2 + 4] & 0x3f); + const uint8_t sc3 = uint8_t( x[ib0 + i].scales[v_im * 2 + 5] & 0x3f); + const uint8_t sc4 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); + const uint8_t sc5 = uint8_t(( x[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); + const uint8_t sc6 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); + const uint8_t sc7 = uint8_t(((x[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((x[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); + + const uint8_t q4_0 = uint8_t(x[ib0 + i].qs[q_offset ] & 0xf); + const uint8_t q4_1 = uint8_t(x[ib0 + i].qs[q_offset + 1] & 0xf); + const uint8_t q4_2 = uint8_t(x[ib0 + i].qs[q_offset + 16] & 0xf); + const uint8_t q4_3 = uint8_t(x[ib0 + i].qs[q_offset + 17] & 0xf); + const uint8_t q4_4 = uint8_t(x[ib0 + i].qs[q_offset ] >> 4); + const uint8_t q4_5 = uint8_t(x[ib0 + i].qs[q_offset + 1] >> 4); + const uint8_t q4_6 = uint8_t(x[ib0 + i].qs[q_offset + 16] >> 4); + const uint8_t q4_7 = uint8_t(x[ib0 + i].qs[q_offset + 17] >> 4); + const uint8_t q4_8 = uint8_t(x[ib0 + i].qs[q_offset + 64] & 0xf); + const uint8_t q4_9 = uint8_t(x[ib0 + i].qs[q_offset + 65] & 0xf); + const uint8_t q4_10 = uint8_t(x[ib0 + i].qs[q_offset + 80] & 0xf); + const uint8_t q4_11 = uint8_t(x[ib0 + i].qs[q_offset + 81] & 0xf); + const uint8_t q4_12 = uint8_t(x[ib0 + i].qs[q_offset + 64] >> 4); + const uint8_t q4_13 = uint8_t(x[ib0 + i].qs[q_offset + 65] >> 4); + const uint8_t q4_14 = uint8_t(x[ib0 + i].qs[q_offset + 80] >> 4); + const uint8_t q4_15 = uint8_t(x[ib0 + i].qs[q_offset + 81] >> 4); + + const FLOAT_TYPE sx = FLOAT_TYPE( + y[y1_idx ] * (q4_0 + (((x[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)) + + y[y1_idx + 1] * (q4_1 + (((x[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)) + + y[y1_idx + 16] * (q4_2 + (((x[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)) + + y[y1_idx + 17] * (q4_3 + (((x[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)) + ); + const FLOAT_TYPE sy = FLOAT_TYPE( + y[y1_idx + 32] * (q4_4 + (((x[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)) + + y[y1_idx + 33] * (q4_5 + (((x[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)) + + y[y1_idx + 48] * (q4_6 + (((x[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)) + + y[y1_idx + 49] * (q4_7 + (((x[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)) + ); + const FLOAT_TYPE sz = FLOAT_TYPE( + y[y2_idx ] * (q4_8 + (((x[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)) + + y[y2_idx + 1] * (q4_9 + (((x[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)) + + y[y2_idx + 16] * (q4_10 + (((x[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)) + + y[y2_idx + 17] * (q4_11 + (((x[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)) + ); + const FLOAT_TYPE sw = FLOAT_TYPE( + y[y2_idx + 32] * (q4_12 + (((x[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)) + + y[y2_idx + 33] * (q4_13 + (((x[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)) + + y[y2_idx + 48] * (q4_14 + (((x[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)) + + y[y2_idx + 49] * (q4_15 + (((x[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)) + ); + const FLOAT_TYPE smin = FLOAT_TYPE( + (y[y1_idx] + y[y1_idx + 1] + y[y1_idx + 16] + y[y1_idx + 17]) * sc2 + (y[y1_idx + 32] + y[y1_idx + 33] + y[y1_idx + 48] + y[y1_idx + 49]) * sc3 + + (y[y2_idx] + y[y2_idx + 1] + y[y2_idx + 16] + y[y2_idx + 17]) * sc6 + (y[y2_idx + 32] + y[y2_idx + 33] + y[y2_idx + 48] + y[y2_idx + 49]) * sc7 + ); + tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin); + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = 16; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + if (tid == 0) { + dst[row] = D_TYPE(tmp[0]); + } +} +""" mul_mat_vec_q6_K_body = """ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; @@ -1067,7 +1259,7 @@ tmp[16 * ix + tid] += sum; #else FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < 4; ++l) { + [[unroll]] for (int l = 0; l < 4; ++l) { sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) + FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) + FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) @@ -1379,6 +1571,8 @@ async def main(): stream.extend((shader_q3_K_defines, dequant_q3_K_body)) elif i == GGML_TYPE_Q4_K: stream.extend((shader_q4_K_defines, dequant_q4_K_body)) + elif i == GGML_TYPE_Q5_K: + stream.extend((shader_q5_K_defines, dequant_q5_K_body)) elif i == GGML_TYPE_Q6_K: stream.extend((shader_q6_K_defines, dequant_q6_K_body)) else: @@ -1409,6 +1603,8 @@ async def main(): stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body)) elif i == GGML_TYPE_Q4_K: stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body)) + elif i == GGML_TYPE_Q5_K: + stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body)) elif i == GGML_TYPE_Q6_K: stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body)) else: