diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index ce30b5dfda..7da9aa389c 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -504,7 +504,6 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - DT alpha = 1.0f, beta = 0.0f; assert(m->qSize == m->vSize && m->qSize == m->kSize); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) @@ -518,43 +517,52 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, compute_type = CUBLAS_COMPUTE_32F_FAST_16F; } #endif - // Compute (W^T)x matmul: einsum(ijkl,im->jmkl) - // Weights: qSize x qProjSize x 3 x num_q_heads - // Input: qSize x num_tokens - // Output >>> qProjSize x num_tokens x 3 x num_q_heads - int m_q = m->qProjSize * m->num_q_heads; - int m_k = m->kProjSize * m->num_q_heads; - int m_v = m->vProjSize * m->num_q_heads; - assert(m_q == m_k && m_k == m_v); // keep things simple for now - int n = bc->num_active_tokens(); - int k = m->qSize; - int m_ = m_q * QKV_WEIGHT_NUM; - int lda = k, ldb = k, ldc = m_; - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - weight_ptr, - cublas_data_type, - lda, - input_ptr, - cublas_data_type, - ldb, - &beta, - output_ptr, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // apply rotary emmmbedding for q - // and k step1 change the k, v to complex tensor + + // Step 1: Compute QKV projections + { + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_q = m->qProjSize * m->num_q_heads; + int m_k = m->kProjSize * m->num_q_heads; + int m_v = m->vProjSize * m->num_q_heads; + assert(m_q == m_k && m_k == m_v); // keep things simple for now + int n = bc->num_active_tokens(); + int k = m->qSize; + int m_ = m_q * QKV_WEIGHT_NUM; + // before transpositions + int lda = k, ldb = k, ldc = m_; + // matrix A: QKV weights + // matrix A's layout: [qSize (hidden_dim), qProjSize, num_heads, 3] + // matrix B: input + // matrix B's layout: [qSize (hidden_dim), num_new_tokens] + // matrix C: devQKVProjArray + // matrix B's layout: [qProjSize, num_heads, 3, num_new_tokens] + checkCUDA(cublasGemmEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + weight_ptr, + cublas_data_type, + lda, + input_ptr, + cublas_data_type, + ldb, + &beta, + output_ptr, + cublas_data_type, + ldc, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + int num_tokens = bc->num_active_tokens(); int parallelism = m->kProjSize * num_tokens * m->num_q_heads; size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; - // apply bias for q, k, v + + // Step 2: apply bias for QKV, or scale the query if (*m->qkv_bias) { apply_proj_bias_qkv<<scaling_factor, m->hidden_size); } + + // Step 3: apply rotary embedding if needed if (*m->apply_rotary_embedding) { /*q&k*/ parallelism = num_tokens * m->hidden_size; @@ -638,38 +648,47 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, cudaDataType_t compute_type = cublas_data_type; #endif // Project to output, save result directly on output tensor - DT alpha = 1.0f, beta = 0.0f; - // int num_tokens = bc->num_active_tokens(); - int m_ = m->oProjSize; - int k = m->vProjSize * m->num_q_heads; - int n = num_tokens; - int lda = k, ldb = k, ldc = m_; - DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + - m->kProjSize * m->num_q_heads + - m->vProjSize * m->num_q_heads); - DT const *B = static_cast
(m->attn_heads); - DT *C = static_cast
(output_ptr); - - checkCUDA(cublasGemmEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - B, - cublas_data_type, - ldb, - &beta, - C, - cublas_data_type, - ldc, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - + { + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_ = m->oProjSize; + int k = m->vProjSize * m->num_q_heads; + int n = num_tokens; + // before transpositions + int lda = k, ldb = k, ldc = m_; + // matrix A: output projection weight + // matrix A's layout: [vProjSize * num_heads, oProjSize] + DT const *A = weight_ptr + m->qSize * (m->qProjSize * m->num_q_heads + + m->kProjSize * m->num_q_heads + + m->vProjSize * m->num_q_heads); + // matrix B: attn heads + // matrix B's layout: [vProjSize * num_heads, num_new_tokens] + DT const *B = static_cast
(m->attn_heads); + // matrix B: output + // matrix B's layout: [oProjSize, num_new_tokens] + DT *C = static_cast
(output_ptr); + + checkCUDA(cublasGemmEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + B, + cublas_data_type, + ldb, + &beta, + C, + cublas_data_type, + ldc, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + // Add final output bias if (*m->final_bias && shard_id == 0) { int parallelism = m->oProjSize * num_tokens; int qkv_weight_size = m->qProjSize * m->global_num_q_heads + @@ -945,54 +964,69 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; - // bc->token_last_available_idx[i] + 1; - // Compute (QK^T/sqrt(d_k)) - // a flag of using this scaling alpha - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + // Step 1: compute query-key product QK.T/sqrt(d_k) + { + // Scale by sqrt(d_k) as per the original attention paper + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // after transpositions + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + // before transpositions + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + // N.B. strides are applied before transpose operations + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // matrix A: devQKVProjArray + // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] + // To get query projection, skip over Q entries from previous requests + DT const *A = static_cast
(m->devQKVProjArray) + + tokens_previous_requests * m->qProjSize * m->num_q_heads * + QKV_WEIGHT_NUM; + // matrix B: key cache + // matrix B's layout: [kProjSize * num_heads, total_tokens] + // To get B, skip over K entries from previous requests (all heads + + // padding) + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + // matrix C: qk_prods + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // To get C, skip over QK.T products from previous requests + DT *C = static_cast
(m->qk_prods); + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - tokens_previous_requests * m->qProjSize * m->num_q_heads * - QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - // To get C, skip over QK^T products from previous requests + // Step 2: Add alibi position bias to qk production + // matrix C: qk_prods + // matrix C's layout: [num_new_tokens, total_tokens, num_heads] + // To get C, skip over QK.T products from previous requests DT *C = static_cast
(m->qk_prods); - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // add alibi position bias to qk production if (*m->position_bias) { size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; apply_position_bias_qkprd<< 0) { @@ -1022,87 +1056,102 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, entries_above_diagonal, static_cast
(-INFINITY)); } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax); - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + i * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - - // store the result attn heads, also skip the genration tokens - C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + // Step 4: Compute Softmax(QK.T/sqrt(d_k)) + { + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, + CUDNN_TENSOR_NCHW, + cudnn_data_type, + n_param, + c_param, + h_param, + w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax); + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax)); + } + // Step 5: Matmul softmax(QK.T/sqrt(d_k)) by V. Implemented as V @ + // softmax(QK.T/sqrt(d_k)).T + { + DT alpha = 1.0f, beta = 0.0f; + // after transpositions + int m_ = m->vProjSize; + int n = num_new_tokens; + int k = total_tokens; + // before transpositions + int lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + // N.B. strides are applied before transpose operations + int strideA = vt_block_size; + int strideB = num_new_tokens * total_tokens; + int strideC = m->vProjSize; + // matrix A: value cache + // matrix A's layout: [vProjSize, num_heads, total_tokens] + // To get A, skip over V.T entries from previous requests (all heads + + // padding) + DT *A = static_cast
(m->valueCache) + i * vt_req_block_size; + // matrix B: qk_prods_softmax + // matrix B's layout: [num_new_tokens, total_tokens, num_heads] + // To get B, skip over softmax(QK.T/sqrt(d_k)) entries from previous + // requests (all heads) + DT *B = static_cast
(m->qk_prods_softmax); + ; + // matrix C: attn heads + // matrix C's layout: [vProjSize, num_heads, num_new_tokens] + // To get C, skip over softmax(QK.T/sqrt(d_k))V products from previous + // requests + // store the result attn heads, also skip the genration tokens + DT *C = static_cast
(m->attn_heads) + + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } tokens_previous_requests += num_new_tokens; } assert(tokens_previous_requests == num_tokens); @@ -1255,7 +1304,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( assert(kSize == vSize); qProjSize = _qProjSize; kProjSize = _kProjSize; - assert(qProjSize == kProjSize); // required for attention QK^T matmul + assert(qProjSize == kProjSize); // required for attention QK.T matmul vProjSize = _vProjSize; oProjSize = _oProjSize; size_t size_of_dt = data_type_size(attn->data_type); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 6dad1c6de9..562dee4d93 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -492,7 +492,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens() * BeamSearchBatchConfig::MAX_BEAM_WIDTH; - + compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); }