Skip to content

Commit

Permalink
Finish broadcasting mul mat support for GQA
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed Oct 22, 2023
1 parent a0db45f commit 3de5ba4
Showing 1 changed file with 100 additions and 44 deletions.
144 changes: 100 additions & 44 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,9 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;

const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
Expand Down Expand Up @@ -1919,24 +1922,37 @@ static void ggml_vk_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_matmul_split_k_reduce, ne12 * ne13);
}

std::vector<vk::Semaphore> x_semaphores;

if (load_x) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const uint32_t x_offset = load_x ? x_sz * (i03 * ne02 + i02) : 0;
// copy data to device
vk::Semaphore s = ggml_vk_create_semaphore(vk_device.compute_queue);
x_semaphores.push_back(s);
// Wait for previous matmul to be done before writing to the input buffers again
extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, x_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s }, nullptr, &extra->memcpys));
}
}
}

for (int64_t i13 = 0; i13 < ne13; i13++) {
const int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
const uint32_t x_offset = load_x ? x_sz * (i13 * ne02 + i12) : 0;
int64_t i02 = i12 / r2;

const uint32_t x_offset = load_x ? x_sz * (i03 * ne02 + i02) : 0;
const uint32_t y_offset = y_sz * (i13 * ne12 + i12);
const uint32_t d_offset = d_sz * (i13 * ne12 + i12);

vk::Semaphore s_x;
vk::Semaphore s_y = ggml_vk_create_semaphore(vk_device.compute_queue);
std::vector<vk::Semaphore> semaphores = { s_y };
// copy data to device
if (load_x) {
s_x = ggml_vk_create_semaphore(vk_device.compute_queue);
semaphores.push_back(s_x);
// Wait for previous matmul to be done before writing to the input buffers again
extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_X, x_offset, src0, i13 % ne03, i12 & ne02, vk_device.transfer_queues[0], {}, { s_x }, nullptr, &extra->memcpys));
s_x = x_semaphores[i03 * ne02 + i02];
}
vk::Semaphore s_y = ggml_vk_create_semaphore(vk_device.compute_queue);
std::vector<vk::Semaphore> semaphores = { s_y };

// Wait for previous matmul to be done before writing to the input buffers again
extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Y, y_offset, src1, i13, i12, vk_device.transfer_queues[1], {}, { s_y }, nullptr, &extra->memcpys));

// compute
Expand Down Expand Up @@ -1970,6 +1986,9 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor *
const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;

vk_queue& compq = vk_device.compute_queue;
vk_queue& tr0q = vk_device.transfer_queues[0];
vk_queue& tr1q = vk_device.transfer_queues[1];
Expand Down Expand Up @@ -2054,14 +2073,58 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor *
ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline_matmul_split_k_reduce, ne12 * ne13);
}

std::vector<vk::Semaphore> x_semaphores;

for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
const uint32_t it_idx0 = (i03 * ne02 + i02);
const uint32_t qx_offset = qx_sz * it_idx0;
const uint32_t x_offset = x_sz * it_idx0;

vk::Semaphore s_qx;
vk::Semaphore s_x;

if (load_x) {
// copy data to device
s_qx = ggml_vk_create_semaphore(vk_device.compute_queue);
extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qx, qx_offset, src0, i03, i02, vk_device.transfer_queues[0], {}, { s_qx }, nullptr, &extra->memcpys));
}

if (qx_needs_dequant) {
s_x = ggml_vk_create_semaphore(vk_device.compute_queue);

vk_submission s = ggml_vk_begin_submission(compq);
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
ggml_vk_sync_buffers(s.buffer, { { *d_Qx, qx_offset, qx_sz } }, compq, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
ggml_vk_sync_buffers(s.buffer, { { *d_X, x_offset, x_sz } }, compq, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_0, { { *d_Qx, qx_offset, qx_sz }, { *d_X, x_offset, x_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1});
if (load_x) {
ggml_vk_end_submission(s, { s_qx }, { s_x });
} else {
ggml_vk_end_submission(s, {}, { s_x });
}

extra->comp_seqs.push_back({ s });

x_semaphores.push_back(s_x);
} else {
x_semaphores.push_back(s_qx);
}
}
}

for (int64_t i13 = 0; i13 < ne13; i13++) {
const int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
const uint32_t it_idx = (i13 * ne12 + i12);
const uint32_t qx_offset = load_x ? qx_sz * it_idx : 0;
const uint32_t qy_offset = load_y ? qy_sz * it_idx : 0;
const uint32_t x_offset = x_sz * it_idx;
const uint32_t y_offset = y_sz * it_idx;
const uint32_t d_offset = d_sz * it_idx;
int64_t i02 = i12 / r2;

const uint32_t it_idx0 = (i03 * ne02 + i02);
const uint32_t it_idx1 = (i13 * ne12 + i12);
const uint32_t qx_offset = load_x ? qx_sz * it_idx0 : 0;
const uint32_t qy_offset = load_y ? qy_sz * it_idx1 : 0;
const uint32_t x_offset = x_sz * it_idx0;
const uint32_t y_offset = y_sz * it_idx1;
const uint32_t d_offset = d_sz * it_idx1;

vk::Semaphore s_x;
vk::Semaphore s_y;
Expand All @@ -2073,13 +2136,12 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor *
std::vector<vk::Semaphore> mm_semaphores;

if (load_x) {
s_x = ggml_vk_create_semaphore(tr0q);
s_x = x_semaphores[it_idx0];
if (qx_needs_dequant) {
q_semaphores.push_back(s_x);
} else {
mm_semaphores.push_back(s_x);
}
extra->in0_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qx, qx_offset, src0, i13 % ne03, i12 % ne02, tr0q, {}, { s_x }, nullptr, &extra->memcpys));
}
if (load_y) {
s_y = ggml_vk_create_semaphore(tr1q);
Expand All @@ -2091,22 +2153,15 @@ static void ggml_vk_mul_mat_q_f16(const ggml_tensor * src0, const ggml_tensor *
extra->in1_seqs.push_back(ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, tr1q, {}, { s_y }, nullptr, &extra->memcpys));
}

if (dq) {
if (qy_needs_dequant) {
s_q = ggml_vk_create_semaphore(tr0q);
vk_submission s = ggml_vk_begin_submission(compq);
if (qx_needs_dequant) {
const std::vector<int> pc = { (int)ne01, (int)ne10, (int)ne10, (int)ne10 };
ggml_vk_sync_buffers(s.buffer, { { *d_Qx, qx_offset, qx_sz } }, compq, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
ggml_vk_sync_buffers(s.buffer, { { *d_X, x_offset, x_sz } }, compq, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_0, { { *d_Qx, qx_offset, qx_sz }, { *d_X, x_offset, x_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)x_ne, 1, 1});
}

if (qy_needs_dequant) {
const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
ggml_vk_sync_buffers(s.buffer, { { *d_Qy, qy_offset, qy_sz } }, compq, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
ggml_vk_sync_buffers(s.buffer, { { *d_Y, y_offset, y_sz } }, compq, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_1, { { *d_Qy, qy_offset, qy_sz }, { *d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});
}
const std::vector<int> pc = { (int)ne11, (int)ne10, (int)ne10, (int)ne10 };
ggml_vk_sync_buffers(s.buffer, { { *d_Qy, qy_offset, qy_sz } }, compq, vk::AccessFlagBits::eTransferWrite, vk::AccessFlagBits::eShaderRead, false);
ggml_vk_sync_buffers(s.buffer, { { *d_Y, y_offset, y_sz } }, compq, vk::AccessFlagBits::eShaderRead, vk::AccessFlagBits::eShaderWrite, false);
ggml_vk_dispatch_pipeline(s, *to_fp16_vk_1, { { *d_Qy, qy_offset, qy_sz }, { *d_Y, y_offset, y_sz } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)y_ne, 1, 1});

ggml_vk_end_submission(s, std::move(q_semaphores), { s_q });
extra->comp_seqs.push_back({ s });

Expand Down Expand Up @@ -2140,16 +2195,19 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso
const int64_t ne13 = src1->ne[3];

GGML_ASSERT(ne11 == 1);
GGML_ASSERT(src0->backend == GGML_BACKEND_GPU);

const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3];

const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03;

vk_queue& compq = vk_device.compute_queue;
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;

const bool qy_needs_dequant = src1->type != GGML_TYPE_F16 && !f16_f32_kernel;

const bool load_x = src0->backend != GGML_BACKEND_GPU;
const bool load_y = src1->backend != GGML_BACKEND_GPU;

const int x_ne = ne01 * ne00;
Expand All @@ -2171,11 +2229,7 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso
vk_buffer* d_Qx;
vk_buffer* d_Qy;
vk_buffer* d_Y;
if (load_x) {
d_Qx = &vk_preallocated_buffers[extra->buffer_idx[buffer_idx++]];
} else {
d_Qx = (vk_buffer *) src0->data;
}
d_Qx = (vk_buffer *) src0->data;
if (load_y) {
d_Qy = &vk_preallocated_buffers[extra->buffer_idx[buffer_idx++]];
} else {
Expand All @@ -2200,18 +2254,20 @@ static void ggml_vk_mul_mat_vec_q_f16(const ggml_tensor * src0, const ggml_tenso
ggml_vk_pipeline_allocate_descriptor_sets(*dmmv, ne12 * ne13);

for (int64_t i13 = 0; i13 < ne13; i13++) {
const int64_t i03 = i13 / r3;
for (int64_t i12 = 0; i12 < ne12; i12++) {
const uint32_t it_idx = (i13 * ne12 + i12);
const uint32_t qx_offset = load_x ? qx_sz * it_idx : 0;
const uint32_t qy_offset = load_y ? qy_sz * it_idx : 0;
const uint32_t y_offset = y_sz * it_idx;
const uint32_t d_offset = d_sz * it_idx;
int64_t i02 = i12 / r2;

const uint32_t it_idx0 = (i03 * ne02 + i02);
const uint32_t it_idx1 = (i13 * ne12 + i12);
const uint32_t qx_offset = qx_sz * it_idx0;
const uint32_t qy_offset = qy_sz * it_idx1;
const uint32_t y_offset = y_sz * it_idx1;
const uint32_t d_offset = d_sz * it_idx1;

vk_submission s = ggml_vk_begin_submission(compq);

if (load_x) {
ggml_vk_h2d_tensor_2d(d_Qx, qx_offset, src0, i13 % ne03, i12 % ne02, compq, {}, {}, &s, &extra->memcpys);
}
vk::Semaphore s_x;
if (load_y) {
ggml_vk_h2d_tensor_2d(d_Qy, qy_offset, src1, i13, i12, compq, {}, {}, &s, &extra->memcpys);
}
Expand Down

0 comments on commit 3de5ba4

Please sign in to comment.