diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 2b68c4d481..002cdf508f 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -56,9 +56,10 @@ template __global__ void apply_proj_bias_w(DT *input_ptr, DT const *bias_ptr, int num_tokens, + int qkv_weight_size, int oProjSize) { CUDA_KERNEL_LOOP(i, num_tokens * oProjSize) { - int bias_idx = 3 * oProjSize + i % oProjSize; + int bias_idx = qkv_weight_size + i % oProjSize; input_ptr[i] += bias_ptr[bias_idx]; } } @@ -120,24 +121,42 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, vProjSize * num_kv_heads)) { // for simplicity, assume q, k, v is in same shape // 0->q, 1->k, 2->v - int qkv_index = i / (num_tokens * qProjSize) % 3; - - int head_idx = i / (num_tokens * (qProjSize + kProjSize + vProjSize)); - int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens; - int q_block_size = qProjSize * num_tokens; - - int idx = i % (num_tokens * (qProjSize)); - - int real_part_index = - head_idx * qkv_block_size + qkv_index * q_block_size + idx; + // int qkv_index = i / (num_tokens * qProjSize) % 3; + + int qkv_index = i < num_tokens * qProjSize * num_heads + ? 0 + : (i < num_tokens * (qProjSize * num_heads + + kProjSize * num_kv_heads) + ? 1 + : 2); + + // int head_idx = i / (num_tokens * (qProjSize + kProjSize + vProjSize)); + // int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens; + int q_block_size = qProjSize * num_tokens * num_heads; + int k_block_size = kProjSize * num_tokens * num_kv_heads; + + // int idx = i % (num_tokens * (qProjSize)); + + // int real_part_index = + // head_idx * qkv_block_size + qkv_index * q_block_size + idx; + int bias_idx = 0; + if(qkv_index == 0){ + int head_idx = i / (num_tokens *qProjSize); + int global_head_idx = head_idx + shard_id * num_heads; + bias_idx = global_head_idx * qProjSize + (i % (num_tokens * (qProjSize)) % qProjSize); + }else{ + int idx = qkv_index == 1 ? i - q_block_size : i - q_block_size - k_block_size + int head_idx = i / (num_tokens * kProjSize); + int global_head_idx = head_idx + shard_id * num_kv_heads; + bias_idx = global_head_idx * kProjSize + (idx % (num_tokens * (qProjSize)) % qProjSize) + } + // int bias_idx = qkv_index * qProjSize * global_num_heads + + // global_head_idx * qProjSize + (idx % qProjSize); - int global_head_idx = head_idx + shard_id * num_heads; - int bias_idx = qkv_index * qProjSize * global_num_heads + - global_head_idx * qProjSize + (idx % qProjSize); - input_ptr[real_part_index] += bias_ptr[bias_idx]; + input_ptr[i] += bias_ptr[bias_idx]; if (scaling_query && qkv_index == 0) { - input_ptr[real_part_index] *= scaling_factor; + input_ptr[i] *= scaling_factor; } } } @@ -772,6 +791,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m, if (*m->bias && shard_id == 0) { int parallelism = m->oProjSize * num_tokens; + int qkv_weight_size = m->qProjSize * m-> apply_proj_bias_w<<