From 55927a31e3e3cfbf5b3a7df5e76c464012f14405 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sun, 23 Jul 2023 03:58:21 +0000 Subject: [PATCH] . --- .../cpp/inference/mixture_of_experts/moe.cc | 1 + .../inference/transformers/transformers.cc | 1 + .../inc_multihead_self_attention_kernels.h | 5 +- inference/file_loader.cc | 14 +- inference/file_loader.h | 3 +- inference/models/falcon.cc | 1 + inference/models/llama.cc | 2 + inference/models/opt.cc | 2 + src/c/flexflow_c.cc | 1 + src/ops/inc_multihead_self_attention.cu | 312 ++++++++++++------ src/ops/spec_inc_multihead_self_attention.cu | 3 +- src/ops/tree_inc_multihead_self_attention.cu | 3 +- 12 files changed, 231 insertions(+), 117 deletions(-) diff --git a/examples/cpp/inference/mixture_of_experts/moe.cc b/examples/cpp/inference/mixture_of_experts/moe.cc index ff3f6bb53a..fac59366d0 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.cc +++ b/examples/cpp/inference/mixture_of_experts/moe.cc @@ -78,6 +78,7 @@ Tensor create_moe_encoder(FFModel *model, x, moeConfig->hidden_size, moeConfig->num_attention_heads, + moeConfig->num_attention_heads, moeConfig->attention_kdim, moeConfig->attention_vdim) : model->multihead_attention(x, diff --git a/examples/cpp/inference/transformers/transformers.cc b/examples/cpp/inference/transformers/transformers.cc index 074e832d47..56e583e6e4 100644 --- a/examples/cpp/inference/transformers/transformers.cc +++ b/examples/cpp/inference/transformers/transformers.cc @@ -46,6 +46,7 @@ Tensor create_inc_multihead_attention_decoder( input, transformerConfig->hidden_size, transformerConfig->num_attention_heads, + transformerConfig->num_attention_heads, transformerConfig->attention_kdim, transformerConfig->attention_vdim) : model->multihead_attention(input, diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 5b40136524..0d2e7a79fb 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -23,6 +23,7 @@ template __global__ void apply_proj_bias_w(DT *input_ptr, DT const *bias_ptr, int num_tokens, + int qkv_weight_size, int oProjSize); template @@ -34,6 +35,7 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, int kProjSize, int vProjSize, int num_heads, + int num_kv_heads, bool scaling_query, float scaling_factor); @@ -46,9 +48,10 @@ __global__ void int kProjSize, int num_heads, int num_tokens, + int num_kv_heads, int q_block_size, int k_block_size, - int v_block_size, + int q_array_size, bool q_tensor); template diff --git a/inference/file_loader.cc b/inference/file_loader.cc index 8cb479f1b8..3c9ed96e56 100644 --- a/inference/file_loader.cc +++ b/inference/file_loader.cc @@ -25,10 +25,11 @@ using namespace Legion; FileDataLoader::FileDataLoader(std::string _input_path, std::string _weight_file_path, int _num_heads, + int _num_kv_heads, size_t _hidden_dim, size_t _qkv_inner_dim) : input_path(_input_path), weight_file_path(_weight_file_path), - num_heads(_num_heads), hidden_dim(_hidden_dim), + num_heads(_num_heads), num_kv_heads(_num_kv_heads), hidden_dim(_hidden_dim), qkv_inner_dim(_qkv_inner_dim){}; BatchConfig::TokenId *FileDataLoader::generate_requests(int num, int length) { @@ -279,6 +280,7 @@ void load_attention_weights_multi_query(DT *ptr, template void load_attention_bias_v2(DT *ptr, int num_heads, + int num_kv_heads, size_t hidden_dim, size_t qkv_inner_dim, std::string layer_name, @@ -298,8 +300,10 @@ void load_attention_bias_v2(DT *ptr, std::vector bias_files = {q_file, k_file, v_file, o_file}; int file_index = 0; + for (auto file : bias_files) { - size_t qkv_partial_size = qkv_inner_dim * num_heads; + int n_heads = file_index == 0 ? num_heads : num_kv_heads; + size_t qkv_partial_size = qkv_inner_dim * n_heads; size_t out_partial_size = hidden_dim; size_t partial_size = (file_index < 3) ? qkv_partial_size : out_partial_size; @@ -785,16 +789,18 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, if (file_path.find("attention_w") != std::string::npos) { if (weight_idx == 0) { - load_attention_weights(data, + load_attention_weights_v2(data, num_heads, + num_kv_heads, hidden_dim, qkv_inner_dim, file_path, weight_file_path, volume); } else { - load_attention_bias(data, + load_attention_bias_v2(data, num_heads, + num_kv_heads, hidden_dim, qkv_inner_dim, file_path, diff --git a/inference/file_loader.h b/inference/file_loader.h index 8be820b1bd..bfd9c502c5 100644 --- a/inference/file_loader.h +++ b/inference/file_loader.h @@ -27,6 +27,7 @@ class FileDataLoader { FileDataLoader(std::string _input_path, std::string _weight_file_path, int _num_heads, + int _num_kv_heads, size_t _hidden_dim, size_t _qkv_inner_dim); @@ -54,7 +55,7 @@ class FileDataLoader { int offset); private: - int num_heads; + int num_heads, num_kv_heads; size_t hidden_dim, qkv_inner_dim; std::string input_path; std::string weight_file_path; diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index bced5dc1e0..fd8d6dd938 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -145,6 +145,7 @@ void FALCON::create_falcon_model(FFModel &ff, FileDataLoader fileloader("", weight_file_path, falcon_config.n_heads, + 1, falcon_config.dim, falcon_config.dim / falcon_config.n_heads); fileloader.load_weights(&ff, weights_layers, use_full_precision); diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 06dfaebcb1..40cd59cfb4 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -148,6 +148,7 @@ void LLAMA::create_llama_model(FFModel &ff, att_norm, llama_config.dim, llama_config.n_heads, + llama_config.n_heads, llama_config.dim / llama_config.n_heads, llama_config.dim / llama_config.n_heads, 0.0f, /*dropout*/ @@ -227,6 +228,7 @@ void LLAMA::create_llama_model(FFModel &ff, FileDataLoader fileloader("", weight_file_path, llama_config.n_heads, + llama_config.n_heads, llama_config.dim, llama_config.dim / llama_config.n_heads); fileloader.load_weights(&ff, weights_layers, use_full_precision); diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 503be39672..fca1cfd6e5 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -170,6 +170,7 @@ void OPT::create_opt_model(FFModel &ff, hidden_states, opt_config.hidden_size, opt_config.num_attention_heads, + opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, opt_config.hidden_size / opt_config.num_attention_heads, 0.0f, @@ -244,6 +245,7 @@ void OPT::create_opt_model(FFModel &ff, FileDataLoader fileloader("", weight_file_path, opt_config.num_attention_heads, + opt_config.num_attention_heads, opt_config.hidden_size, opt_config.hidden_size / opt_config.num_attention_heads); diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 09258d8206..d78070f44b 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1022,6 +1022,7 @@ flexflow_tensor_t flexflow_model_add_inc_multihead_attention( Tensor tensor = handle->inc_multihead_self_attention(input, embed_dim, num_heads, + num_heads, kdim, vdim, dropout, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 002cdf508f..8d5e3f6b44 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -140,15 +140,18 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, // int real_part_index = // head_idx * qkv_block_size + qkv_index * q_block_size + idx; int bias_idx = 0; - if(qkv_index == 0){ - int head_idx = i / (num_tokens *qProjSize); + if (qkv_index == 0) { + int head_idx = i / (num_tokens * qProjSize); int global_head_idx = head_idx + shard_id * num_heads; - bias_idx = global_head_idx * qProjSize + (i % (num_tokens * (qProjSize)) % qProjSize); - }else{ - int idx = qkv_index == 1 ? i - q_block_size : i - q_block_size - k_block_size + bias_idx = global_head_idx * qProjSize + + (i % (num_tokens * (qProjSize)) % qProjSize); + } else { + int idx = + qkv_index == 1 ? i - q_block_size : i - q_block_size - k_block_size; int head_idx = i / (num_tokens * kProjSize); int global_head_idx = head_idx + shard_id * num_kv_heads; - bias_idx = global_head_idx * kProjSize + (idx % (num_tokens * (qProjSize)) % qProjSize) + bias_idx = global_head_idx * kProjSize + + (idx % (num_tokens * (qProjSize)) % qProjSize); } // int bias_idx = qkv_index * qProjSize * global_num_heads + // global_head_idx * qProjSize + (idx % qProjSize); @@ -170,11 +173,13 @@ __global__ void int kProjSize, int num_heads, int num_tokens, + int num_kv_heads, int q_block_size, int k_block_size, - int v_block_size, + int q_array_size, bool q_tensor) { int proj_size = q_tensor ? qProjSize : kProjSize; + int n_heads = q_tensor ? num_heads : num_kv_heads; CUDA_KERNEL_LOOP(i, num_tokens * proj_size * num_heads / 2) { // create complex number int head_idx = i / (num_tokens * proj_size / 2); @@ -182,10 +187,14 @@ __global__ void int token_idx = (i - head_idx * (num_tokens * proj_size / 2)) / (proj_size / 2); - int real_part_index = - idx + token_idx * (proj_size / 2) + - head_idx * (q_block_size + k_block_size + v_block_size) + - (q_tensor ? 0 : q_block_size); + // int real_part_index = + // idx + token_idx * (proj_size / 2) + + // head_idx * (q_block_size + k_block_size + v_block_size) + + // (q_tensor ? 0 : q_block_size); + + int real_part_index = idx + token_idx * (proj_size / 2) + + head_idx * (q_tensor ? q_block_size : k_block_size) + + (q_tensor ? 0 : q_array_size); int complex_part_index = real_part_index + (proj_size / 2); complex_input[i] = {input_ptr[real_part_index], @@ -239,12 +248,12 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, // Weights: qSize x qProjSize x 3 x num_heads // Input: qSize x num_tokens // Output >>> qProjSize x num_tokens x 3 x num_heads - int m_q = m->qProjSize * m->embed_dim; - int m_k = m->kProjSize; - int m_v = m->vProjSize; - assert(m_q == m_k && m_k == m_v); // keep things simple for now + int m_q = m->qProjSize * m->num_heads; + int m_k = m->kProjSize * m->num_kv_heads; + int m_v = m->vProjSize * m->num_kv_heads; + // assert(m_q == m_k && m_k == m_v); // keep things simple for now int n = bc->num_active_tokens(); - int k_q = m->qSize, k_k = m->kProjSize, k_v = m->vProjSize; + int k = m->qSize; int lda = k, ldb = k, ldc_q = m_q, ldc_k = m_k, ldc_v = m_v; // Q checkCUDA(cublasGemmEx(m->handle.blas, @@ -252,7 +261,7 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, CUBLAS_OP_N, m_q, n, - k_q, + k, &alpha, weight_ptr, cublas_data_type, @@ -272,43 +281,44 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, CUBLAS_OP_N, m_k, n, - k_k, + k, &alpha, - weight_ptr + m->embed_dim * m->embed_dim, + weight_ptr + m->qSize * m->qProjSize * m->num_heads, cublas_data_type, lda, input_ptr, cublas_data_type, ldb, &beta, - output_ptr + num_tokens * m->embed_dim, + output_ptr + n * m->qProjSize * m->num_heads, cublas_data_type, ldc_k, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // value - checkCUDA(cublasGemmEx( - m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_v, - n, - k_v, - &alpha, - weight_ptr + - m->embed_dim * (m->embed_dim + m->kProjSize * m->num_kv_heads), - cublas_data_type, - lda, - input_ptr, - cublas_data_type, - ldb, - &beta, - output_ptr + num_tokens * (m->embed_dim + m->kProjSize * m->num_kv_heads), - cublas_data_type, - ldc_v, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + checkCUDA( + cublasGemmEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_v, + n, + k, + &alpha, + weight_ptr + m->qSize * (m->qProjSize * m->num_heads + + m->kProjSize * m->num_kv_heads), + cublas_data_type, + lda, + input_ptr, + cublas_data_type, + ldb, + &beta, + output_ptr + n * (m->qProjSize * m->num_heads + + m->kProjSize * m->num_kv_heads), + cublas_data_type, + ldc_v, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // apply rotary emmmbedding for k and v // step1 change the k, v to complex tensor int num_tokens = bc->num_active_tokens(); @@ -316,6 +326,7 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, int q_block_size = m->qProjSize * num_tokens; int k_block_size = m->kProjSize * num_tokens; int v_block_size = m->vProjSize * num_tokens; + int q_array_size = m->qProjSize * num_tokens * m->num_heads; // apply bias for q, k, v if (*m->bias) { apply_proj_bias_qkv<<vProjSize, m->global_num_heads, m->num_heads, + m->num_kv_heads, *m->scaling_query, m->scaling_factor); } @@ -345,10 +357,11 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, m->qProjSize, m->kProjSize, m->num_heads, + m->num_kv_heads, num_tokens, q_block_size, k_block_size, - v_block_size, + q_array_size, true); /*k*/ apply_rotary_embedding<<qProjSize, m->kProjSize, m->num_heads, + m->num_kv_heads, num_tokens, q_block_size, k_block_size, - v_block_size, + q_array_size, false); } } @@ -386,6 +400,7 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, m->vProjSize, num_tokens, m->num_heads, + m->num_kv_heads, BatchConfig::MAX_SEQ_LENGTH, /* k_cache = */ true); @@ -401,6 +416,7 @@ void update_kv_cache_kernel(IncMultiHeadSelfAttentionMeta const *m, m->vProjSize, num_tokens, m->num_heads, + m->num_kv_heads, BatchConfig::MAX_SEQ_LENGTH, /* k_cache = */ false); } @@ -515,7 +531,8 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - compute_attention_kernel(m, bc, shard_id, output_ptr, bias_ptr, stream); + compute_attention_kernel( + m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } } // namespace IncMultiHeadAttention @@ -532,27 +549,36 @@ __global__ void store_kv_cache(DT const *devQKVProjArray, int vProjSize, int num_tokens, int num_heads, + int num_kv_heads, int max_seq_len, bool k_cache) { - CUDA_KERNEL_LOOP(i, - num_tokens * (k_cache ? kProjSize : vProjSize) * num_heads) { + CUDA_KERNEL_LOOP( + i, num_tokens * (k_cache ? kProjSize : vProjSize) * num_kv_heads) { int proj_size = k_cache ? kProjSize : vProjSize; int head_idx = i / (num_tokens * proj_size); int token_idx = (i - head_idx * (num_tokens * proj_size)) / proj_size; int data_idx = i % proj_size; - int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens; - int current_head_block_size = - num_tokens * (k_cache ? qProjSize : qProjSize + kProjSize); - DT val = - devQKVProjArray[head_idx * qkv_block_size + current_head_block_size + - token_idx * proj_size + data_idx]; + // int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens; + // int current_head_block_size = + // num_tokens * (k_cache ? qProjSize : qProjSize + kProjSize); + + int q_array_size = qProjSize * num_tokens * num_heads; + int k_array_size = kProjSize * num_tokens * num_kv_heads; + + // DT val = + // devQKVProjArray[head_idx * qkv_block_size + current_head_block_size + + // token_idx * proj_size + data_idx]; + + DT val = devQKVProjArray[q_array_size + (k_cache ? 0 : k_array_size) + + head_idx * proj_size + token_idx * proj_size + + data_idx]; // int const req_id = id_map[token_idx].request_index; // int const tok_id = id_map[token_idx].token_position; int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - cache_ptr[req_id * (num_heads * max_seq_len * proj_size) + + cache_ptr[req_id * (num_kv_heads * max_seq_len * proj_size) + head_idx * (max_seq_len * proj_size) + tok_id * proj_size + data_idx] = val; } @@ -581,6 +607,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, int shard_id, DT *output_ptr, DT const *bias_ptr, + DT const *weight_ptr, cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); @@ -596,8 +623,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, // int num_requests = bc->num_active_requests(); int num_tokens = bc->num_active_tokens(); int tokens_previous_requests = 0; - int qkv_block_size = - (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; + int qkv_block_size = m->qProjSize * num_tokens; int kt_block_size = m->kProjSize * BatchConfig::MAX_SEQ_LENGTH; int kt_req_block_size = kt_block_size * m->num_heads; int vt_block_size = m->vProjSize * BatchConfig::MAX_SEQ_LENGTH; @@ -613,15 +639,14 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, bc->requestsInfo[i].num_tokens_in_batch; // bc->token_last_available_idx[i] + 1; // Compute (QK^T/sqrt(d_k)) + // a flag of using this scaling alpha int m_ = num_new_tokens; int n = total_tokens; - int k = m->qProjSize; + int k = m->qProjSize * m->num_heads; int lda = k, ldb = k, ldc = m_; int strideA = qkv_block_size; - int strideB = kt_block_size; + int strideB = 0; int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha DT alpha = 1.0f, beta = 0.0f; if (*m->qk_prod_scaling) { alpha = static_cast
(1.0f / sqrt(m->kProjSize)); @@ -634,30 +659,62 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, void const *B = static_cast
(m->keyCache) + i * kt_req_block_size; // To get C, skip over QK^T products from previous requests void *C = (void *)(m->qk_prods); + if (m->num_kv_heads == m->num_heads) { + // use cublasGemmEx + checkCUDA(cublasGemmEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + B, + cublas_data_type, + ldb, + &beta, + C, + cublas_data_type, + ldc, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + // use cublasGemmStridedBatchedEx + int one_step_heads = m->num_heads / m->num_kv_heads; + m_ = num_new_tokens; + n = total_tokens; + k = m->qProjSize; + lda = k, ldb = k, ldc = m_; + for (int step = 0; step < m->num_kv_heads; step++) { + checkCUDA( + cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A + step * strideA * one_step_heads, + cublas_data_type, + lda, + strideA, + B + step * kt_block_size, + cublas_data_type, + ldb, + strideB, + &beta, + C + step * strideC * one_step_heads, + cublas_data_type, + ldc, + strideC, + one_step_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } // Fill all elements above diagonal in qk prods with -inf to force // causal attention. @@ -716,7 +773,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, // Matmul softmax(QK^T/sqrt(d_k)) by V alpha = 1.0f, beta = 0.0f; m_ = num_new_tokens; - n = m->vProjSize; + n = m->vProjSize * m->num_heads; k = total_tokens; lda = m_, ldb = n, ldc = m_; strideA = num_new_tokens * total_tokens; @@ -733,36 +790,70 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, C = static_cast
(m->attn_heads) + tokens_previous_requests * m->num_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->num_heads == m->num_kv_heads) { + checkCUDA(cublasGemmEx(m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + B, + cublas_data_type, + ldb, + &beta, + C, + cublas_data_type, + ldc, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + int one_step_heads = m->num_heads / m->num_kv_heads; + n = m->vProjSize; + lda = m_, ldb = n, ldc = m_; + strideA = num_new_tokens * total_tokens; + strideB = 0; + // strideB = vt_block_size; + strideC = num_new_tokens * m->vProjSize; + for (int step = 0; step < m->num_kv_heads; step++) { + checkCUDA(cublasGemmStridedBatchedEx( + m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A + step * one_step_heads * strideA, + cublas_data_type, + lda, + strideA, + B + step * vt_block_size, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC + step * one_step_heads * strideC, + one_step_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } // Project to output, save result directly on output tensor alpha = 1.0f, beta = 0.0f; m_ = m->oProjSize; k = m->vProjSize * m->num_heads; n = num_new_tokens; lda = k, ldb = n, ldc = m_; - A = static_cast
(m->W_out_contiguous); + A = weight_ptr + m->qSize * (m->qProjSize * m->num_heads + + m->kProjSize * m->num_kv_heads + + m->vProjSize * m->num_kv_heads); B = C; C = static_cast
(output_ptr) + tokens_previous_requests * m->oProjSize; @@ -791,12 +882,15 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, if (*m->bias && shard_id == 0) { int parallelism = m->oProjSize * num_tokens; - int qkv_weight_size = m->qProjSize * m-> + int qkv_weight_size = m->qProjSize * m->num_heads + + m->kProjSize * m->num_kv_heads + + m->vProjSize * m->num_kv_heads; + apply_proj_bias_w<<>>( - output_ptr, bias_ptr, num_tokens, m->oProjSize); + output_ptr, bias_ptr, num_tokens, qkv_weight_size, m->oProjSize); } assert(tokens_previous_requests == num_tokens); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 44080b7c5c..9b3d8520c3 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -459,7 +459,7 @@ void compute_attention_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, min(CUDA_NUM_THREADS, parallelism), 0, stream>>>( - output_ptr, bias_ptr, num_tokens, m->oProjSize); + output_ptr, bias_ptr, num_tokens, 0, m->oProjSize); } } @@ -613,6 +613,7 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( num_samples, attn->num_heads, _num_heads, + attn->num_heads, DT_NONE, false) { cudaStream_t stream; diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index b46ccb4853..6bc991871e 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -447,7 +447,7 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, min(CUDA_NUM_THREADS, parallelism), 0, stream>>>( - output_ptr, bias_ptr, processed_tokens_in_batch, m->oProjSize); + output_ptr, bias_ptr, processed_tokens_in_batch, 0, m->oProjSize); } assert(processed_tokens_in_batch == bc->num_active_tokens()); @@ -646,6 +646,7 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( num_samples, attn->num_heads, _num_heads, + attn->num_heads, attn->quantization_type, attn->offload), num_active_tokens(0) {