diff --git a/include/flexflow/ops/inc_multihead_self_attention.h b/include/flexflow/ops/inc_multihead_self_attention.h index ee486ff9fe..5b2acba1bc 100644 --- a/include/flexflow/ops/inc_multihead_self_attention.h +++ b/include/flexflow/ops/inc_multihead_self_attention.h @@ -126,13 +126,14 @@ class IncMultiHeadSelfAttention : public Op { int shard_id, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); - static void peft_bwd_kernel_wrapper(IncMultiHeadSelfAttentionMeta *m, - BatchConfig const *bc, - int shard_id, - GenericTensorAccessorW const &input_grad, - // GenericTensorAccessorR const &weight, - GenericTensorAccessorR const &output_grad); - // GenericTensorAccessorR const &bias); + static void + peft_bwd_kernel_wrapper(IncMultiHeadSelfAttentionMeta *m, + BatchConfig const *bc, + int shard_id, + GenericTensorAccessorW const &input_grad, + // GenericTensorAccessorR const &weight, + GenericTensorAccessorR const &output_grad); + // GenericTensorAccessorR const &bias); Params get_params() const; public: diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 8e8f225955..4b5a3f55ee 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -93,19 +93,20 @@ void LLAMA::create_llama_model(FFModel &ff, } att_norm->print("att_norm"); Tensor qkv_proj = ff.dense( - att_norm, - llama_config.hidden_size * 3, // q, k, v. need to change if want to remove replication. (q_heads + 2 * kv_heads) * proj_size - AC_MODE_NONE, - false, // seems like llama does not use bias - DT_NONE, // what is this - nullptr, // ? - nullptr, // ? - nullptr, // ? - REG_MODE_NONE, // no regularization - 0.0f, // no dropout - std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj") - .c_str() - ); + att_norm, + llama_config.hidden_size * + 3, // q, k, v. need to change if want to remove replication. + // (q_heads + 2 * kv_heads) * proj_size + AC_MODE_NONE, + false, // seems like llama does not use bias + DT_NONE, // what is this + nullptr, // ? + nullptr, // ? + nullptr, // ? + REG_MODE_NONE, // no regularization + 0.0f, // no dropout + std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj") + .c_str()); qkv_proj->print("qkv_proj"); Tensor mha; @@ -189,18 +190,19 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor mha_input = mha; mha_input->print("mha_input"); - mha = ff.dense(mha_input, - llama_config.hidden_size, - AC_MODE_NONE, - false, - DT_NONE, - nullptr, - nullptr, - nullptr, - REG_MODE_NONE, - 0.0f, - std::string("layers." + std::to_string(i) + ".self_attn.o_proj") - .c_str()); + mha = ff.dense( + mha_input, + llama_config.hidden_size, + AC_MODE_NONE, + false, + DT_NONE, + nullptr, + nullptr, + nullptr, + REG_MODE_NONE, + 0.0f, + std::string("layers." + std::to_string(i) + ".self_attn.o_proj") + .c_str()); mha->print("mha"); // step 2: SILU activaion diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 3463c3b235..76bfa89def 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -457,8 +457,7 @@ __host__ void bc, task->index_point.point_data[0], my_input_accessor[0], - my_output_accessor[0] - ); + my_output_accessor[0]); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { @@ -1042,7 +1041,7 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, my_input_grad_accessor[0], // my_weight_accessor[0], my_output_grad_accessor[0]); - // biases); + // biases); break; } case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 92cbd65360..f00bddb661 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -394,8 +394,8 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[i] = _input->dims[i]; } dims[0].size = _embed_dim; - // Currently require no parallelism along this dim, is this consistent with the - // removal of the previous assert? + // Currently require no parallelism along this dim, is this consistent with + // the removal of the previous assert? assert(dims[0].degree == 1); if (allocate_weights) { // Create weight tensor @@ -600,10 +600,13 @@ OpMeta *IncMultiHeadSelfAttention::init_task( attn->num_kv_heads / attn->tensor_parallelism_degree + (attn->num_kv_heads % attn->tensor_parallelism_degree != 0); - if(attn->oProjSize != output.domain.hi()[0] - output.domain.lo()[0] + 1) { - printf("attn o_proj size %d does not match output domain %d\n", attn->oProjSize, output.domain.hi()[0] - output.domain.lo()[0] + 1); + if (attn->oProjSize != output.domain.hi()[0] - output.domain.lo()[0] + 1) { + printf("attn o_proj size %d does not match output domain %d\n", + attn->oProjSize, + output.domain.hi()[0] - output.domain.lo()[0] + 1); } - // assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); + // assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + + // 1); Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); @@ -709,7 +712,7 @@ void IncMultiHeadSelfAttention::inference_task( GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); Domain input_domain = runtime->get_index_space_domain( @@ -724,7 +727,7 @@ void IncMultiHeadSelfAttention::inference_task( assert(task->index_point.get_dim() == 1); IncMultiHeadSelfAttention::inference_kernel_wrapper( - m, bc, task->index_point.point_data[0], input, output); + m, bc, task->index_point.point_data[0], input, output); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); @@ -822,9 +825,11 @@ void IncMultiHeadSelfAttention::peft_bwd_task( GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); // GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - // m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + // m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, + // runtime); // GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorRW( - // m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + // m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, + // runtime); GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorRW( m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorR biases; @@ -862,7 +867,7 @@ void IncMultiHeadSelfAttention::peft_bwd_task( input_grad, // weight, output_grad); - // biases); + // biases); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 0ec9bf4ba5..c9b91e5f80 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -938,7 +938,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, compute_qkv_kernel(m, bc, shard_id, - // input_ptr, + // input_ptr, weight_ptr, static_cast
(m->devQKVProjArray), bias_ptr, diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index f89321554c..f6993e987a 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -538,7 +538,6 @@ __global__ void fill_entries_above_diagonal(DT *matrix, } } - template void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -564,7 +563,6 @@ void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, } #endif - int num_tokens = bc->num_active_tokens(); int parallelism = m->kProjSize * num_tokens * m->num_q_heads; size_t q_array_size = m->qProjSize * num_tokens * m->num_q_heads; @@ -739,7 +737,7 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, } } -// this kernel is no longer used by the attention operator because +// this kernel is no longer used by the attention operator because // there's no more weights // TODO: check if this is needed by the projection layers? template @@ -814,7 +812,8 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, // phase 0: copy calculated qkv into devQKVProjArray // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); cudaMemcpyAsync(m->devQKVProjArray, qkv_ptr, @@ -826,11 +825,11 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, compute_qkv_kernel(m, bc, shard_id, - // input_ptr, - // weight_ptr, - // nullptr, // does not use weight + // input_ptr, + // weight_ptr, + // nullptr, // does not use weight static_cast
(m->devQKVProjArray), - // bias_ptr, + // bias_ptr, stream); update_kv_cache_kernel
(m, bc, stream); @@ -871,50 +870,79 @@ std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, return dst_filepath.string(); } -__global__ void transposeAdd_half_kernel(half *out, const half *in, int width, int height, half alpha, half beta) { - int t_id = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - for(int i = t_id; i < width * height; i += num_threads) { - int row = i / width; - int col = i % width; - out[col * height + row] = alpha * in[row * width + col] + beta * out[col * height + row]; - } +__global__ void transposeAdd_half_kernel( + half *out, half const *in, int width, int height, half alpha, half beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; + } } -__global__ void transposeAdd_float_kernel(float *out, const float *in, int width, int height, float alpha, float beta) { - int t_id = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - for(int i = t_id; i < width * height; i += num_threads) { - int row = i / width; - int col = i % width; - out[col * height + row] = alpha * in[row * width + col] + beta * out[col * height + row]; - } +__global__ void transposeAdd_float_kernel(float *out, + float const *in, + int width, + int height, + float alpha, + float beta) { + int t_id = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + for (int i = t_id; i < width * height; i += num_threads) { + int row = i / width; + int col = i % width; + out[col * height + row] = + alpha * in[row * width + col] + beta * out[col * height + row]; + } } template -void transposeAdd(DT *out, const DT *in, int width, int height, float alpha, float beta, cudaStream_t stream) { - assert(false && "Unsupported data type"); +void transposeAdd(DT *out, + const DT *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + assert(false && "Unsupported data type"); } -template<> -void transposeAdd(float *out, const float *in, int width, int height, float alpha, float beta, cudaStream_t stream) { - transposeAdd_float_kernel<<<4, 1024, 0, stream>>>(out, in, width, height, alpha, beta); +template <> +void transposeAdd(float *out, + float const *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + transposeAdd_float_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, alpha, beta); } -template<> -void transposeAdd(half *out, const half *in, int width, int height, float alpha, float beta, cudaStream_t stream) { - transposeAdd_half_kernel<<<4, 1024, 0, stream>>>(out, in, width, height, __float2half(alpha), __float2half(beta)); +template <> +void transposeAdd(half *out, + half const *in, + int width, + int height, + float alpha, + float beta, + cudaStream_t stream) { + transposeAdd_half_kernel<<<4, 1024, 0, stream>>>( + out, in, width, height, __float2half(alpha), __float2half(beta)); } template -void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - DT *input_grad_ptr, - DT const *weight_ptr, // this is unused, kept for consistency - DT const *output_grad_ptr, - DT const *bias_ptr, - cudaStream_t stream) { +void peft_bwd_kernel( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + DT *input_grad_ptr, + DT const *weight_ptr, // this is unused, kept for consistency + DT const *output_grad_ptr, + DT const *bias_ptr, + cudaStream_t stream) { assert(!m->offload); checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); @@ -1327,12 +1355,14 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, int n_ = num_tokens; int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); - // TODO: checkout if the input grad ptr has some relation with m->devQKVProjArray - // so we may potentially skip this transpose and copy - // TODO: check if this transposeAdd can correctly implement gradient accumulation + // TODO: checkout if the input grad ptr has some relation with + // m->devQKVProjArray so we may potentially skip this transpose and copy + // TODO: check if this transposeAdd can correctly implement gradient + // accumulation transposeAdd(C, B, n_, k_, alpha, beta, stream); - - // printf("backward of raw attn grad: %d, %d, with redudant dimension %d\n", k_, n_, m_); + + // printf("backward of raw attn grad: %d, %d, with redudant dimension + // %d\n", k_, n_, m_); if (m->inference_debugging) { std::string filename = get_peft_dbg_folder(m, shard_id) + ".self_attn.input_gradient_0"; @@ -1685,7 +1715,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( // GenericTensorAccessorR const &weight, GenericTensorAccessorW const &output // GenericTensorAccessorR const &bias - ) { +) { // printf("inf_k_warpper start\n"); cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -1710,7 +1740,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( bc, shard_id, input.get_half_ptr(), - static_cast(nullptr), //weight_ptr is no longer used + static_cast(nullptr), // weight_ptr is no longer used output.get_half_ptr(), static_cast(nullptr), // bias_ptr is no longer used stream); @@ -1720,7 +1750,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( bc, shard_id, input.get_float_ptr(), - static_cast(nullptr), //weight_ptr is no longer used + static_cast(nullptr), // weight_ptr is no longer used output.get_float_ptr(), static_cast(nullptr), // bias_ptr is no longer used stream); @@ -1747,7 +1777,7 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( GenericTensorAccessorW const &input_grad, // GenericTensorAccessorR const &weight, GenericTensorAccessorR const &output_grad) { - // GenericTensorAccessorR const &bias) { + // GenericTensorAccessorR const &bias) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); bool use_bias = *m->qkv_bias || *m->final_bias; @@ -1769,30 +1799,33 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( assert(!m->offload); // half const *bias_ptr = // use_bias ? bias.get_half_ptr() : static_cast(nullptr); - Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, - bc, - shard_id, - input_grad.get_half_ptr(), - // weight.get_half_ptr(), - static_cast(nullptr), - output_grad.get_half_ptr(), - // bias_ptr, - static_cast(nullptr), - stream); + Kernels::IncMultiHeadAttention::peft_bwd_kernel( + m, + bc, + shard_id, + input_grad.get_half_ptr(), + // weight.get_half_ptr(), + static_cast(nullptr), + output_grad.get_half_ptr(), + // bias_ptr, + static_cast(nullptr), + stream); } else if (input_grad.data_type == DT_FLOAT) { assert(!m->offload); // float const *bias_ptr = - // use_bias ? bias.get_float_ptr() : static_cast(nullptr); - Kernels::IncMultiHeadAttention::peft_bwd_kernel(m, - bc, - shard_id, - input_grad.get_float_ptr(), - // weight.get_float_ptr(), - static_cast(nullptr), - output_grad.get_float_ptr(), - // bias_ptr, - static_cast(nullptr), - stream); + // use_bias ? bias.get_float_ptr() : static_cast(nullptr); + Kernels::IncMultiHeadAttention::peft_bwd_kernel( + m, + bc, + shard_id, + input_grad.get_float_ptr(), + // weight.get_float_ptr(), + static_cast(nullptr), + output_grad.get_float_ptr(), + // bias_ptr, + static_cast(nullptr), + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index ee7dd9f4e7..29dc969687 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -631,7 +631,8 @@ void peft_bwd_kernel(LinearMeta const *m, in_dim, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // printf("%s: input_grad has shape %d, %d\n", m->op_name, in_dim, num_peft_tokens); + // printf("%s: input_grad has shape %d, %d\n", m->op_name, in_dim, + // num_peft_tokens); } } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 45d85f6f39..88a3d2e3e4 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -779,7 +779,8 @@ void Linear::peft_bwd_task(Task const *task, if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; - printf("%s: in_dim = %d, out_dim = %d, num_infr_tokens = %d, num_peft_tokens = %d, volume = %d\n", + printf("%s: in_dim = %d, out_dim = %d, num_infr_tokens = %d, " + "num_peft_tokens = %d, volume = %d\n", m->op_name, in_dim, out_dim, diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index 4cd54763ec..bd7f1624ae 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -705,14 +705,14 @@ void SpecIncMultiHeadSelfAttention::inference_task( SpecIncMultiHeadSelfAttentionMeta *m = *((SpecIncMultiHeadSelfAttentionMeta **)task->local_args); - assert(regions.size() ==2); + assert(regions.size() == 2); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorR biases; - + Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); Domain output_domain = runtime->get_index_space_domain( diff --git a/src/ops/spec_inc_multihead_self_attention.cpp b/src/ops/spec_inc_multihead_self_attention.cpp index b48c4bf734..0bf2b3346e 100644 --- a/src/ops/spec_inc_multihead_self_attention.cpp +++ b/src/ops/spec_inc_multihead_self_attention.cpp @@ -501,17 +501,19 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, stream)); // phase 0: copy calculated qkv into devQKVProjArray // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); cudaMemcpyAsync(m->devQKVProjArray, qkv_ptr, - qkv_proj_size * sizeof(DT), // is this right, do we need layers etc here + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here cudaMemcpyDeviceToDevice, stream); // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn first - // compute_qkv_kernel(m, + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first compute_qkv_kernel(m, // bc, // shard_id, // // input_ptr, diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 6144b9bd4c..30cbdc6b10 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -706,22 +706,25 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 0: copy calculated qkv into devQKVProjArray // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); cudaMemcpyAsync(m->devQKVProjArray, qkv_ptr, - qkv_proj_size * sizeof(DT), // is this right, do we need layers etc here + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here cudaMemcpyDeviceToDevice, stream); // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn first + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first compute_qkv_kernel(m, bc, shard_id, - // input_ptr, - // weight_ptr, + // input_ptr, + // weight_ptr, static_cast
(m->devQKVProjArray), - // bias_ptr, + // bias_ptr, stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index a3f6757df3..4564ca6cc2 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -159,7 +159,7 @@ Tensor FFModel::inc_multiquery_self_attention_verify( int one_head_size = qParas + kParas + vParas + oParas; int weight_size = qParas * num_q_heads + kParas * num_q_heads + vParas * num_q_heads + oParas * num_q_heads; - + li->data_type = data_type; li->add_int_property("embed_dim", embed_dim); li->add_int_property("num_q_heads", num_q_heads); @@ -392,7 +392,8 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[i] = _input->dims[i]; } dims[0].size = _embed_dim; - // Currently require no parallelism along this dim, is this aligned with the previous removal of assert? + // Currently require no parallelism along this dim, is this aligned with the + // previous removal of assert? assert(dims[0].degree == 1); if (allocate_weights) { // Create weight tensor @@ -597,10 +598,13 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task( int num_kv_heads = attn->num_kv_heads / attn->tensor_parallelism_degree + (attn->num_kv_heads % attn->tensor_parallelism_degree != 0); - if(attn->oProjSize != output.domain.hi()[0] - output.domain.lo()[0] + 1) { - std::cout<<"attn->oProjSize: "<oProjSize<<" does not match output domain dim[0]: "<oProjSize != output.domain.hi()[0] - output.domain.lo()[0] + 1) { + std::cout << "attn->oProjSize: " << attn->oProjSize + << " does not match output domain dim[0]: " + << output.domain.hi()[0] - output.domain.lo()[0] + 1 << std::endl; } - // assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); + // assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + + // 1); Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); diff --git a/src/ops/tree_inc_multihead_self_attention.cpp b/src/ops/tree_inc_multihead_self_attention.cpp index 585bf3fa46..ff592ddccb 100644 --- a/src/ops/tree_inc_multihead_self_attention.cpp +++ b/src/ops/tree_inc_multihead_self_attention.cpp @@ -936,8 +936,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, bias_ptr = static_cast
(m->bias_ptr); } // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn first - // compute_qkv_kernel(m, + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first compute_qkv_kernel(m, // bc, // shard_id, // // input_ptr, diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 9619070737..c2ba0ecbde 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -916,23 +916,26 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, } // phase 0: copy calculated qkv into devQKVProjArray // [qProjSize, num_heads, 3, num_new_tokens] - size_t qkv_proj_size = m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); + size_t qkv_proj_size = + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM * bc->num_active_tokens(); cudaMemcpyAsync(m->devQKVProjArray, qkv_ptr, - qkv_proj_size * sizeof(DT), // is this right, do we need layers etc here + qkv_proj_size * + sizeof(DT), // is this right, do we need layers etc here cudaMemcpyDeviceToDevice, stream); // phase 1: Implement kernel to compute KQV for input tokens - // TODO WARNING: this is commented out only because we are fixing the inc_attn first + // TODO WARNING: this is commented out only because we are fixing the inc_attn + // first compute_qkv_kernel(m, bc, shard_id, - // input_ptr, - // weight_ptr, + // input_ptr, + // weight_ptr, static_cast
(m->devQKVProjArray), - // bias_ptr, + // bias_ptr, stream); // phase 2: No need to update key/val cache @@ -985,25 +988,23 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( assert(input.data_type == output.data_type); if (input.data_type == DT_HALF) { - Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - (half*)nullptr, - output.get_half_ptr(), - (half*)nullptr, - stream); + Kernels::TreeIncMultiHeadAttention::inference_kernel(m, + bc, + shard_id, + input.get_half_ptr(), + (half *)nullptr, + output.get_half_ptr(), + (half *)nullptr, + stream); } else if (input.data_type == DT_FLOAT) { - Kernels::TreeIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - (float*)nullptr, - output.get_float_ptr(), - (float*)nullptr, - stream); + Kernels::TreeIncMultiHeadAttention::inference_kernel(m, + bc, + shard_id, + input.get_float_ptr(), + (float *)nullptr, + output.get_float_ptr(), + (float *)nullptr, + stream); } else { assert(false && "Unspported data type"); } diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index 0cb12e3b0e..9a6c561f18 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -129,12 +129,12 @@ void load_attention_weights_multi_query(DT *ptr, template void load_attention_o_proj_bias_to_dense_v2(DT *ptr, - int num_heads, - int num_kv_heads, - size_t hidden_dim, - size_t qkv_inner_dim, - std::string layer_name, - std::string weights_folder) { + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + std::string layer_name, + std::string weights_folder) { std::string filename = layer_name + ".o_proj.bias"; int file_index = 0; @@ -262,15 +262,15 @@ void load_attention_bias_v2(DT *ptr, template void load_attention_weights_to_dense_v2(DT *ptr, - int num_heads, - int num_kv_heads, - size_t hidden_dim, - size_t qkv_inner_dim, - std::string layer_name, - std::string weights_folder, - size_t volume, - int tensor_parallelism_degree, - bool load_o_proj) { + int num_heads, + int num_kv_heads, + size_t hidden_dim, + size_t qkv_inner_dim, + std::string layer_name, + std::string weights_folder, + size_t volume, + int tensor_parallelism_degree, + bool load_o_proj) { // layers_0_attention_wq_weight // layers_0_self_attn_q_proj_weight std::string q_file = layer_name + ".q_proj.weight"; @@ -299,9 +299,10 @@ void load_attention_weights_to_dense_v2(DT *ptr, // stride for q, k, v, o size_t stride_size = (q_size + v_replicate_size + k_replicate_size) / tensor_parallelism_degree; - if(!load_o_proj) { + if (!load_o_proj) { for (auto filename : weight_filenames) { - std::cout << "Loading weight file " << filename << " to dense"<< std::endl; + std::cout << "Loading weight file " << filename << " to dense" + << std::endl; std::string weight_filepath = join_path({weights_folder, filename}); int data_index = 0; @@ -342,17 +343,18 @@ void load_attention_weights_to_dense_v2(DT *ptr, int head_idx = i % (num_heads / tensor_parallelism_degree); int tp_idx = (i / (num_heads / tensor_parallelism_degree)); for (int j = 0; j < single_proj_size; j++) { - ptr[base_index + tp_idx * stride_size + single_proj_size * head_idx + - j] = host_array.at(kv_idx * single_proj_size + j); + ptr[base_index + tp_idx * stride_size + + single_proj_size * head_idx + j] = + host_array.at(kv_idx * single_proj_size + j); } } } - std::cout<<"host array going out of scope, releasing"<config.benchmarking) { std::cout << "Initializing weight " << weight_filename << " with random data (benchmarking mode)" << std::endl; @@ -957,9 +959,9 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, // weight_filename, // weights_folder); // } - } else if(is_attn_proj) { - if(is_o_proj) { - if(weight_idx == 0) { + } else if (is_attn_proj) { + if (is_o_proj) { + if (weight_idx == 0) { load_attention_weights_to_dense_v2(data, num_heads, num_kv_heads, @@ -978,10 +980,9 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, qkv_inner_dim, weight_filename, weights_folder); - } } else { - if(weight_idx == 0) { + if (weight_idx == 0) { load_attention_weights_to_dense_v2(data, num_heads, num_kv_heads, diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 40d4ca9766..e3bc433302 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1154,19 +1154,25 @@ bool Op::check_output_input_weight_same_parallel_is() const { IndexSpace parallel_is = outputs[0]->parallel_is; for (int i = 0; i < numOutputs; i++) { if (outputs[i]->parallel_is != parallel_is) { - std::cout<<"outputs["<parallel_is<<" than output[0] "<parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } for (int i = 0; i < numInputs; i++) { if (inputs[i]->parallel_is != parallel_is) { - std::cout<<"inputs["<parallel_is<<" than output[0] "<parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } for (int i = 0; i < numWeights; i++) { if (weights[i]->parallel_is != parallel_is) { - std::cout<<"weights["<parallel_is<<" than output[0] "<parallel_is << " than output[0] " << parallel_is + << std::endl; return false; } } @@ -3416,27 +3422,28 @@ bool FFModel::need_to_add_allreduce(int layer_idx) const { if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && ( - // l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || - // l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - (std::string(l->name).find(".self_attn.o_proj") != std::string::npos) || - // mlp layer - is_mlp_block(layer_idx) || - // llama mlp layer - (l->op_type == OP_LINEAR && layer_idx >= 2 && - layers[layer_idx - 1]->op_type == OP_GELU && - layers[layer_idx - 2]->op_type == OP_LINEAR) || - // LLAMA without element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 5 && - layers[layer_idx - 1]->op_type == OP_EW_MUL && - layers[layer_idx - 2]->op_type == OP_EW_MUL && - layers[layer_idx - 3]->op_type == OP_SIGMOID && - layers[layer_idx - 4]->op_type == OP_LINEAR && - layers[layer_idx - 5]->op_type == OP_LINEAR) || - // LLAMA with element-wise operator fusion - (l->op_type == OP_LINEAR && layer_idx >= 3 && - layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && - layers[layer_idx - 2]->op_type == OP_LINEAR && - layers[layer_idx - 3]->op_type == OP_LINEAR))) { + // l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || + // l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || + (std::string(l->name).find(".self_attn.o_proj") != + std::string::npos) || + // mlp layer + is_mlp_block(layer_idx) || + // llama mlp layer + (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_GELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) || + // LLAMA without element-wise operator fusion + (l->op_type == OP_LINEAR && layer_idx >= 5 && + layers[layer_idx - 1]->op_type == OP_EW_MUL && + layers[layer_idx - 2]->op_type == OP_EW_MUL && + layers[layer_idx - 3]->op_type == OP_SIGMOID && + layers[layer_idx - 4]->op_type == OP_LINEAR && + layers[layer_idx - 5]->op_type == OP_LINEAR) || + // LLAMA with element-wise operator fusion + (l->op_type == OP_LINEAR && layer_idx >= 3 && + layers[layer_idx - 1]->op_type == OP_SIGMOID_SILU_MULTI && + layers[layer_idx - 2]->op_type == OP_LINEAR && + layers[layer_idx - 3]->op_type == OP_LINEAR))) { return true; } return false; diff --git a/src/runtime/operator.cc b/src/runtime/operator.cc index 52f192902b..d5bfcfc48e 100644 --- a/src/runtime/operator.cc +++ b/src/runtime/operator.cc @@ -2,8 +2,8 @@ #include "flexflow/ffconst_utils.h" #include "flexflow/simulator.h" #include -#include #include +#include namespace FlexFlow { @@ -33,11 +33,12 @@ fs::path get_dst_folder(std::string const &subdir, char cwd[PATH_MAX]; getcwd(cwd, sizeof(cwd)); - // char const *ff_cache_path = std::string(std::getenv("FF_DEBUG_PATH")) == "." ? + // char const *ff_cache_path = std::string(std::getenv("FF_DEBUG_PATH")) == + // "." ? // cwd : std::getenv("FF_DEBUG_PATH"); char const *ff_cache_path = std::getenv("FF_CACHE_PATH"); - + std::string debug_dir_ = ff_cache_path ? std::string(ff_cache_path) + "/debug/flexflow" : std::string("~/.cache/flexflow/debug/flexflow"); @@ -46,7 +47,7 @@ fs::path get_dst_folder(std::string const &subdir, debug_dir_ = p.we_wordv[0]; wordfree(&p); fs::path debug_dir = debug_dir_; - if(!fs::is_directory(debug_dir)) { + if (!fs::is_directory(debug_dir)) { printf("invalid debug directory: %s\n", debug_dir.c_str()); } assert(fs::is_directory(debug_dir));