Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Jul 21, 2023
1 parent 87cc843 commit c268e9e
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ template <typename DT>
__global__ void apply_proj_bias_w(DT *input_ptr,
DT const *bias_ptr,
int num_tokens,
int qkv_weight_size,
int oProjSize) {
CUDA_KERNEL_LOOP(i, num_tokens * oProjSize) {
int bias_idx = 3 * oProjSize + i % oProjSize;
int bias_idx = qkv_weight_size + i % oProjSize;
input_ptr[i] += bias_ptr[bias_idx];
}
}
Expand Down Expand Up @@ -120,24 +121,42 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr,
vProjSize * num_kv_heads)) {
// for simplicity, assume q, k, v is in same shape
// 0->q, 1->k, 2->v
int qkv_index = i / (num_tokens * qProjSize) % 3;

int head_idx = i / (num_tokens * (qProjSize + kProjSize + vProjSize));
int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens;
int q_block_size = qProjSize * num_tokens;

int idx = i % (num_tokens * (qProjSize));

int real_part_index =
head_idx * qkv_block_size + qkv_index * q_block_size + idx;
// int qkv_index = i / (num_tokens * qProjSize) % 3;

int qkv_index = i < num_tokens * qProjSize * num_heads
? 0
: (i < num_tokens * (qProjSize * num_heads +
kProjSize * num_kv_heads)
? 1
: 2);

// int head_idx = i / (num_tokens * (qProjSize + kProjSize + vProjSize));
// int qkv_block_size = (qProjSize + kProjSize + vProjSize) * num_tokens;
int q_block_size = qProjSize * num_tokens * num_heads;
int k_block_size = kProjSize * num_tokens * num_kv_heads;

// int idx = i % (num_tokens * (qProjSize));

// int real_part_index =
// head_idx * qkv_block_size + qkv_index * q_block_size + idx;
int bias_idx = 0;
if(qkv_index == 0){
int head_idx = i / (num_tokens *qProjSize);
int global_head_idx = head_idx + shard_id * num_heads;
bias_idx = global_head_idx * qProjSize + (i % (num_tokens * (qProjSize)) % qProjSize);
}else{
int idx = qkv_index == 1 ? i - q_block_size : i - q_block_size - k_block_size
int head_idx = i / (num_tokens * kProjSize);
int global_head_idx = head_idx + shard_id * num_kv_heads;
bias_idx = global_head_idx * kProjSize + (idx % (num_tokens * (qProjSize)) % qProjSize)
}
// int bias_idx = qkv_index * qProjSize * global_num_heads +
// global_head_idx * qProjSize + (idx % qProjSize);

int global_head_idx = head_idx + shard_id * num_heads;
int bias_idx = qkv_index * qProjSize * global_num_heads +
global_head_idx * qProjSize + (idx % qProjSize);
input_ptr[real_part_index] += bias_ptr[bias_idx];
input_ptr[i] += bias_ptr[bias_idx];

if (scaling_query && qkv_index == 0) {
input_ptr[real_part_index] *= scaling_factor;
input_ptr[i] *= scaling_factor;
}
}
}
Expand Down Expand Up @@ -772,6 +791,7 @@ void compute_attention_kernel(IncMultiHeadSelfAttentionMeta const *m,
if (*m->bias && shard_id == 0) {
int parallelism = m->oProjSize * num_tokens;
int qkv_weight_size = m->qProjSize * m->
apply_proj_bias_w<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS, parallelism),
0,
Expand Down

0 comments on commit c268e9e

Please sign in to comment.