diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index f5b17a5c99..5e68a65d8c 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -54,6 +54,8 @@ class BatchConfig { static BatchConfig const *from_future(BatchConfigFuture const &future); static int const MAX_NUM_REQUESTS = 1; static int const MAX_NUM_TOKENS = 64; + static int const MAX_PROMPT_LENGTH = + 63; // should be MAX_NUM_TOKENS - 1 for SpecInfer static int const MAX_SEQ_LENGTH = 256; // These are set by update diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 50f4fceeb1..514d9d8c6e 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -127,8 +127,22 @@ RequestManager::RequestGuid Request request; request.guid = next_available_guid++; request.max_sequence_length = max_sequence_length; - request.initial_len = prompt.size(); - request.tokens = prompt; + + if (prompt.size() > BatchConfig::MAX_PROMPT_LENGTH) { + std::cout << "Warning: too many tokens in prompt, only load up to " + << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " + << prompt.size() << ".\n"; + // Truncate the prompt to MAX_NUM_TOKENS + request.tokens.insert(request.tokens.end(), + prompt.begin(), + prompt.begin() + BatchConfig::MAX_PROMPT_LENGTH); + request.initial_len = BatchConfig::MAX_PROMPT_LENGTH; + printf("tokens size: %zu\n", request.tokens.size()); + // assert(false); + } else { + request.initial_len = prompt.size(); + request.tokens = prompt; + } if (get_num_ssms() == 0) { std::cout << "No small speculative model registered yet, using incremental " @@ -175,12 +189,22 @@ RequestManager::RequestGuid request.tokens.push_back(this->model_bos_map.at(this->model_type)); std::vector tokens = this->tokenizer_->Encode(prompt); + if (tokens.size() > BatchConfig::MAX_PROMPT_LENGTH) { + std::cout << "Warning: too many tokens in prompt, only load up to " + << BatchConfig::MAX_PROMPT_LENGTH << " tokens, but got " + << tokens.size() << ".\n"; + // Truncate the prompt to MAX_NUM_TOKENS + tokens.resize(BatchConfig::MAX_PROMPT_LENGTH); + printf("tokens size: %zu\n", tokens.size()); + // assert(false); + } + for (int i = 0; i < tokens.size(); i++) { - std::cout << tokens.at(i) << "\n"; + std::cout << "[" << i << "]" << tokens.at(i) << "\n"; } // assert(false); - request.tokens.insert(request.tokens.end(), tokens.begin(), tokens.end()); + request.tokens = tokens; request.initial_len = request.tokens.size(); if (get_num_ssms() == 0) { @@ -809,6 +833,7 @@ BeamSearchBatchConfig (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; + // add profile_info for the new request ProfileInfo profile_info; profile_info.decoding_steps = 0; @@ -818,6 +843,10 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].beam_size = BeamSearchBatchConfig::MAX_BEAM_WIDTH; new_bc.beamRequestsInfo[i].current_depth = 1; + new_bc.beamRequestsInfo[i].max_depth = + std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH, + BatchConfig::MAX_NUM_TOKENS - + new_bc.requestsInfo[i].num_tokens_in_batch - 1); for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; @@ -947,7 +976,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; } - if (new_bc.num_tokens == BatchConfig::MAX_NUM_TOKENS) { + + std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl; + if (new_bc.num_tokens >= BatchConfig::MAX_NUM_TOKENS) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; @@ -1065,6 +1096,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( break; } } + + std::cout << "new_bc.num_tokens: " << new_bc.num_tokens << std::endl; } if (verbose) { @@ -1346,9 +1379,11 @@ std::vector> log_req_mgr.print("(%d, %d)", pair.first, pair.second); } - assert(inputSerializedTree.size() == outputSerializedTree.size()); + // It's safe to have inputSerializedTree.size() > outputSerializedTree.size() + // In this case the inputSeriedTree ends with padding 0s + assert(inputSerializedTree.size() >= outputSerializedTree.size()); - for (int i = 0; i < inputSerializedTree.size(); i++) { + for (int i = 0; i < outputSerializedTree.size(); i++) { auto input = inputSerializedTree.at(i); auto output = outputSerializedTree.at(i); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 1ca466ff91..1bbda58b02 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -31,7 +31,16 @@ void RequestManager::load_tokens_task( // BatchConfig const batch_config = *((BatchConfig *)task->args); BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); BatchConfig::TokenId dram_copy[BatchConfig::MAX_NUM_TOKENS]; - assert(batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS); + + // Extreme long prompts are not supported, only load up to MAX_NUM_TOKENS as + // prompt + if (batch_config->num_tokens > BatchConfig::MAX_NUM_TOKENS) { + printf("Warning: too many tokens in prompt, only load up to %d tokens\n", + BatchConfig::MAX_NUM_TOKENS); + printf("Got: %d tokens\n", batch_config->num_tokens); + } + // assert(batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS); + for (int i = 0; i < batch_config->num_tokens; i++) { dram_copy[i] = batch_config->tokensInfo[i].token_id; } @@ -42,11 +51,14 @@ void RequestManager::load_tokens_task( assert(batch_config->num_tokens <= domain.get_volume()); cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - checkCUDA(cudaMemcpyAsync(fb_ptr, - dram_copy, - sizeof(TokenId) * batch_config->num_tokens, - cudaMemcpyHostToDevice, - stream)); + checkCUDA( + cudaMemcpyAsync(fb_ptr, + dram_copy, + batch_config->num_tokens <= BatchConfig::MAX_NUM_TOKENS + ? sizeof(TokenId) * batch_config->num_tokens + : sizeof(TokenId) * BatchConfig::MAX_NUM_TOKENS, + cudaMemcpyHostToDevice, + stream)); } void RequestManager::load_positions_task(