diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 4ad735ef7d..8fcee5e2f6 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -277,6 +277,7 @@ enum TaskIDs { RM_PREPARE_NEXT_BATCH_INIT_TASK_ID, RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID, + RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID, RM_BACKGROUND_SERVING_TASK_ID, // Custom tasks CUSTOM_GPU_TASK_ID_FIRST, diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index f0fab957ee..45731efe33 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -19,6 +19,7 @@ #include "flexflow/inference.h" #include "flexflow/model.h" #include "flexflow/utils/file_loader.h" +#include "suffix_decoding.h" #include #include #include @@ -164,6 +165,7 @@ class RequestManager { void serve_incr_decoding(FFModel *model); void serve_spec_infer(FFModel *model); + void serve_suffix_decoding(FFModel *model); GenerationResult get_generation_result(RequestGuid const &guid); RequestGuid register_new_request(Request const &request_); RequestGuid register_new_peft_request(Request const &request_); @@ -210,6 +212,15 @@ class RequestManager { Legion::Context ctx, Legion::Runtime *runtime); + TreeVerifyBatchConfig + prepare_next_batch_suffix_decode(TreeVerifyBatchConfig const &old_bc, + InferenceResult const &result); + TreeVerifyBatchConfigFuture prepare_next_batch_suffix_decode( + TreeVerifyBatchConfigFuture const &old_bc, + InferenceResultFuture const &result, + Legion::Context ctx, + Legion::Runtime *runtime); + void store_beam_metadata(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result); void update_beam_metadata(BeamSearchBatchConfig &new_bc, @@ -280,6 +291,12 @@ class RequestManager { Legion::Context ctx, Legion::Runtime *runtime); + static TreeVerifyBatchConfig prepare_next_batch_suffix_decode_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + private: // configuration parameters int max_requests_per_batch; @@ -295,6 +312,8 @@ class RequestManager { // tree width in each speculative step, if not specified 1 std::vector spec_infer_tree_width; + SuffixTree *suffix_tree; + // private fields std::unique_ptr tokenizer_; bool verbose; diff --git a/src/mapper/mapper.cc b/src/mapper/mapper.cc index d7b9a5e99d..d321aeb583 100644 --- a/src/mapper/mapper.cc +++ b/src/mapper/mapper.cc @@ -284,6 +284,7 @@ void FFMapper::select_task_options(const MapperContext ctx, (task.task_id == RM_PREPARE_NEXT_BATCH_INIT_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID) || + (task.task_id == RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID) || (task.task_id == RM_BACKGROUND_SERVING_TASK_ID)) { output.initial_proc = all_cpus[0]; return; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 52f1dd2220..4e8aebe2aa 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4822,6 +4822,27 @@ void register_flexflow_internal_tasks(Runtime *runtime, RequestManager::prepare_next_batch_verify_task>(registrar); } } + // RequestManager prepare_next_batch_suffix_decode + { + TaskVariantRegistrar registrar( + RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID, + "RequestManager Prepare Next Batch (Suffix Decode)"); + registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant< + TreeVerifyBatchConfig, + RequestManager::prepare_next_batch_suffix_decode_task>( + registrar, "RequestManager Prepare Next Batch (Suffix Decode) Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant< + TreeVerifyBatchConfig, + RequestManager::prepare_next_batch_suffix_decode_task>(registrar); + } + } // RequestManager background serving task { TaskVariantRegistrar registrar(RM_BACKGROUND_SERVING_TASK_ID, diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index f13277ddd1..90d060c97f 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -18,7 +18,6 @@ #include "flexflow/ops/lora_linear.h" #include "flexflow/parallel_ops/parallel_op.h" // #include "flexflow/tokenizers.h" -#include "suffix_decoding.h" #include #include #include @@ -97,6 +96,124 @@ std::ostream &operator<<(std::ostream &os, Request const &req) { bool RequestManager::inference_finished = false; +DailSqlTrace load_trace_dail_sql(std::string const &trace_filepath) { + std::filesystem::path cwd = std::filesystem::current_path(); + std::ifstream file(trace_filepath); + assert(file.good() && "File does not exist or cannot be opened"); + + nlohmann::json data; + try { + file >> data; + } catch (nlohmann::json::parse_error &e) { + std::cerr << "JSON parse error: " << e.what() << std::endl; + assert(false); + } + std::cout << "finished loading json file: " << trace_filepath << std::endl; + + DailSqlTrace trace; + for (auto const &question : data["questions"]) { + trace.prompts.push_back(question["prompt"]); + trace.responses.push_back(question["response"]); + } + + assert(trace.prompts.size() == trace.responses.size()); + return trace; +} + +std::string replaceSlashes(std::string str) { + size_t pos = 0; + while ((pos = str.find('/', pos)) != std::string::npos) { + str.replace(pos, 1, "--"); + pos += 2; + } + return str; +} + +std::string getHFHome() { + if (char const *env_p = std::getenv("HF_HOME")) { + return std::string(env_p); + } else { + std::filesystem::path home = std::filesystem::path(getenv("HOME")); + return (home / ".cache" / "huggingface").string(); + } +} + +std::string get_tokenizer_base_folder(std::string const &model_name) { + // Replace '/' with '--' in the model name + std::string model_name_without_slashes = model_name; + model_name_without_slashes = replaceSlashes(model_name_without_slashes); + + // Construct the base path + std::string hf_home = getHFHome(); + + std::string base_path = + hf_home + "/hub/models--" + model_name_without_slashes + "/snapshots/"; + + // Find the first subfolder in the snapshots directory + std::string first_folder_name; + for (auto const &entry : fs::directory_iterator(base_path)) { + if (fs::is_directory(entry)) { + first_folder_name = entry.path().filename().string(); + break; + } + } + return base_path + first_folder_name; +} + +std::string get_tokenizer_path(std::string const &model_name) { + std::string base_folder = get_tokenizer_base_folder(model_name); + + if (fs::exists(base_folder + "/tokenizer.model")) { + return base_folder + "/tokenizer.model"; + } else if (fs::exists(base_folder + "/tokenizer.json")) { + return base_folder + "/tokenizer.json"; + } else { + assert(false); + } +} + +int get_bos_token_id(std::string const &model_name) { + std::string base_folder = get_tokenizer_base_folder(model_name); + std::string filename = base_folder + "/config.json"; + if (!fs::exists(filename)) { + assert(false && "config.json not found"); + return -1; + } + // Read the JSON file + std::ifstream file(filename); + if (!file.is_open()) { + assert(false && "Unable to open file"); + return -1; + } + + // Parse JSON + nlohmann::json j; + file >> j; + + // Get the bos_token_id value + if (j.contains("bos_token_id")) { + return j["bos_token_id"].get(); + } else { + assert(false && "bos_token_id not found in JSON"); + return -1; + } + return -1; +} + +auto get_tokenizer(std::string const &model_name) { + std::string tokenizer_path = get_tokenizer_path(model_name); + // if the tokenizer_path ends with the ".json" extension: + if (tokenizer_path.find("tokenizer.json") != std::string::npos) { + auto blob = LoadBytesFromFile(tokenizer_path); + return Tokenizer::FromBlobJSON(blob); + } else if (tokenizer_path.find("tokenizer.model") != std::string::npos) { + auto blob = LoadBytesFromFile(tokenizer_path); + return Tokenizer::FromBlobSentencePiece(blob); + } else { + assert(false); + } +} + RequestManager::RequestManager() : request_manager_status(INITIALIZED), verbose(false), next_available_guid(1000000), num_processed_requests(0), @@ -110,6 +227,30 @@ RequestManager::RequestManager() max_tokens_per_batch = -1; max_spec_tree_token_num = -1; max_sequence_length = -1; + + std::string model_name = "meta-llama/Meta-Llama-3-70B"; + DailSqlTrace dail_sql_trace = + load_trace_dail_sql("/usr/suffix-tree-decoding/trace/spider.json"); + int num_prompts = dail_sql_trace.prompts.size(); + int num_responses = dail_sql_trace.responses.size(); + assert(num_prompts == num_responses); + auto tokenizer = get_tokenizer(model_name); + int bos_token_id = get_bos_token_id(model_name); + int train_size = num_prompts / 2; + std::cout << "Number of prompts: " << num_prompts << std::endl; + std::cout << "Train size: " << train_size << std::endl; + std::vector> training_dataset; + for (int i = 0; i < train_size; i++) { + std::string text = dail_sql_trace.prompts[i] + dail_sql_trace.responses[i]; + std::vector encoded = tokenizer->Encode(text); + encoded.insert(encoded.begin(), bos_token_id); + training_dataset.push_back(encoded); + } + suffix_tree = new SuffixTree(50); + for (auto const &text : training_dataset) { + suffix_tree->insert(text, suffix_tree->query_guid); + suffix_tree->query_guid++; + } } void RequestManager::set_max_requests_per_batch(int max_num_requests) { @@ -1721,6 +1862,38 @@ BeamSearchBatchConfig return new_bc; } +TreeVerifyBatchConfigFuture RequestManager::prepare_next_batch_suffix_decode( + TreeVerifyBatchConfigFuture const &old_bc, + InferenceResultFuture const &result, + Context ctx, + Runtime *runtime) { + + RequestManager *rm = this; + TaskLauncher launcher(RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID, + TaskArgument(&rm, sizeof(RequestManager *))); + launcher.add_future(old_bc); + launcher.add_future(result); + return runtime->execute_task(ctx, launcher); +} + +TreeVerifyBatchConfig RequestManager::prepare_next_batch_suffix_decode_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + RequestManager *rm = *((RequestManager **)task->args); + TreeVerifyBatchConfig const &bc = + Future(task->futures[0]).get_result(); + InferenceResult const &result = + Future(task->futures[1]).get_result(); + return rm->prepare_next_batch_suffix_decode(bc, result); +} + +TreeVerifyBatchConfig RequestManager::prepare_next_batch_suffix_decode( + TreeVerifyBatchConfig const &old_bc, InferenceResult const &result) { + const std::lock_guard lock(request_queue_mutex); +} + /***** Verify Phase *****/ TreeVerifyBatchConfigFuture RequestManager::prepare_next_batch_verify( @@ -2802,8 +2975,15 @@ void RequestManager::background_serving_task( "###PEFT DEBUGGING### Updated models' configuration."); if (rm->get_num_ssms() == 0) { - // No SSMs: perform incremental decoding - rm->serve_incr_decoding(llm); + + char const *env_var = std::getenv("FF_SUFFIX_DECODING"); + + if (env_var != nullptr && std::string(env_var) == "1") { + rm->serve_suffix_decoding(llm); + } else { + // No SSMs: perform incremental decoding + rm->serve_incr_decoding(llm); + } } else { // Registered SSMs: perform speculative inference rm->serve_spec_infer(llm); @@ -3004,33 +3184,41 @@ void RequestManager::serve_spec_infer(FFModel *llm) { /*static*/ void RequestManager::serve_suffix_decoding(FFModel *llm) { + + // Check if the model object exists + if (llm == nullptr) { + std::cout << "###PEFT DEBUGGING### LLM Model object does not exist." + << std::endl; + return; // Early return to prevent further operations on a nullptr + } else { + std::cout << "###PEFT DEBUGGING### LLM Model object exists." << std::endl; + } + Context ctx = llm->config.lg_ctx; Runtime *runtime = llm->config.lg_hlr; + // Compile the llm InferenceManager *im = InferenceManager::get_inference_manager(); + im->compile_model_and_allocate_buffer(llm); + assert(im->model_weights_loaders.find(llm) != + im->model_weights_loaders.end()); + // Load model weights + im->model_weights_loaders[llm]->load_weights(llm); + // init operators + im->init_operators_inference(llm); + // Legion futures for inc_decoding and spec_infer + TreeVerifyBatchConfigFuture last_bcf; + InferenceResultFuture last_irf; { - // Compile the llm - im->compile_model_and_allocate_buffer(llm); - assert(im->model_weights_loaders.find(llm) != - im->model_weights_loaders.end()); - // Load model weights - im->model_weights_loaders[llm]->load_weights(llm); - // init operators - im->init_operators_inference(llm); + // Initialize futures for incr decoding + TreeVerifyBatchConfig bc; + InferenceResult ir; + last_bcf = Future::from_value(bc); + last_irf = Future::from_value(ir); } std::queue> batch_pipeline; - // Legion futures for inc_decoding and spec_infer - TreeVerifyBatchConfigFuture last_tree_bcf; - InferenceResultFuture last_tree_irf; - { - // Initialize futures for spec infer - TreeVerifyBatchConfig tree_bc; - InferenceResult tree_ir; - last_tree_bcf = Future::from_value(tree_bc); - last_tree_irf = Future::from_value(tree_ir); - } - batch_pipeline.push(std::make_pair(last_tree_bcf, last_tree_irf)); + { batch_pipeline.push(std::make_pair(last_bcf, last_irf)); } while (!is_background_server_terminated()) { @@ -3048,27 +3236,16 @@ void RequestManager::serve_suffix_decoding(FFModel *llm) { break; } } - runtime->begin_trace(ctx, 12347 /*trace_id*/); auto const &next_batch = batch_pipeline.back(); - - BeamSearchBatchConfigFuture beam_bcf = prepare_next_batch_init(next_batch.first, next_batch.second, 0, ctx, runtime); - FutureMap fm = im->suffix_decode(llm, 0, beam_bcf); + TreeVerifyBatchConfigFuture bcf = prepare_next_batch_suffix_decode( + next_batch.first, next_batch.second, ctx, runtime); + FutureMap fm = im->inference(llm, 0, bcf); assert(fm.get_future_map_domain().get_volume() == 1); - BeamInferenceResultFuture beam_irf = fm.get_future(0); - beam_bcf = prepare_next_batch_beam(beam_bcf, beam_irf, ctx, runtime); - std::vector beam_bcf_vec(1); - beam_bcf_vec[0] = beam_bcf; - // Token Tree Verification - { - TreeVerifyBatchConfigFuture tree_bcf = prepare_next_batch_verify(beam_bcf_vec, ctx, runtime); - FutureMap fm = im->inference(llm, 0, tree_bcf); - assert(fm.get_future_map_domain().get_volume() == 1); - InferenceResultFuture tree_irf = fm.get_future(0); - batch_pipeline.push(std::make_pair(tree_bcf, tree_irf)); - last_tree_bcf = tree_bcf; - last_tree_irf = tree_irf; - } + InferenceResultFuture irf = fm.get_future(0); + batch_pipeline.push(std::make_pair(bcf, irf)); + last_bcf = bcf; + last_irf = irf; runtime->end_trace(ctx, 12347 /*trace_id*/); } }