Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Peft python interface #1306

Merged
merged 33 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ if(NOT BUILD_LEGION_ONLY)
if(FF_BUILD_ALL_INFERENCE_EXAMPLES OR FF_BUILD_ALL_EXAMPLES)
add_subdirectory(inference/spec_infer)
add_subdirectory(inference/incr_decoding)
add_subdirectory(inference/peft)
endif()


Expand Down
8 changes: 6 additions & 2 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ enum InferenceMode {
TREE_VERIFY_MODE = 2003,
};

enum RequestType {
REQ_INFERENCE = 4001,
REQ_FINETUNING = 4002,
};

// This is consistent with TASO's OpType
// https://github.com/jiazhihao/TASO/blob/master/include/taso/ops.h#L75-L138
enum OperatorType {
Expand Down Expand Up @@ -179,8 +184,7 @@ enum OperatorType {
OP_TREE_INC_MULTIHEAD_SELF_ATTENTION,
OP_SAMPLING,
// PEFT Ops
OP_LORA_MLP_FIRST,
OP_LORA_MLP_SECOND,
OP_LORA,
// Parallel Ops
OP_REPARTITION,
OP_COMBINE,
Expand Down
36 changes: 32 additions & 4 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ FF_NEW_OPAQUE_TYPE(flexflow_inference_manager_t);
FF_NEW_OPAQUE_TYPE(flexflow_request_manager_t);
FF_NEW_OPAQUE_TYPE(flexflow_file_data_loader_t);
FF_NEW_OPAQUE_TYPE(flexflow_generation_result_t);
FF_NEW_OPAQUE_TYPE(flexflow_lora_linear_config_t);
FF_NEW_OPAQUE_TYPE(flexflow_peft_model_id_t);

// -----------------------------------------------------------------------
// FFConfig
Expand Down Expand Up @@ -593,6 +595,9 @@ flexflow_tensor_t flexflow_model_add_argmax(flexflow_model_t handle_,
bool beam_search,
char const *name);

flexflow_peft_model_id_t flexflow_model_add_lora_layer(
flexflow_model_t handle_, const flexflow_lora_linear_config_t peft_config_);

void flexflow_model_set_sgd_optimizer(flexflow_model_t handle,
flexflow_sgd_optimizer_t optimizer);

Expand All @@ -616,10 +621,13 @@ void flexflow_model_set_transformer_layer_id(flexflow_model_t handle, int id);

void flexflow_model_generate(flexflow_model_t handle_,
int num_requests,
char const **input_text,
int max_num_chars,
char **output_text,
int max_seq_length,
enum RequestType *request_types,
char const **input_texts,
char **output_texts,
int *max_seq_lengths,
flexflow_peft_model_id_t *peft_model_ids,
char const **dataset_filepaths,
int *training_steps,
int **output_length_and_tokens);

void flexflow_model_set_position_offset(flexflow_model_t handle, int offset);
Expand Down Expand Up @@ -1036,6 +1044,26 @@ void flexflow_file_data_loader_destroy(flexflow_file_data_loader_t handle_);
void flexflow_file_data_loader_load_weights(flexflow_file_data_loader_t handle_,
flexflow_model_t model_handle_);

// -----------------------------------------------------------------------
// LoraLinearConfig
// -----------------------------------------------------------------------

flexflow_lora_linear_config_t
flexflow_lora_linear_config_create(char const *cache_folder_,
char const *peft_model_id_);

void flexflow_lora_linear_config_destroy(flexflow_lora_linear_config_t handle_);

// -----------------------------------------------------------------------
// PEFTModelID
// -----------------------------------------------------------------------

flexflow_peft_model_id_t flexflow_peft_model_id_create();

flexflow_peft_model_id_t flexflow_peft_model_id_create_id(unsigned long id);

void flexflow_peft_model_id_destroy(flexflow_peft_model_id_t handle_);

#ifdef __cplusplus
}
#endif
Expand Down
15 changes: 7 additions & 8 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -837,19 +837,12 @@ class FFModel {
// ========================================
// PEFT Layers
// ========================================
void lora_linear(Tensor const input,
Tensor const output,
OperatorType _type,
char const *name = nullptr);
PEFTModelID *add_lora_layer(LoraLinearConfig const peft_config);
// ========================================
// Inference APIs
// ========================================
std::vector<GenerationResult> generate(std::vector<Request> const &requests);

PEFTModelID register_peft_model(
LoraLinearConfig const mlp_first = LoraLinearConfig::DefaultConfig,
LoraLinearConfig const mlp_second = LoraLinearConfig::DefaultConfig);

Tensor create_tensor_legion_ordering(int num_dim,
int const dims[],
DataType data_type,
Expand Down Expand Up @@ -1174,6 +1167,12 @@ class FFModel {
std::vector<Layer *> layers;
std::vector<Op *> operators;
std::vector<ParallelTensor> parameters;
// PEFT related
std::unordered_map<Layer *, Layer *> base_layer_to_peft_layer;
std::unordered_map<Layer *, std::vector<PEFTModelID>> peft_layer_to_peft_id;
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
// std::vector<Op *> peft_operators;

FFHandler handlers[MAX_NUM_WORKERS];
Legion::Future current_metrics;
// Cached operators: key: operator hash, value: operator pointer
Expand Down
26 changes: 10 additions & 16 deletions include/flexflow/ops/lora_linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ class LoraLinear : public Op {
using Params = LoraLinearParams;
using Input = std::pair<ParallelTensor, ParallelTensor>;

LoraLinear(FFModel &model,
LayerID const &layer_guid,
OperatorType type,
ParallelTensor const input,
ParallelTensor const output,
char const *name = nullptr);
LoraLinear(
FFModel &model,
LayerID const &layer_guid,
OperatorType type,
ParallelTensor const input,
ParallelTensor const output,
std::unordered_map<PEFTModelID, LoraLinearConfig> const &_peft_configs,
char const *name = nullptr);
LoraLinear(FFModel &model,
LoraLinear const &other,
ParallelTensor const input,
Expand All @@ -39,11 +41,6 @@ class LoraLinear : public Op {
MachineView const *mv = nullptr) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
void register_peft_model(FFModel const &ff,
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
PEFTModelID const &model_id,
LoraLinearConfig const lora_config);
Legion::FutureMap inference(FFModel const &,
BatchConfigFuture const &,
std::vector<ParallelTensor> const &,
Expand All @@ -64,11 +61,6 @@ class LoraLinear : public Op {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void
register_model_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Expand Down Expand Up @@ -98,6 +90,8 @@ class LoraLinear : public Op {
int num_inputs) const override;
// size_t get_params_hash() const override;
LoraLinearParams get_params() const;

std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
};

}; // namespace FlexFlow
Expand Down
4 changes: 3 additions & 1 deletion include/flexflow/ops/lora_linear_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace FlexFlow {

class LoraLinearConfig {
public:
static const LoraLinearConfig DefaultConfig;
static const LoraLinearConfig EmptyConfig;
LoraLinearConfig();
LoraLinearConfig(int rank,
OptimizerType type = OPTIMIZER_TYPE_SGD,
Expand All @@ -33,6 +33,7 @@ class LoraLinearConfig {
std::string peft_model_id;
int lora_alpha;
float lora_dropout;
std::vector<std::string> target_modules;
// whether to load weights from file, instead of initializing them randomly
bool load_weights_from_file;
};
Expand All @@ -41,6 +42,7 @@ class LoraLinearParams {
public:
LayerID layer_guid;
OperatorType type;
std::unordered_map<PEFTModelID, LoraLinearConfig> peft_configs;
char name[MAX_OPNAME];

bool is_valid(std::pair<ParallelTensorShape, ParallelTensorShape> const
Expand Down
4 changes: 2 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ struct Request {
COMPLETED = 103, // finished and verified
FINISHING = 104, // finishing request, but not yet verified
};
enum RequestType { REQ_INFERENCE = 201, REQ_FINETUNING = 202 };
BatchConfig::RequestGuid guid;
PEFTModelID peft_model_id = PEFTModelID::NO_ID;
int max_sequence_length = 128;
Expand All @@ -81,10 +80,11 @@ struct Request {
RequestType req_type = REQ_INFERENCE;
int completed_training_steps = 0;
int max_training_steps = 1;
std::vector<std::pair<std::string, std::string>> dataset_text;
std::string dataset_filepath;
std::vector<std::pair<std::vector<BatchConfig::TokenId>,
std::vector<BatchConfig::TokenId>>>
dataset;
friend std::ostream &operator<<(std::ostream &os, Request const &req);
};

// store the result of beam search
Expand Down
43 changes: 6 additions & 37 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ void parse_input_args(char **argv,
int argc,
FilePaths &paths,
std::string &llm_model_name,
std::string &peft_model_name,
bool &use_full_precision,
bool &verbose,
bool &do_sample,
Expand All @@ -58,13 +57,6 @@ void parse_input_args(char **argv,
}
continue;
}
if (!strcmp(argv[i], "-peft-model")) {
peft_model_name = std::string(argv[++i]);
for (char &c : peft_model_name) {
c = std::tolower(c);
}
continue;
}
// cache folder
if (!strcmp(argv[i], "-cache-folder")) {
paths.cache_folder_path = std::string(argv[++i]);
Expand Down Expand Up @@ -133,7 +125,7 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "Doesn't support quantization in non-offload mode");
}
FilePaths file_paths;
std::string llm_model_name, peft_model_name;
std::string llm_model_name;
bool use_full_precision = false;
bool verbose = false;
bool do_sample = false;
Expand All @@ -150,7 +142,6 @@ void FlexFlow::top_level_task(Task const *task,
argc,
file_paths,
llm_model_name,
peft_model_name,
use_full_precision,
verbose,
do_sample,
Expand All @@ -159,6 +150,7 @@ void FlexFlow::top_level_task(Task const *task,
max_requests_per_batch,
max_tokens_per_batch,
max_sequence_length);

assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree *
ffconfig.pipeline_parallelism_degree ==
ffconfig.numNodes * ffconfig.workersPerNode);
Expand Down Expand Up @@ -259,19 +251,6 @@ void FlexFlow::top_level_task(Task const *task,
assert(false && "unknow model type");
}

// Register PEFT layer
LoraLinearConfig mlp_second =
peft_model_name.empty()
? LoraLinearConfig::DefaultConfig
: LoraLinearConfig(file_paths.cache_folder_path, peft_model_name);
PEFTModelID peft_model_id =
peft_model_name.empty()
? PEFTModelID::NO_ID
: model.register_peft_model(
LoraLinearConfig::DefaultConfig /*mlp_first*/,
mlp_second /*mlp_second*/);

// Start background server
rm->start_background_server(&model);

int total_num_requests = 0;
Expand All @@ -288,20 +267,10 @@ void FlexFlow::top_level_task(Task const *task,
for (auto &prompt : prompt_json) {
std::string text = prompt.get<std::string>();
printf("Prompt[%d]: %s\n", total_num_requests, text.c_str());
// Add inference request
// Request inference_req;
// inference_req.prompt = text;
// inference_req.max_sequence_length = 128;
// inference_req.peft_model_id = peft_model_id;
// requests.push_back(inference_req);
// total_num_requests++;
// Add fine-tuning request
Request fine_tuning_req;
fine_tuning_req.req_type = Request::RequestType::REQ_FINETUNING;
fine_tuning_req.max_sequence_length = 128;
fine_tuning_req.peft_model_id = peft_model_id;
fine_tuning_req.dataset_text.push_back(std::make_pair(text, ""));
requests.push_back(fine_tuning_req);
Request inference_req;
inference_req.prompt = text;
inference_req.max_sequence_length = 128;
requests.push_back(inference_req);
total_num_requests++;
}
std::vector<GenerationResult> result = model.generate(requests);
Expand Down
Loading
Loading