diff --git a/examples/cpp/inference/dataloader.cu b/examples/cpp/inference/dataloader.cu index 7fb3478020..434dc337c9 100644 --- a/examples/cpp/inference/dataloader.cu +++ b/examples/cpp/inference/dataloader.cu @@ -15,6 +15,7 @@ #include "dataloader.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include "flexflow/utils/cuda_helper.h" void DataLoader::load_input(Task const *task, diff --git a/examples/cpp/inference/mixture_of_experts/moe.cc b/examples/cpp/inference/mixture_of_experts/moe.cc index ff3f6bb53a..4a5c33c9b0 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.cc +++ b/examples/cpp/inference/mixture_of_experts/moe.cc @@ -15,6 +15,7 @@ #include "moe.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include #include #include diff --git a/examples/cpp/inference/transformers/transformers.cc b/examples/cpp/inference/transformers/transformers.cc index 074e832d47..0717ddc90f 100644 --- a/examples/cpp/inference/transformers/transformers.cc +++ b/examples/cpp/inference/transformers/transformers.cc @@ -15,6 +15,7 @@ #include "transformers.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include #include #include diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index c6d7e3e2bb..f5b17a5c99 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -52,7 +52,7 @@ class BatchConfig { void print() const; virtual InferenceMode get_mode() const; static BatchConfig const *from_future(BatchConfigFuture const &future); - static int const MAX_NUM_REQUESTS = 16; + static int const MAX_NUM_REQUESTS = 1; static int const MAX_NUM_TOKENS = 64; static int const MAX_SEQ_LENGTH = 256; diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index fd0235cbcc..c30b0c0be3 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -14,20 +14,10 @@ */ #pragma once - #include "flexflow/batch_config.h" -#include "flexflow/model.h" -#include -#include -#include namespace FlexFlow { -class FFModel; -class BeamTree; -class RequestManager; -using tokenizers::Tokenizer; - struct SamplingConfig { bool do_sample = false; float temperature = 0.8; @@ -50,210 +40,4 @@ struct GenerationResult { std::vector output_tokens; }; -class InferenceManager { -public: - InferenceManager(FFConfig const &config, int max_num_tokens_per_batch); - static InferenceManager *get_inference_manager(); - void compile_model_and_allocate_buffer(FFModel *model); - void init_operators_inference(FFModel *model); - MachineView *get_machine_view(int mv_id); - Legion::FutureMap inference(FFModel *model, int index, BatchConfig const &bc); - Legion::FutureMap - inference(FFModel *model, int index, BatchConfigFuture const &bc); - void load_input_tokens_from_batch_config(BatchConfigFuture const &bc, - ParallelTensor const input); - void load_positions(BatchConfigFuture const &bc, - ParallelTensor position_input); - void incr_decoding_loop(FFModel *model, - RequestManager &rm, - int total_num_requests); - void spec_inference_loop(FFModel *model, - RequestManager &rm, - int total_num_requests, - std::vector ssm_model_ids); - -public: - FFConfig ff_config; - std::unordered_map> tensor_buffer; - int max_num_tokens_per_batch; - int num_devices; - std::vector machine_views; -}; - -struct Request { - BatchConfig::RequestGuid guid; - int max_sequence_length; - int initial_len; - std::vector tokens; - - std::vector beam_trees; - std::promise *promise; -}; - -// store the result of beam search -struct BeamTree { - struct treeLayer { - BeamSearchBatchConfig::TokenId - tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - }; - treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; -}; - -// struct BeamTree_v2 { -// std::vector tokens; -// std::vector parent_ids; -// std::vector probs; -// }; - -class RequestManager { -public: - using RequestGuid = BatchConfig::RequestGuid; - using TokenId = BatchConfig::TokenId; - RequestManager(ModelType model_type, - std::string const &path, - bool verbose = false, - std::string output_filepath = ""); - RequestManager(); - static RequestManager *get_request_manager(); - size_t get_num_processed_requests(); - - int register_new_model(FFModel *model); - void register_tokenizer(ModelType model_type, std::string const &path); - void register_output_filepath(std::string const &); - - FFModel *get_model(int model_id); - void serve(FFModel *model); - - static GenerationResult generate(std::string const &text, int max_seq_length); - RequestGuid register_new_request(std::string const &prompt, - int max_sequence_length); - RequestGuid register_new_request(std::vector const &prompt, - int max_sequence_length); - BatchConfig prepare_next_batch(BatchConfig const &bc, - InferenceResult const &result); - BatchConfigFuture prepare_next_batch(BatchConfigFuture const &bc, - InferenceResultFuture const &result); - BeamSearchBatchConfig - prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, - BeamInferenceResult const &result); - BeamSearchBatchConfigFuture - prepare_next_batch_beam(BeamSearchBatchConfigFuture const &old_bc, - BeamInferenceResultFuture const &result); - BeamSearchBatchConfig - prepare_next_batch_init(TreeVerifyBatchConfig const &old_bc, - InferenceResult const &result, - int model_id); - BeamSearchBatchConfigFuture - prepare_next_batch_init(TreeVerifyBatchConfigFuture const &old_bc, - InferenceResultFuture const &result, - int model_id); - TreeVerifyBatchConfig prepare_next_batch_verify( - std::vector const &old_batches); - TreeVerifyBatchConfigFuture prepare_next_batch_verify( - std::vector const &old_batches); - - void store_beam_metadata(BeamSearchBatchConfig const &old_bc, - BeamInferenceResult const &result); - void update_beam_metadata(BeamSearchBatchConfig &new_bc, - BeamTree &tree, - int request_index); - - std::vector> - traverse_beam_tree(BeamSearchBatchConfig const &old_bc, - int request_index, - int token_start_offset); - - // remove guid after put the cached tree in request - std::vector> merge_dfs_trees( - std::vector>> - input_trees, - int root_depth, - RequestGuid guid); - - std::vector> traverse_verify_tree( - size_t guid, - std::vector> const - &inputSerializedTree, - std::vector> const - &outputSerializedTree); - - static void - load_tokens_task(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - static void - load_positions_task(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - static BatchConfig prepare_next_batch_task( - Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - static BeamSearchBatchConfig prepare_next_batch_beam_task( - Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - static BeamSearchBatchConfig prepare_next_batch_init_task( - Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - static TreeVerifyBatchConfig prepare_next_batch_verify_task( - Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - - static void llm_serving_background_task( - Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - -private: - std::unique_ptr tokenizer_; - bool verbose; - ModelType model_type; - std::string output_filepath; - std::queue pending_request_queue; - std::unordered_map all_requests; - std::unordered_map request_generation_results; - std::mutex request_queue_mutex; - RequestGuid next_available_guid; - const std::map model_bos_map = {{ModelType::LLAMA, 0}, - {ModelType::OPT, 2}}; - - // TODO: Move this two vector to request struct - std::unordered_map>> - dfs_tree_inputs; - std::unordered_map>> - committed_tokens; - - // Multi-model support - int num_ssms; - std::vector models; - - // Performance profiling - size_t num_processed_requests; - -private: - struct ProfileInfo { - int decoding_steps; - double start_time, finish_time; - }; - std::unordered_map profiling_requests; - double total_request_run_time; -}; - } // namespace FlexFlow diff --git a/include/flexflow/model.h b/include/flexflow/model.h index d39c03c0fd..fb71226118 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -17,6 +17,7 @@ #include "accessor.h" #include "config.h" #include "device.h" +#include "flexflow/inference.h" #include "flexflow/memory_optimization.h" #include "flexflow/node.h" #include "flexflow/operator_params.h" @@ -698,6 +699,10 @@ class FFModel { float scaling_factor = 1.0f, bool qk_prod_scaling = true, char const *name = NULL); + // ======================================== + // Inference APIs + // ======================================== + GenerationResult generate(std::string const &text, int max_seq_length); Tensor create_tensor_legion_ordering(int num_dim, int const dims[], diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index 6c6f0183eb..244100bc6f 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_H #define _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_H +#include "flexflow/accessor.h" #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/inference.h" diff --git a/include/flexflow/ops/inc_multiquery_self_attention.h b/include/flexflow/ops/inc_multiquery_self_attention.h index 3bbc684595..1e36876c57 100644 --- a/include/flexflow/ops/inc_multiquery_self_attention.h +++ b/include/flexflow/ops/inc_multiquery_self_attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_INC_MULTIQUERY_ATTENTION_H #define _FLEXFLOW_INC_MULTIQUERY_ATTENTION_H +#include "flexflow/accessor.h" #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/inference.h" diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index 71f7051b4c..c8c1c4c9cf 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_SPEC_INC_MULTIHEAD_SELF_ATTENTION_H #define _FLEXFLOW_SPEC_INC_MULTIHEAD_SELF_ATTENTION_H +#include "flexflow/accessor.h" #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/inference.h" diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index 1ace56e5d6..ba1d80dd60 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_VERIFY_H #define _FLEXFLOW_INC_MULTIHEAD_SELF_ATTENTION_VERIFY_H +#include "flexflow/accessor.h" #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/inference.h" @@ -9,6 +10,7 @@ #include "flexflow/op_meta.h" #include "flexflow/operator.h" #include "flexflow/ops/inc_multihead_self_attention.h" +#include "flexflow/ops/tree_inc_multihead_self_attention_params.h" #include "math.h" #include #include diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h new file mode 100644 index 0000000000..16e7b87227 --- /dev/null +++ b/include/flexflow/request_manager.h @@ -0,0 +1,243 @@ +/* Copyright 2023 CMU, Stanford, Facebook, LANL + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "flexflow/batch_config.h" +#include "flexflow/inference.h" +#include "flexflow/model.h" +#include +#include +#include + +namespace FlexFlow { + +class FFModel; +class BeamTree; +class RequestManager; +using tokenizers::Tokenizer; + +class InferenceManager { +public: + InferenceManager(FFConfig const &config, int max_num_tokens_per_batch); + static InferenceManager *get_inference_manager(); + void compile_model_and_allocate_buffer(FFModel *model); + void init_operators_inference(FFModel *model); + MachineView *get_machine_view(int mv_id); + Legion::FutureMap inference(FFModel *model, int index, BatchConfig const &bc); + Legion::FutureMap + inference(FFModel *model, int index, BatchConfigFuture const &bc); + void load_input_tokens_from_batch_config(BatchConfigFuture const &bc, + ParallelTensor const input); + void load_positions(BatchConfigFuture const &bc, + ParallelTensor position_input); + void incr_decoding_loop(FFModel *model, + RequestManager &rm, + int total_num_requests); + void spec_inference_loop(FFModel *model, + RequestManager &rm, + int total_num_requests, + std::vector ssm_model_ids); + +public: + FFConfig ff_config; + std::unordered_map> tensor_buffer; + int max_num_tokens_per_batch; + int num_devices; + std::vector machine_views; +}; + +struct Request { + BatchConfig::RequestGuid guid; + int max_sequence_length; + int initial_len; + std::vector tokens; + + std::vector beam_trees; + std::promise *promise; +}; + +// store the result of beam search +struct BeamTree { + struct treeLayer { + BeamSearchBatchConfig::TokenId + tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + }; + treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; +}; + +// struct BeamTree_v2 { +// std::vector tokens; +// std::vector parent_ids; +// std::vector probs; +// }; + +class RequestManager { +public: + using RequestGuid = BatchConfig::RequestGuid; + using TokenId = BatchConfig::TokenId; + RequestManager(ModelType model_type, + std::string const &path, + bool verbose = false, + std::string output_filepath = ""); + RequestManager(); + static RequestManager *get_request_manager(); + size_t get_num_processed_requests(); + size_t get_num_ssms(); + + int register_ssm_model(FFModel *model); + void register_tokenizer(ModelType model_type, std::string const &path); + void register_output_filepath(std::string const &); + + FFModel *get_model(int model_id); + static void serve(FFModel *model); + + GenerationResult generate_incr_decoding(FFModel *model, + std::string const &text, + int max_seq_length); + GenerationResult generate_spec_infer(FFModel *model, + std::string const &text, + int max_seq_length); + RequestGuid register_new_request(std::string const &prompt, + int max_sequence_length); + RequestGuid register_new_request(std::vector const &prompt, + int max_sequence_length); + BatchConfig prepare_next_batch(BatchConfig const &bc, + InferenceResult const &result); + BatchConfigFuture prepare_next_batch(BatchConfigFuture const &bc, + InferenceResultFuture const &result); + BeamSearchBatchConfig + prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, + BeamInferenceResult const &result); + BeamSearchBatchConfigFuture + prepare_next_batch_beam(BeamSearchBatchConfigFuture const &old_bc, + BeamInferenceResultFuture const &result); + BeamSearchBatchConfig + prepare_next_batch_init(TreeVerifyBatchConfig const &old_bc, + InferenceResult const &result, + int model_id); + BeamSearchBatchConfigFuture + prepare_next_batch_init(TreeVerifyBatchConfigFuture const &old_bc, + InferenceResultFuture const &result, + int model_id); + TreeVerifyBatchConfig prepare_next_batch_verify( + std::vector const &old_batches); + TreeVerifyBatchConfigFuture prepare_next_batch_verify( + std::vector const &old_batches); + + void store_beam_metadata(BeamSearchBatchConfig const &old_bc, + BeamInferenceResult const &result); + void update_beam_metadata(BeamSearchBatchConfig &new_bc, + BeamTree &tree, + int request_index); + + std::vector> + traverse_beam_tree(BeamSearchBatchConfig const &old_bc, + int request_index, + int token_start_offset); + + // remove guid after put the cached tree in request + std::vector> merge_dfs_trees( + std::vector>> + input_trees, + int root_depth, + RequestGuid guid); + + std::vector> traverse_verify_tree( + size_t guid, + std::vector> const + &inputSerializedTree, + std::vector> const + &outputSerializedTree); + + static void + load_tokens_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void + load_positions_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + + static BatchConfig prepare_next_batch_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + + static BeamSearchBatchConfig prepare_next_batch_beam_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + + static BeamSearchBatchConfig prepare_next_batch_init_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + + static TreeVerifyBatchConfig prepare_next_batch_verify_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + + static void llm_serving_background_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + +private: + std::unique_ptr tokenizer_; + bool verbose; + ModelType model_type; + std::string output_filepath; + std::queue pending_request_queue; + std::unordered_map all_requests; + std::unordered_map request_generation_results; + std::mutex request_queue_mutex; + RequestGuid next_available_guid; + const std::map model_bos_map = {{ModelType::LLAMA, 0}, + {ModelType::OPT, 2}}; + + // TODO: Move this two vector to request struct + std::unordered_map>> + dfs_tree_inputs; + std::unordered_map>> + committed_tokens; + + // Multi-model support + std::vector models; + + // Performance profiling + size_t num_processed_requests; + +private: + struct ProfileInfo { + int decoding_steps; + double start_time, finish_time; + }; + std::unordered_map profiling_requests; + double total_request_run_time; +}; + +}; // namespace FlexFlow diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index d1925cca70..957c41b103 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -14,6 +14,7 @@ */ #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include "models/falcon.h" #include "models/llama.h" #include "models/opt.h" @@ -195,7 +196,7 @@ void FlexFlow::top_level_task(Task const *task, printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; GenerationResult result = - RequestManager::generate(text, 128 /*max_sequence_length*/); + model.generate(text, 128 /*max_sequence_length*/); } } diff --git a/inference/models/falcon.h b/inference/models/falcon.h index 986d9d6951..d9c330a8b9 100644 --- a/inference/models/falcon.h +++ b/inference/models/falcon.h @@ -17,6 +17,7 @@ #include "file_loader.h" #include "flexflow/batch_config.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include #include using json = nlohmann::json; diff --git a/inference/models/llama.h b/inference/models/llama.h index 77b57520e3..61d8908d0c 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -17,6 +17,7 @@ #include "file_loader.h" #include "flexflow/batch_config.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include #include using json = nlohmann::json; diff --git a/inference/models/opt.h b/inference/models/opt.h index 2023adf5af..45ee6e6181 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -17,6 +17,7 @@ #include "file_loader.h" #include "flexflow/batch_config.h" #include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include #include using json = nlohmann::json; diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index a5def745e8..99131edb34 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -43,10 +43,7 @@ void parse_input_args(char **argv, FilePaths &paths, ModelTypes &model_types, bool &use_full_precision, - bool &verbose, - int &data_parallelism_degree, - int &tensor_parallelism_degree, - int &pipeline_parallelism_degree) { + bool &verbose) { for (int i = 1; i < argc; i++) { // llm model type if (!strcmp(argv[i], "-llm-model")) { @@ -117,21 +114,6 @@ void parse_input_args(char **argv, paths.output_file_path = std::string(argv[++i]); continue; } - // data parallelism degree - if (!strcmp(argv[i], "-data-parallelism-degree")) { - data_parallelism_degree = std::stoi(argv[++i]); - continue; - } - // tensor parallelism degree - if (!strcmp(argv[i], "-tensor-parallelism-degree")) { - tensor_parallelism_degree = std::stoi(argv[++i]); - continue; - } - // pipeline parallelism degree - if (!strcmp(argv[i], "-pipeline-parallelism-degree")) { - pipeline_parallelism_degree = std::stoi(argv[++i]); - continue; - } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; continue; @@ -160,20 +142,10 @@ void FlexFlow::top_level_task(Task const *task, InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; int argc = command_args.argc; - parse_input_args(argv, - argc, - file_paths, - model_types, - use_full_precision, - verbose, - data_parallelism_degree, - tensor_parallelism_degree, - pipeline_parallelism_degree); - ffconfig.data_parallelism_degree = data_parallelism_degree; - ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; - ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree; - assert(data_parallelism_degree * tensor_parallelism_degree * - pipeline_parallelism_degree == + parse_input_args( + argv, argc, file_paths, model_types, use_full_precision, verbose); + assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree * + ffconfig.pipeline_parallelism_degree == ffconfig.numNodes * ffconfig.workersPerNode); if (file_paths.ssm_weight_file_paths.size() == 0) { @@ -261,8 +233,7 @@ void FlexFlow::top_level_task(Task const *task, assert(false && "Invalid SSM model type passed."); } - int beam_model_id = rm->register_new_model(&beam_model); - ssm_model_ids.push_back(beam_model_id); + rm->register_ssm_model(&beam_model); } // Register requests from prompt file @@ -279,61 +250,9 @@ void FlexFlow::top_level_task(Task const *task, std::string text = prompt.get(); printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; - rm->register_new_request(text, 128 /*max_sequence_length*/); - } - } - - TreeVerifyBatchConfigFuture tree_bcf; - BeamSearchBatchConfigFuture beam_bcf; - InferenceResultFuture tree_irf; - std::vector beam_bcf_vec; - { - TreeVerifyBatchConfig tree_bc; - BeamSearchBatchConfig beam_bc; - InferenceResult tree_ir; - tree_bcf = Future::from_value(tree_bc); - beam_bcf = Future::from_value(beam_bc); - tree_irf = Future::from_value(tree_ir); - for (int ssm_id = 0; ssm_id < num_ssms; ssm_id++) { - beam_bcf_vec.push_back(Future::from_value( - BeamSearchBatchConfig(ssm_model_ids[ssm_id]))); - } - } - - while (rm->get_num_processed_requests() < total_num_requests) { - // Beam Search - beam_bcf = rm->prepare_next_batch_init(tree_bcf, tree_irf, 0); - for (int ssm_id = 0; ssm_id < num_ssms; ssm_id++) { - beam_bcf_vec[ssm_id] = beam_bcf; - } - - if (rm->get_num_processed_requests() >= total_num_requests) { - break; - } - - for (int i = 0; i < num_ssms; i++) { - for (int depth = 0; depth < BeamSearchBatchConfig::MAX_BEAM_DEPTH; - depth++) { - beam_bcf = beam_bcf_vec[i]; - - FutureMap fm = im->inference(rm->get_model(0), 0, beam_bcf_vec[i]); - assert(fm.get_future_map_domain().get_volume() == 1); - BeamInferenceResultFuture beam_irf = fm.get_future(0); - beam_bcf_vec[i] = - rm->prepare_next_batch_beam(beam_bcf_vec[i], beam_irf); - } - // std::cout << "----------beam search finished for model " - // << beam_bc_vec[i].model_id << "------------" << std::endl; - } - // Token Tree Verification - { - tree_bcf = rm->prepare_next_batch_verify(beam_bcf_vec); - FutureMap fm = im->inference(&tree_model, 0, tree_bcf); - assert(fm.get_future_map_domain().get_volume() == 1); - tree_irf = fm.get_future(0); + tree_model.generate(text, 128 /*max_sequence_length*/); } } - // im.spec_inference_loop(&tree_model, rm, total_num_requests, ssm_model_ids); // Execution fence { diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index d7f1b70232..1c3103683f 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -15,8 +15,8 @@ #include "flexflow/flexflow_c.h" #include "flexflow/dataloader.h" -#include "flexflow/inference.h" #include "flexflow/mapper.h" +#include "flexflow/request_manager.h" using namespace Legion; using namespace FlexFlow; diff --git a/src/mapper/mapper.cc b/src/mapper/mapper.cc index 3d08eb0bcc..9449e9e44b 100644 --- a/src/mapper/mapper.cc +++ b/src/mapper/mapper.cc @@ -286,7 +286,8 @@ void FFMapper::select_task_options(const MapperContext ctx, if ((task.task_id == RM_PREPARE_NEXT_BATCH_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID) || (task.task_id == RM_PREPARE_NEXT_BATCH_INIT_TASK_ID) || - (task.task_id == RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID)) { + (task.task_id == RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID) || + (task.task_id == RM_LLM_SERVING_BACKGROUND_TASK_ID)) { output.initial_proc = all_cpus[0]; return; } diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index e6da678908..cfcc938204 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -15,11 +15,11 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/graph.h" -#include "flexflow/inference.h" #include "flexflow/model.h" #include "flexflow/ops/fused.h" #include "flexflow/ops/noop.h" #include "flexflow/parallel_ops/parallel_op.h" +#include "flexflow/request_manager.h" namespace FlexFlow { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 91d50a5161..f531284fce 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -65,6 +65,7 @@ #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" #include "flexflow/parallel_ops/replicate.h" +#include "flexflow/request_manager.h" #include "flexflow/substitution.h" #include "flexflow/utils/random_utils.h" #include "flexflow/utils/test_utils.h" diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index b20f6fcf81..bb172ca734 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include "flexflow/parallel_ops/parallel_op.h" // #include "flexflow/tokenizers.h" #include @@ -47,7 +47,7 @@ RequestManager::RequestManager() void RequestManager::register_tokenizer(ModelType type, std::string const &path) { // bos id - this->model_type = model_type; + this->model_type = type; if (model_type == ModelType::LLAMA) { this->tokenizer_ = Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(path)); @@ -78,15 +78,13 @@ void RequestManager::register_tokenizer(ModelType type, void RequestManager::register_output_filepath( std::string const &_output_filepath) { - output_filepath = _output_filepath; + this->output_filepath = _output_filepath; } -int RequestManager::register_new_model(FFModel *model) { +int RequestManager::register_ssm_model(FFModel *model) { int model_id = models.size(); models.push_back(model); std::cout << "Register new model with id: " << model_id << std::endl; - num_ssms++; - assert(models.size() == num_ssms); return model_id; } @@ -95,6 +93,10 @@ FFModel *RequestManager::get_model(int model_id) { return models[model_id]; } +size_t RequestManager::get_num_ssms() { + return models.size(); +} + RequestManager::RequestGuid RequestManager::register_new_request(std::vector const &prompt, int max_sequence_length) { @@ -108,13 +110,13 @@ RequestManager::RequestGuid request.tokens = prompt; request.promise = new std::promise(); - if (num_ssms == 0) { + if (get_num_ssms() == 0) { std::cout << "No small speculative model registered yet, using incremental " "decoding." << std::endl; } else { - std::cout << "Num of models: " << num_ssms << std::endl; - for (int i = 0; i < num_ssms; i++) { + std::cout << "Num of models: " << get_num_ssms() << std::endl; + for (int i = 0; i < get_num_ssms(); i++) { BeamTree beam_tree = BeamTree{}; request.beam_trees.push_back(beam_tree); } @@ -160,13 +162,13 @@ RequestManager::RequestGuid request.tokens.insert(request.tokens.end(), tokens.begin(), tokens.end()); request.initial_len = request.tokens.size(); - if (num_ssms == 0) { + if (get_num_ssms() == 0) { std::cout << "No small speculative model registered yet, using incremental " "decoding." << std::endl; } else { - std::cout << "Num of models: " << num_ssms << std::endl; - for (int i = 0; i < num_ssms; i++) { + std::cout << "Num of models: " << get_num_ssms() << std::endl; + for (int i = 0; i < get_num_ssms(); i++) { BeamTree beam_tree = BeamTree{}; request.beam_trees.push_back(beam_tree); } @@ -1471,26 +1473,142 @@ std::vector> return merged_tree; } -/*static*/ -GenerationResult RequestManager::generate(std::string const &text, - int max_seq_length) { +GenerationResult FFModel::generate(std::string const &text, + int max_seq_length) { RequestManager *rm = RequestManager::get_request_manager(); - RequestGuid guid = rm->register_new_request(text, max_seq_length); - std::future future = - rm->all_requests[guid].promise->get_future(); - return future.get(); + if (rm->get_num_ssms() == 0) { + // No SSMs: perform incremental decoding + return rm->generate_incr_decoding(this, text, max_seq_length); + } else { + // Registered SSMs: perform speculative inference + return rm->generate_spec_infer(this, text, max_seq_length); + } +} + +/*static*/ +GenerationResult RequestManager::generate_incr_decoding(FFModel *llm, + std::string const &text, + int max_seq_length) { + InferenceManager *im = InferenceManager::get_inference_manager(); + RequestGuid guid = register_new_request(text, max_seq_length); + int tokens_to_generate = max_seq_length - all_requests[guid].tokens.size(); + std::queue> + batch_pipeline; + { + BatchConfig bc; + InferenceResult ir; + BatchConfigFuture bcf = Future::from_value(bc); + InferenceResultFuture irf = Future::from_value(ir); + batch_pipeline.push(std::make_pair(bcf, irf)); + } + for (int i = 0; i < tokens_to_generate; i++) { + if (batch_pipeline.size() >= 4) { + // Block here to avoid launching too many batches + auto const &batch = batch_pipeline.front(); + batch.second.get_void_result(); + } + // deque finished batches + while (batch_pipeline.size() > 1) { + auto const &batch = batch_pipeline.front(); + if (batch.second.is_ready()) { + batch_pipeline.pop(); + } else { + break; + } + } + auto const &next_batch = batch_pipeline.back(); + BatchConfigFuture bcf = + prepare_next_batch(next_batch.first, next_batch.second); + FutureMap fm = im->inference(llm, 0, bcf); + assert(fm.get_future_map_domain().get_volume() == 1); + InferenceResultFuture irf = fm.get_future(0); + batch_pipeline.push(std::make_pair(bcf, irf)); + } + + GenerationResult gr = request_generation_results[guid]; + return gr; +} + +/*static*/ +GenerationResult RequestManager::generate_spec_infer(FFModel *llm, + std::string const &text, + int max_seq_length) { + InferenceManager *im = InferenceManager::get_inference_manager(); + RequestGuid guid = register_new_request(text, max_seq_length); + std::queue> + batch_pipeline; + { + TreeVerifyBatchConfig tree_bc; + InferenceResult tree_ir; + TreeVerifyBatchConfigFuture tree_bcf = + Future::from_value(tree_bc); + InferenceResultFuture tree_irf = + Future::from_value(tree_ir); + batch_pipeline.push(std::make_pair(tree_bcf, tree_irf)); + } + size_t num_processed_requests = get_num_processed_requests(); + while (get_num_processed_requests() == num_processed_requests) { + if (batch_pipeline.size() >= 4) { + // Block here to avoid launching too many batches + auto const &batch = batch_pipeline.front(); + batch.second.get_void_result(); + } + // deque finished batches + while (batch_pipeline.size() > 1) { + auto const &batch = batch_pipeline.front(); + if (batch.second.is_ready()) { + batch_pipeline.pop(); + } else { + break; + } + } + auto const &next_batch = batch_pipeline.back(); + BeamSearchBatchConfigFuture beam_bcf = + prepare_next_batch_init(next_batch.first, next_batch.second, 0); + std::vector beam_bcf_vec(get_num_ssms()); + for (size_t ssm_id = 0; ssm_id < get_num_ssms(); ssm_id++) { + beam_bcf_vec[ssm_id] = beam_bcf; + } + if (get_num_processed_requests() > num_processed_requests) { + break; + } + + for (size_t i = 0; i < get_num_ssms(); i++) { + for (int depth = 0; depth < BeamSearchBatchConfig::MAX_BEAM_DEPTH; + depth++) { + beam_bcf = beam_bcf_vec[i]; + + FutureMap fm = im->inference(get_model(i), 0, beam_bcf_vec[i]); + assert(fm.get_future_map_domain().get_volume() == 1); + BeamInferenceResultFuture beam_irf = fm.get_future(0); + beam_bcf_vec[i] = prepare_next_batch_beam(beam_bcf_vec[i], beam_irf); + } + } + // Token Tree Verification + { + TreeVerifyBatchConfigFuture tree_bcf = + prepare_next_batch_verify(beam_bcf_vec); + FutureMap fm = im->inference(llm, 0, tree_bcf); + assert(fm.get_future_map_domain().get_volume() == 1); + InferenceResultFuture tree_irf = fm.get_future(0); + batch_pipeline.push(std::make_pair(tree_bcf, tree_irf)); + } + } + + GenerationResult gr = request_generation_results[guid]; + return gr; } /*static*/ void RequestManager::serve(FFModel *llm) { Runtime *runtime = Runtime::get_runtime(); Context ctx = Runtime::get_context(); - TaskLauncher launcher(RM_LLM_SERVING_BACKGROUND_TASK_ID, TaskArgument(&llm, sizeof(FFModel *))); runtime->execute_task(ctx, launcher); } +/*static*/ void RequestManager::llm_serving_background_task( Task const *task, std::vector const ®ions, diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 31d5afd36d..1ca466ff91 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -13,7 +13,7 @@ * limitations under the License. */ -#include "flexflow/inference.h" +#include "flexflow/request_manager.h" #include "flexflow/utils/cuda_helper.h" namespace FlexFlow {