Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 7, 2024
1 parent 4a3d1bd commit d2caba8
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 40 deletions.
1 change: 1 addition & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ enum TaskIDs {
RM_PREPARE_NEXT_BATCH_INIT_TASK_ID,
RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID,
RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID,
RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID,
RM_BACKGROUND_SERVING_TASK_ID,
// Custom tasks
CUSTOM_GPU_TASK_ID_FIRST,
Expand Down
19 changes: 19 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/utils/file_loader.h"
#include "suffix_decoding.h"
#include <future>
#include <mutex>
#include <tokenizers_cpp.h>
Expand Down Expand Up @@ -164,6 +165,7 @@ class RequestManager {

void serve_incr_decoding(FFModel *model);
void serve_spec_infer(FFModel *model);
void serve_suffix_decoding(FFModel *model);
GenerationResult get_generation_result(RequestGuid const &guid);
RequestGuid register_new_request(Request const &request_);
RequestGuid register_new_peft_request(Request const &request_);
Expand Down Expand Up @@ -210,6 +212,15 @@ class RequestManager {
Legion::Context ctx,
Legion::Runtime *runtime);

TreeVerifyBatchConfig
prepare_next_batch_suffix_decode(TreeVerifyBatchConfig const &old_bc,
InferenceResult const &result);
TreeVerifyBatchConfigFuture prepare_next_batch_suffix_decode(
TreeVerifyBatchConfigFuture const &old_bc,
InferenceResultFuture const &result,
Legion::Context ctx,
Legion::Runtime *runtime);

void store_beam_metadata(BeamSearchBatchConfig const &old_bc,
BeamInferenceResult const &result);
void update_beam_metadata(BeamSearchBatchConfig &new_bc,
Expand Down Expand Up @@ -280,6 +291,12 @@ class RequestManager {
Legion::Context ctx,
Legion::Runtime *runtime);

static TreeVerifyBatchConfig prepare_next_batch_suffix_decode_task(
Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);

private:
// configuration parameters
int max_requests_per_batch;
Expand All @@ -295,6 +312,8 @@ class RequestManager {
// tree width in each speculative step, if not specified 1
std::vector<int> spec_infer_tree_width;

SuffixTree *suffix_tree;

// private fields
std::unique_ptr<Tokenizer> tokenizer_;
bool verbose;
Expand Down
1 change: 1 addition & 0 deletions src/mapper/mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ void FFMapper::select_task_options(const MapperContext ctx,
(task.task_id == RM_PREPARE_NEXT_BATCH_INIT_TASK_ID) ||
(task.task_id == RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID) ||
(task.task_id == RM_PREPARE_NEXT_BATCH_VERIFY_TASK_ID) ||
(task.task_id == RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID) ||
(task.task_id == RM_BACKGROUND_SERVING_TASK_ID)) {
output.initial_proc = all_cpus[0];
return;
Expand Down
21 changes: 21 additions & 0 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4822,6 +4822,27 @@ void register_flexflow_internal_tasks(Runtime *runtime,
RequestManager::prepare_next_batch_verify_task>(registrar);
}
}
// RequestManager prepare_next_batch_suffix_decode
{
TaskVariantRegistrar registrar(
RM_PREPARE_NEXT_BATCH_SUFFIX_DECODE_TASK_ID,
"RequestManager Prepare Next Batch (Suffix Decode)");
registrar.add_constraint(ProcessorConstraint(Processor::LOC_PROC));
registrar.set_leaf();
if (pre_register) {
Runtime::preregister_task_variant<
TreeVerifyBatchConfig,
RequestManager::prepare_next_batch_suffix_decode_task>(
registrar, "RequestManager Prepare Next Batch (Suffix Decode) Task");
} else {
if (enable_control_replication) {
registrar.global_registration = false;
}
runtime->register_task_variant<
TreeVerifyBatchConfig,
RequestManager::prepare_next_batch_suffix_decode_task>(registrar);
}
}
// RequestManager background serving task
{
TaskVariantRegistrar registrar(RM_BACKGROUND_SERVING_TASK_ID,
Expand Down
Loading

0 comments on commit d2caba8

Please sign in to comment.