Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 5, 2024
1 parent c5e813b commit aa57f98
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 49 deletions.
1 change: 1 addition & 0 deletions include/flexflow/utils/peft_weight_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class PEFTMemoryManager {
void allocate_finetuning_memory();

LoraLinearWeight get_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);
void check_ft_model_id(PEFTModelID const &model_id);

private:
// Check if the PEFT adapter for the given model is in memory. If not, sets the cache_miss flag to true. If this is the first finetuning request, allocate memory for the finetuning adapter.
Expand Down
82 changes: 33 additions & 49 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ void inference_kernel(LoraLinearMeta *m,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
assert(num_peft_requests <= 1);
}

template <typename DT>
Expand Down Expand Up @@ -437,39 +438,24 @@ void peft_bwd_kernel(LoraLinearMeta *m,
cudaDataType_t weight_type = output_type;
cudaDataType_t lr_actv_type = output_type;
cudaDataType_t compute_type = output_type;
// #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
// cudaDataType_t compute_type = output_type;
// #else
// // For best performance, set the default cublas compute type to
// // CUBLAS_COMPUTE_16F for half precision and to
// // CUBLAS_COMPUTE_32F_FAST_16F for full precision
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
// if (m->output_type[0] == DT_FLOAT) {
// compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
// }
// #endif

for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
}
// Skip non-PEFT requests
if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) {
// Skip completed, non-PEFT and PEFT forward-only requests
if (bc->request_completed[i] || bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || !bc->requestsInfo[i].peft_bwd) {
continue;
}
// Skip PEFT forward-only requests
if (!bc->requestsInfo[i].peft_bwd) {
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
LoraLinearConfig lora_config = LoraLinearConfig::deserialize_from_json_string(bc->requestsInfo[i].peft_adapters[bc->requestsInfo[i].peft_model_id]);
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && "Trainable flag mismatch");
m->peft_memory_manager->check_ft_model_id(bc->requestsInfo[i].peft_model_id);
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
// int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) !=
m->model_state.end());
LoraLinearWeight weight =
m->model_state[bc->requestsInfo[i].peft_model_id].weights;
int rank = weight.rank;
float lora_alpha =
m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha;
DT scaling_constant = (DT)(lora_alpha / rank);
// int max_peft_tokens = bc->requestsInfo[i].max_length;
int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
LoraLinearWeight weight = m->peft_memory_manager->get_peft(bc->requestsInfo[i].peft_model_id, lora_config);
DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank);

// Compute LORA_B weight's gradient
if (bc->requestsInfo[i].optimizer_tasks.compute_gradients) {
Expand All @@ -480,20 +466,20 @@ void peft_bwd_kernel(LoraLinearMeta *m,
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
rank,
lora_config.rank,
out_dim,
num_peft_tokens,
&scaling_constant,
m->low_rank_activation,
weight.low_rank_activation,
lr_actv_type,
rank,
lora_config.rank,
output_grad_ptr,
output_type,
out_dim,
&beta,
weight.w1_grad_ptr,
weight_type,
rank,
lora_config.rank,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
Expand All @@ -505,20 +491,20 @@ void peft_bwd_kernel(LoraLinearMeta *m,
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_N,
rank,
lora_config.rank,
num_peft_tokens,
out_dim,
&scaling_constant,
weight.w1_ptr,
weight_type,
rank,
lora_config.rank,
output_grad_ptr,
output_type,
out_dim,
&beta,
m->low_rank_activation,
weight.low_rank_activation,
lr_actv_type,
rank,
lora_config.rank,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
Expand All @@ -533,15 +519,15 @@ void peft_bwd_kernel(LoraLinearMeta *m,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
rank,
lora_config.rank,
num_peft_tokens,
&alpha,
m->input_activation,
weight.input_activation,
input_type,
in_dim,
m->low_rank_activation,
weight.low_rank_activation,
lr_actv_type,
rank,
lora_config.rank,
&beta,
weight.w0_grad_ptr,
weight_type,
Expand All @@ -559,14 +545,14 @@ void peft_bwd_kernel(LoraLinearMeta *m,
CUBLAS_OP_N,
in_dim,
num_peft_tokens,
rank,
lora_config.rank,
&alpha,
weight.w0_ptr,
weight_type,
in_dim,
m->low_rank_activation,
weight.low_rank_activation,
lr_actv_type,
rank,
lora_config.rank,
&beta,
input_grad_ptr,
input_type,
Expand All @@ -576,15 +562,13 @@ void peft_bwd_kernel(LoraLinearMeta *m,
}

if (bc->requestsInfo[i].optimizer_tasks.update_weights) {
LoraOptimizerConfig const *optimizer_config =
m->model_state[bc->requestsInfo[i].peft_model_id].optimizer_config;
LoraOptimizerConfig const *optimizer_config = lora_config.optimizer_config;
assert(optimizer_config != nullptr);
assert(typeid(*optimizer_config) != typeid(LoraOptimizerConfig));
int w0_num_elements = rank * in_dim;
int w1_num_elements = rank * out_dim;
int w0_num_elements = lora_config.rank * in_dim;
int w1_num_elements = lora_config.rank * out_dim;

// Get optimizer config
if (typeid(*optimizer_config) == typeid(LoraSGDOptimizerConfig)) {
if (optimizer_config->getType() == "SGD") {
LoraSGDOptimizerConfig const *sgd_config =
(LoraSGDOptimizerConfig const *)optimizer_config;
// LoRA_A weight is split in tensor parallelism, so no need to apply
Expand Down Expand Up @@ -625,7 +609,7 @@ void peft_bwd_kernel(LoraLinearMeta *m,
static_cast<DT const *>(weight.w1_grad_ptr),
static_cast<DT *>(weight.w1_v_values_ptr),
static_cast<DT *>(weight.w1_ptr));
} else if (typeid(*optimizer_config) == typeid(LoraAdamOptimizerConfig)) {
} else if (optimizer_config->getType() == "Adam") {
assert(false && "Adam optimizer type not implemented yet");
} else {
assert(false && "Unsupported optimizer type");
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/peft_weight_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,8 @@ LoraLinearWeight PEFTMemoryManager::get_peft(PEFTModelID const &model_id, LoraLi
}
}

void PEFTMemoryManager::check_ft_model_id(PEFTModelID const &model_id) {
assert(finetuning_model_id == model_id && "PEFT bwd model is not in memory!");
}

}; // namespace FlexFlow

0 comments on commit aa57f98

Please sign in to comment.