Skip to content

Commit

Permalink
Merge branch 'inference' into cuda_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored Jan 7, 2024
2 parents 3157405 + 7b00e81 commit f151532
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 51 deletions.
23 changes: 14 additions & 9 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ class BeamSearchBatchConfig : public BatchConfig {
int current_depth = -1;
int max_depth = MAX_BEAM_DEPTH;

BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
BatchConfig::TokenId
tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_id[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int sub_request_num;
};

Expand All @@ -178,10 +179,11 @@ class BeamSearchBatchConfig : public BatchConfig {
};

BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS];
BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH];
BeamSearchPerTokenInfo
beamTokenInfo[MAX_NUM_TOKENS +
MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS];

// why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH?
int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH];
int sub_requests[MAX_NUM_REQUESTS];

private:
size_t current_iteration;
Expand All @@ -190,9 +192,12 @@ class BeamSearchBatchConfig : public BatchConfig {
struct BeamInferenceResult {
static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS;
BatchConfig::TokenId
token_ids[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
float probs[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_id[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH];
token_ids[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
float probs[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_id[MAX_NUM_TOKENS *
BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
};

}; // namespace FlexFlow
2 changes: 1 addition & 1 deletion include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct BeamTree {
struct treeLayer {
BeamSearchBatchConfig::TokenId
tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH];
int parent_ids[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES];
int nodes_num_this_layer = 0;
};
Expand Down
42 changes: 11 additions & 31 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ __global__ void compute_attention_kernel_generation_kernel(
int max_seq_length,
int per_head_size,
int hidden_size,
BatchConfig::PerRequestInfo *request_infos,
bool is_beam,
int max_beam_width) {
BatchConfig::PerRequestInfo *request_infos) {

// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
Expand Down Expand Up @@ -85,10 +83,6 @@ __global__ void compute_attention_kernel_generation_kernel(
int const batch_config_request_id =
request_infos[request_idx].batch_config_request_id;

int const beam_request_idx =
is_beam ? request_idx / max_beam_width : request_idx;
int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0;

int const first_step = 0;

int const tlength =
Expand All @@ -106,8 +100,7 @@ __global__ void compute_attention_kernel_generation_kernel(
// first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum
__shared__ float red_smem[WARPS_PER_BLOCK * 2];

const DT *q_ptr = query +
batch_config_request_id * hidden_size * QKV_WEIGHT_NUM +
const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM +
head_idx * per_head_size;
__shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD];
// DT const *q_ptr =
Expand Down Expand Up @@ -142,10 +135,7 @@ __global__ void compute_attention_kernel_generation_kernel(
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;

DT const *k_cache_batch =
key_cache +
(batch_config_request_id * max_beam_width + beam_sub_request_idx) *
max_seq_length * hidden_size +
ki;
key_cache + batch_config_request_id * max_seq_length * hidden_size + ki;

int ti_end =
div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
Expand Down Expand Up @@ -248,10 +238,7 @@ __global__ void compute_attention_kernel_generation_kernel(

// The base pointer for the value in the cache buffer.
DT const *v_cache_batch =
value_cache +
(batch_config_request_id * max_beam_width + beam_sub_request_idx) *
max_seq_length * hidden_size +
vi;
value_cache + batch_config_request_id * max_seq_length * hidden_size + vi;

if (Dh == Dh_MAX || vi < Dh) {
for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
Expand Down Expand Up @@ -297,7 +284,7 @@ __global__ void compute_attention_kernel_generation_kernel(
// Output the final values.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
convert_from_float(
*reinterpret_cast<V_vec *>(output_ptr + beam_request_idx * hidden_size +
*reinterpret_cast<V_vec *>(output_ptr + request_idx * hidden_size +
head_idx * per_head_size + vi),
out);
}
Expand Down Expand Up @@ -727,9 +714,7 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig::max_sequence_length(), \
m->qProjSize, \
m->hidden_size, \
m->request_infos, \
false, \
0)
m->request_infos)
template <typename DT>
void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m,
Expand Down Expand Up @@ -944,14 +929,9 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
assert(m->qProjSize == m->kProjSize);
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
} else if (tokens_previous_requests < bc->num_generation_tokens) {
tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch;
if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase)) {
continue;
}
assert(tokens_previous_requests ==
bc->requestsInfo[i].first_token_offset_in_batch);
int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int total_tokens = bc->requestsInfo[i].first_token_depth_in_request +
bc->requestsInfo[i].num_tokens_in_batch;
Expand All @@ -978,8 +958,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
// matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens]
// To get query projection, skip over Q entries from previous requests
DT const *A = static_cast<DT *>(m->devQKVProjArray) +
tokens_previous_requests * m->qProjSize * m->num_q_heads *
QKV_WEIGHT_NUM;
bc->requestsInfo[i].first_token_offset_in_batch *
m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM;
// matrix B: key cache
// matrix B's layout: [kProjSize * num_heads, total_tokens]
// To get B, skip over K entries from previous requests (all heads +
Expand Down Expand Up @@ -1117,7 +1097,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
// requests
// store the result attn heads, also skip the genration tokens
DT *C = static_cast<DT *>(m->attn_heads) +
(tokens_previous_requests + bc->num_generation_tokens) *
(bc->requestsInfo[i].first_token_offset_in_batch) *
m->num_q_heads * m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
Expand Down Expand Up @@ -1145,7 +1125,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m,
}
tokens_previous_requests += num_new_tokens;
}
assert(tokens_previous_requests == num_tokens);
assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens));
}
/*static*/
Expand Down
5 changes: 3 additions & 2 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
assert(m->qProjSize == m->kProjSize);
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase) ||
(bc->requestsInfo[i].num_tokens_in_batch == 0)) {
continue;
} else if (tokens_previous_requests < bc->num_generation_tokens) {
tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch;
Expand Down Expand Up @@ -694,7 +695,7 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
tokens_prev_requests_squares += num_new_tokens * total_tokens;
}
// assert(tokens_previous_requests == num_tokens);
assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens));
}
template <typename DT>
Expand Down
24 changes: 16 additions & 8 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,14 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
// Incremental phase
new_bc.requestsInfo[i].num_tokens_in_batch = 1;
num_generation_tokens++;
new_bc.requestsInfo[i].prompt_phase = false;
} else {
// Prompt phase
new_bc.requestsInfo[i].num_tokens_in_batch =
std::min(get_max_tokens_per_batch() - new_bc.num_tokens,
(int)request.tokens.size() -
new_bc.requestsInfo[i].first_token_depth_in_request);
new_bc.requestsInfo[i].prompt_phase = true;
}
for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) {
int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j;
Expand Down Expand Up @@ -509,6 +511,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc,
new_bc.requestsInfo[i].max_sequence_length =
new_request.max_sequence_length;
new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].prompt_phase = true;
num_active_req++;
new_bc.requestsInfo[num_active_req].batch_config_request_id = i;
// add profile_info for the new request
Expand Down Expand Up @@ -755,6 +758,7 @@ BeamSearchBatchConfig
new_bc.beamRequestsInfo[i].current_depth = 1;

profiling_requests[request.guid].ssm_decoding_steps = 0;
new_bc.requestsInfo[i].prompt_phase = true;

int ssm_decoding_steps = 0;
new_bc.beamRequestsInfo[i].beam_size =
Expand All @@ -763,7 +767,9 @@ BeamSearchBatchConfig
: 1;
new_bc.beamRequestsInfo[i].max_depth =
std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH);
for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) {
for (int j = 0;
j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES;
j++) {
new_bc.beamRequestsInfo[i].parent_id[j] = 0;
new_bc.beamRequestsInfo[i].probs[j] = 1;
}
Expand Down Expand Up @@ -836,7 +842,8 @@ BeamSearchBatchConfig
? spec_infer_tree_width[ssm_decoding_steps]
: 1;
new_bc.beamRequestsInfo[i].max_depth = 0;
for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) {
for (int j = 0; j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES;
j++) {
new_bc.beamRequestsInfo[i].parent_id[j] = 0;
new_bc.beamRequestsInfo[i].probs[j] = 1;
}
Expand Down Expand Up @@ -896,12 +903,15 @@ BeamSearchBatchConfig
std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH,
get_max_tokens_per_batch() -
new_bc.requestsInfo[i].num_tokens_in_batch - 1);
for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) {
for (int j = 0;
j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES;
j++) {
new_bc.beamRequestsInfo[i].parent_id[j] = 0;
new_bc.beamRequestsInfo[i].probs[j] = 1;
}

new_bc.request_completed[i] = false;
new_bc.requestsInfo[i].prompt_phase = true;

new_bc.beamRequestsInfo[i].sub_request_num = 1;
printf("sub request num == 1, %d \n",
Expand Down Expand Up @@ -1188,10 +1198,7 @@ BeamSearchBatchConfig
int ssm_decoding_steps =
profiling_requests[request.guid].ssm_decoding_steps;

new_bc.beamRequestsInfo[i].beam_size =
spec_infer_tree_width.size() > ssm_decoding_steps
? spec_infer_tree_width[ssm_decoding_steps]
: 1;
new_bc.beamRequestsInfo[i].beam_size = 1;
// printf("beam size: %d, %d\n",
// new_bc.beamRequestsInfo[i].beam_size,
// ssm_decoding_steps);
Expand Down Expand Up @@ -1223,6 +1230,7 @@ BeamSearchBatchConfig
&old_bc.causalMask[i],
sizeof(BatchConfig::BitMask));

new_bc.requestsInfo[i].prompt_phase = true;
if (new_bc.requestsInfo[i].first_token_depth_in_request >=
request.tokens.size()) {
// request is done
Expand Down Expand Up @@ -1820,7 +1828,7 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask,
void RequestManager::appendPendingRequest(BatchConfig::BitMask &bitmask,
int initLength) {
assert(initLength > 0);
std::cout << "append pending bit mask: " << initLength << "\n";
// std::cout << "append pending bit mask: " << initLength << "\n";
// eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4:
// 0000000..1000
bitmask.non_tree_cache_size = 0;
Expand Down

0 comments on commit f151532

Please sign in to comment.