Skip to content

Commit

Permalink
Simplify q6_k fp16 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Oct 8, 2023
1 parent 85c1a63 commit dad1cdb
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions ggml-vulkan-shaders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
// Generic
const std::string shader_f32 = R"(
#define FLOAT_TYPE float
#define INT8_TYPE int
#define UINT8_TYPE uint
)";
const std::string shader_f16 = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#define FLOAT_TYPE float16_t
#define INT8_TYPE int8_t
#define UINT8_TYPE uint8_t
)";
const std::string shader_int8_ext = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
Expand Down Expand Up @@ -605,22 +601,22 @@ void main() {
const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d);
#if K_QUANTS_PER_ITERATION == 1
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0xc0) >> 2)) - 32)
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0xc0) >> 2)) - 32);
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
tmp[16 * ix + tid] += sum;
#else
FLOAT_TYPE sum = 0;
for (int l = 0; l < 4; ++l) {
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 4) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 6) & 3) << 4)) - 32);
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
}
tmp[16 * ix + tid] += sum;
#endif
Expand Down

0 comments on commit dad1cdb

Please sign in to comment.