Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-goliaro committed Sep 28, 2024
1 parent 8f4bc8b commit fbac32e
Show file tree
Hide file tree
Showing 17 changed files with 515 additions and 141 deletions.
Empty file added backup.txt
Empty file.
2 changes: 1 addition & 1 deletion inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ void FlexFlow::top_level_task(Task const *task,
printf("Prompt[%d]: %s\n", total_num_requests, text.c_str());
Request inference_req;
inference_req.prompt = text;
inference_req.max_sequence_length = 128;
inference_req.max_sequence_length = 10;
requests.push_back(inference_req);
total_num_requests++;
}
Expand Down
6 changes: 2 additions & 4 deletions inference/models/mpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ void MPT::create_mpt_model(FFModel &ff,
nullptr, // ?
REG_MODE_NONE, // no regularization
0.0f, // no dropout
std::string("layers." + std::to_string(i) + ".attn.qkv_proj")
.c_str());
std::string("layers." + std::to_string(i) + ".attn.qkv_proj").c_str());

Tensor o_proj;
switch (mode) {
Expand Down Expand Up @@ -199,8 +198,7 @@ void MPT::create_mpt_model(FFModel &ff,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".attn.o_proj")
.c_str());
std::string("layers." + std::to_string(i) + ".attn.o_proj").c_str());

ff.residual_layer_norm(
attn_outputs,
Expand Down
10 changes: 8 additions & 2 deletions inference/python/incr_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,15 @@ def main():

if len(configs.prompt) > 0:
prompts = [s for s in json.load(open(configs.prompt))]
results = llm.generate(prompts)
if "max_length" not in configs_dict:
results = llm.generate(prompts)
else:
results = llm.generate(prompts, max_length=configs.max_length)
else:
result = llm.generate("Three tips for staying healthy are: ")
if "max_length" not in configs_dict:
result = llm.generate("Three tips for staying healthy are: ")
else:
result = llm.generate("Three tips for staying healthy are: ", max_length=configs.max_length)

llm.stop_server()

Expand Down
7 changes: 1 addition & 6 deletions src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ OpMeta *IncMultiHeadSelfAttention::init_task(
attn->num_kv_heads / attn->tensor_parallelism_degree +
(attn->num_kv_heads % attn->tensor_parallelism_degree != 0);


Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc);
MemoryAllocator gpu_mem_allocator(gpu_mem);
if (attn->offload) {
Expand Down Expand Up @@ -809,11 +808,7 @@ void IncMultiHeadSelfAttention::peft_bwd_task(
assert(task->index_point.get_dim() == 1);

IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper(
m,
bc,
task->index_point.point_data[0],
input_grad,
output_grad);
m, bc, task->index_point.point_data[0], input_grad, output_grad);

if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
Expand Down
17 changes: 3 additions & 14 deletions src/ops/inc_multihead_self_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,

if (bc->num_tokens > bc->num_generation_tokens) {
// phase 4: Compute attention score for prompt tokens;
compute_attention_kernel_prompt(
m, bc, shard_id, stream);
compute_attention_kernel_prompt(m, bc, shard_id, stream);
}

// compute output production and bias together for all tokens
Expand Down Expand Up @@ -1795,25 +1794,15 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
half const *bias_ptr =
use_bias ? bias.get_half_ptr() : static_cast<half const *>(nullptr);
Kernels::IncMultiHeadAttention::inference_kernel(
m,
bc,
shard_id,
input.get_half_ptr(),
output.get_half_ptr(),
stream);
m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream);
} else if (input.data_type == DT_FLOAT) {
if (m->offload) {
pre_build_weight_kernel<float>(m, weight, input.data_type, stream);
}
float const *bias_ptr =
use_bias ? bias.get_float_ptr() : static_cast<float const *>(nullptr);
Kernels::IncMultiHeadAttention::inference_kernel(
m,
bc,
shard_id,
input.get_float_ptr(),
output.get_float_ptr(),
stream);
m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream);
} else {
assert(false && "Unspported data type");
}
Expand Down
79 changes: 41 additions & 38 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -542,26 +542,24 @@ template <typename DT>
void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
// DT const *weight_ptr,
DT *output_ptr,
// DT const *bias_ptr,
cudaStream_t stream) {

checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
assert(m->qSize == m->vSize && m->qSize == m->kSize);
cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
cudaDataType_t compute_type = cublas_data_type;
#else
// For best performance, set the default cublas compute type to
// CUBLAS_COMPUTE_16F for half precision and to
// CUBLAS_COMPUTE_32F_FAST_16F for full precision
cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
if (m->output_type[0] == DT_FLOAT) {
compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
}
#endif
// cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]);
// #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
// cudaDataType_t compute_type = cublas_data_type;
// #else
// // For best performance, set the default cublas compute type to
// // CUBLAS_COMPUTE_16F for half precision and to
// // CUBLAS_COMPUTE_32F_FAST_16F for full precision
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
// if (m->output_type[0] == DT_FLOAT) {
// compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
// }
// #endif

int num_tokens = bc->num_active_tokens();
int parallelism = m->kProjSize * num_tokens * m->num_q_heads;
Expand Down Expand Up @@ -820,11 +818,8 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
stream);
// phase 1: Implement kernel to apply rotary embedding and scaling
compute_qkv_kernel(m,
bc,
shard_id,
static_cast<DT *>(m->devQKVProjArray),
stream);
compute_qkv_kernel(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
update_kv_cache_kernel<DT>(m, bc, stream);
if (bc->num_generation_tokens > 0) {
Expand All @@ -835,8 +830,12 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m,
if (bc->num_tokens > bc->num_generation_tokens) {
// phase 4: Compute attention score for prompt tokens;
compute_attention_kernel_prompt(
m, bc, shard_id, static_cast<DT*>(nullptr), static_cast<DT*>(nullptr), stream);
compute_attention_kernel_prompt(m,
bc,
shard_id,
static_cast<DT *>(nullptr),
static_cast<DT *>(nullptr),
stream);
}
// compute output production and bias together for all tokens
Expand Down Expand Up @@ -1345,12 +1344,12 @@ void peft_bwd_kernel(
// matrix C's layout: [m->qSize, num_tokens]
DT *C = input_grad_ptr +
bc->requestsInfo[i].first_token_offset_in_batch * m->qSize;
int m_ = m->qSize;
// int m_ = m->qSize;
int n_ = num_tokens;
int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize);
// The original version uses existing result and attention's projection to
// do further calculation in a way different than the usual dense layer,
// do further calculation in a way different than the usual dense layer,
// they are off by a transpose. So an explicit transpose is needed here.
// The add here is just for gradient accumulation.
transposeAdd(C, B, n_, k_, alpha, beta, stream);
Expand Down Expand Up @@ -1704,8 +1703,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output
) {
GenericTensorAccessorW const &output) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
Expand All @@ -1720,20 +1718,10 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper(
if (input.data_type == DT_HALF) {
Kernels::IncMultiHeadAttention::inference_kernel(
m,
bc,
shard_id,
input.get_half_ptr(),
output.get_half_ptr(),
stream);
m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream);
} else if (input.data_type == DT_FLOAT) {
Kernels::IncMultiHeadAttention::inference_kernel(
m,
bc,
shard_id,
input.get_float_ptr(),
output.get_float_ptr(),
stream);
m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream);
} else {
assert(false && "Unspported data type");
}
Expand All @@ -1758,7 +1746,7 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper(
GenericTensorAccessorR const &output_grad) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
bool use_bias = *m->qkv_bias || *m->final_bias;
// bool use_bias = *m->qkv_bias || *m->final_bias;
cudaEvent_t t_start, t_end;
if (m->profiling) {
Expand Down Expand Up @@ -2132,4 +2120,19 @@ template void
BatchConfig const *bc,
half *output_ptr,
cudaStream_t stream);
template void Kernels::IncMultiHeadAttention::compute_qkv_kernel<float>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
float *output_ptr,
ffStream_t stream);
template void Kernels::IncMultiHeadAttention::compute_qkv_kernel<half>(
IncMultiHeadSelfAttentionMeta const *m,
BatchConfig const *bc,
int shard_id,
half *output_ptr,
ffStream_t stream);
}; // namespace FlexFlow
24 changes: 5 additions & 19 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -714,11 +714,8 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
// phase 1: Implement kernel to compute KQV for input tokens
// TODO WARNING: this is commented out only because we are fixing the inc_attn
// first
compute_qkv_kernel(m,
bc,
shard_id,
static_cast<DT *>(m->devQKVProjArray),
stream);
compute_qkv_kernel(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
// phase 2: Update key/val cache
update_kv_cache_kernel<DT>(m, bc, stream);
if (bc->num_generation_tokens > 0) {
Expand All @@ -728,8 +725,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
// phase 3: Compute attention score
// 3 kernels for pahse 3: matmul1 - softmax - matmal2
if (bc->num_tokens > bc->num_generation_tokens) {
compute_attention_kernel_prompt(
m, bc, shard_id, output_ptr, stream);
compute_attention_kernel_prompt(m, bc, shard_id, output_ptr, stream);
}
// compute output production and bias together for all tokens
int num_tokens = bc->num_active_tokens();
Expand Down Expand Up @@ -767,20 +763,10 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper(
if (input.data_type == DT_HALF) {
half const *bias_ptr = static_cast<half const *>(nullptr);
Kernels::SpecIncMultiHeadSelfAttention::inference_kernel(
m,
bc,
shard_id,
input.get_half_ptr(),
output.get_half_ptr(),
stream);
m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream);
} else if (input.data_type == DT_FLOAT) {
Kernels::SpecIncMultiHeadSelfAttention::inference_kernel(
m,
bc,
shard_id,
input.get_float_ptr(),
output.get_float_ptr(),
stream);
m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream);
} else {
assert(false && "Unspported data type");
}
Expand Down
7 changes: 2 additions & 5 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m,
// phase 1: Implement kernel to compute KQV for input tokens
// TODO WARNING: this is commented out only because we are fixing the inc_attn
// first
compute_qkv_kernel(m,
bc,
shard_id,
static_cast<DT *>(m->devQKVProjArray),
stream);
compute_qkv_kernel(
m, bc, shard_id, static_cast<DT *>(m->devQKVProjArray), stream);
// phase 2: No need to update key/val cache
compute_attention_kernel_fused<DT>(
Expand Down
15 changes: 9 additions & 6 deletions src/runtime/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ void load_attention_weights_to_dense_v2(DT *ptr,
size_t one_weight_file_size =
num_heads * single_proj_size; // size of each of Q/K/V/O for all heads

std::cout<<"hidden_dim: "<<hidden_dim<<", qkv_inner_dim: "<<qkv_inner_dim<<", num_heads: "<<num_heads<<std::endl;
std::cout << "hidden_dim: " << hidden_dim
<< ", qkv_inner_dim: " << qkv_inner_dim
<< ", num_heads: " << num_heads << std::endl;

size_t q_size = one_weight_file_size, o_size = one_weight_file_size;
size_t k_size = single_proj_size * num_kv_heads,
Expand Down Expand Up @@ -374,12 +376,12 @@ void load_attention_weights_to_dense_v2(DT *ptr,

DT temp;

for(int i = 0; i < one_weight_file_size; i++) {
for (int i = 0; i < one_weight_file_size; i++) {
temp = host_array.at(i);
}

// std::cout<<"o_proj loaded into host array, total size: "<<one_weight_file_size<<std::endl;

// std::cout<<"o_proj loaded into host array, total size:
// "<<one_weight_file_size<<std::endl;

if (in_get_size != loaded_data_size) {
std::cout << "load data error" << std::endl;
Expand All @@ -390,7 +392,7 @@ void load_attention_weights_to_dense_v2(DT *ptr,

// std::cout<<"read data size checked"<<std::endl;

for(int i = 0; i < one_weight_file_size; i++) {
for (int i = 0; i < one_weight_file_size; i++) {
ptr[i] = temp;
}

Expand Down Expand Up @@ -928,7 +930,8 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff,
// self_attn.qkv_proj or self_attn.o_proj
// so looking for self_attn. in the name can determine if it is an attention
// projection
if (weight_filename.find("attn.") != std::string::npos || weight_filename.find("self_attention.") != std::string::npos) {
if (weight_filename.find("attn.") != std::string::npos ||
weight_filename.find("self_attention.") != std::string::npos) {
size_t pos = weight_filename.find(".o_proj");
if (pos != std::string::npos) {
weight_filename.replace(pos, std::string(".o_proj").length(), "");
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3424,8 +3424,7 @@ bool FFModel::need_to_add_allreduce(int layer_idx) const {
(
// l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION ||
// l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION ||
(std::string(l->name).find("attn.o_proj") !=
std::string::npos) ||
(std::string(l->name).find("attn.o_proj") != std::string::npos) ||
// mlp layer
is_mlp_block(layer_idx) ||
// llama mlp layer
Expand Down
Loading

0 comments on commit fbac32e

Please sign in to comment.