From fbac32ea33289f19e3a7dc4abee194ed2feda5a6 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 28 Sep 2024 04:37:07 +0000 Subject: [PATCH] update --- backup.txt | 0 inference/incr_decoding/incr_decoding.cc | 2 +- inference/models/mpt.cc | 6 +- inference/python/incr_decoding.py | 10 +- src/ops/inc_multihead_self_attention.cc | 7 +- src/ops/inc_multihead_self_attention.cpp | 17 +- src/ops/inc_multihead_self_attention.cu | 79 ++--- src/ops/spec_inc_multihead_self_attention.cu | 24 +- src/ops/tree_inc_multihead_self_attention.cu | 7 +- src/runtime/file_loader.cc | 15 +- src/runtime/model.cc | 3 +- tests/fine_grained_alignment_test.sh | 78 +++++ tests/inference/huggingface_inference.py | 49 +-- tests/inference/inference_alignment_test.py | 329 +++++++++++++++++++ tests/peft/alignment/align_test_utils.py | 13 +- tests/peft/hf_finetune.py | 2 +- tests/peft/hf_utils.py | 15 +- 17 files changed, 515 insertions(+), 141 deletions(-) create mode 100644 backup.txt create mode 100755 tests/fine_grained_alignment_test.sh create mode 100644 tests/inference/inference_alignment_test.py diff --git a/backup.txt b/backup.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index c9ffff5c07..8c70c19eb9 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -271,7 +271,7 @@ void FlexFlow::top_level_task(Task const *task, printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); Request inference_req; inference_req.prompt = text; - inference_req.max_sequence_length = 128; + inference_req.max_sequence_length = 10; requests.push_back(inference_req); total_num_requests++; } diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index 9986182495..64e5924753 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -106,8 +106,7 @@ void MPT::create_mpt_model(FFModel &ff, nullptr, // ? REG_MODE_NONE, // no regularization 0.0f, // no dropout - std::string("layers." + std::to_string(i) + ".attn.qkv_proj") - .c_str()); + std::string("layers." + std::to_string(i) + ".attn.qkv_proj").c_str()); Tensor o_proj; switch (mode) { @@ -199,8 +198,7 @@ void MPT::create_mpt_model(FFModel &ff, nullptr, REG_MODE_NONE, 0.0f, - std::string("layers." + std::to_string(i) + ".attn.o_proj") - .c_str()); + std::string("layers." + std::to_string(i) + ".attn.o_proj").c_str()); ff.residual_layer_norm( attn_outputs, diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index f888982f2c..1df5a05a8f 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -111,9 +111,15 @@ def main(): if len(configs.prompt) > 0: prompts = [s for s in json.load(open(configs.prompt))] - results = llm.generate(prompts) + if "max_length" not in configs_dict: + results = llm.generate(prompts) + else: + results = llm.generate(prompts, max_length=configs.max_length) else: - result = llm.generate("Three tips for staying healthy are: ") + if "max_length" not in configs_dict: + result = llm.generate("Three tips for staying healthy are: ") + else: + result = llm.generate("Three tips for staying healthy are: ", max_length=configs.max_length) llm.stop_server() diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 31dab57b3a..1bea204601 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -599,7 +599,6 @@ OpMeta *IncMultiHeadSelfAttention::init_task( attn->num_kv_heads / attn->tensor_parallelism_degree + (attn->num_kv_heads % attn->tensor_parallelism_degree != 0); - Memory gpu_mem = get_proc_mem(Machine::get_machine(), task->target_proc); MemoryAllocator gpu_mem_allocator(gpu_mem); if (attn->offload) { @@ -809,11 +808,7 @@ void IncMultiHeadSelfAttention::peft_bwd_task( assert(task->index_point.get_dim() == 1); IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( - m, - bc, - task->index_point.point_data[0], - input_grad, - output_grad); + m, bc, task->index_point.point_data[0], input_grad, output_grad); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 0093d417b5..81a3401da3 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -951,8 +951,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, if (bc->num_tokens > bc->num_generation_tokens) { // phase 4: Compute attention score for prompt tokens; - compute_attention_kernel_prompt( - m, bc, shard_id, stream); + compute_attention_kernel_prompt(m, bc, shard_id, stream); } // compute output production and bias together for all tokens @@ -1795,12 +1794,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( half const *bias_ptr = use_bias ? bias.get_half_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - output.get_half_ptr(), - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { if (m->offload) { pre_build_weight_kernel(m, weight, input.data_type, stream); @@ -1808,12 +1802,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( float const *bias_ptr = use_bias ? bias.get_float_ptr() : static_cast(nullptr); Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - output.get_float_ptr(), - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 0fe728be86..0ac8653b4a 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -542,26 +542,24 @@ template void compute_qkv_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, int shard_id, - // DT const *weight_ptr, DT *output_ptr, - // DT const *bias_ptr, cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); assert(m->qSize == m->vSize && m->qSize == m->kSize); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); -#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - cudaDataType_t compute_type = cublas_data_type; -#else - // For best performance, set the default cublas compute type to - // CUBLAS_COMPUTE_16F for half precision and to - // CUBLAS_COMPUTE_32F_FAST_16F for full precision - cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - if (m->output_type[0] == DT_FLOAT) { - compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - } -#endif + // cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); + // #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) + // cudaDataType_t compute_type = cublas_data_type; + // #else + // // For best performance, set the default cublas compute type to + // // CUBLAS_COMPUTE_16F for half precision and to + // // CUBLAS_COMPUTE_32F_FAST_16F for full precision + // cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + // if (m->output_type[0] == DT_FLOAT) { + // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + // } + // #endif int num_tokens = bc->num_active_tokens(); int parallelism = m->kProjSize * num_tokens * m->num_q_heads; @@ -820,11 +818,8 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, stream); // phase 1: Implement kernel to apply rotary embedding and scaling - compute_qkv_kernel(m, - bc, - shard_id, - static_cast
(m->devQKVProjArray), - stream); + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { @@ -835,8 +830,12 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, if (bc->num_tokens > bc->num_generation_tokens) { // phase 4: Compute attention score for prompt tokens; - compute_attention_kernel_prompt( - m, bc, shard_id, static_cast(nullptr), static_cast(nullptr), stream); + compute_attention_kernel_prompt(m, + bc, + shard_id, + static_cast
(nullptr), + static_cast
(nullptr), + stream); } // compute output production and bias together for all tokens @@ -1345,12 +1344,12 @@ void peft_bwd_kernel( // matrix C's layout: [m->qSize, num_tokens] DT *C = input_grad_ptr + bc->requestsInfo[i].first_token_offset_in_batch * m->qSize; - int m_ = m->qSize; + // int m_ = m->qSize; int n_ = num_tokens; int k_ = m->num_q_heads * (m->qProjSize + m->kProjSize + m->vProjSize); // The original version uses existing result and attention's projection to - // do further calculation in a way different than the usual dense layer, + // do further calculation in a way different than the usual dense layer, // they are off by a transpose. So an explicit transpose is needed here. // The add here is just for gradient accumulation. transposeAdd(C, B, n_, k_, alpha, beta, stream); @@ -1704,8 +1703,7 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( BatchConfig const *bc, int shard_id, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &output -) { + GenericTensorAccessorW const &output) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -1720,20 +1718,10 @@ void IncMultiHeadSelfAttention::inference_kernel_wrapper( if (input.data_type == DT_HALF) { Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - output.get_half_ptr(), - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { Kernels::IncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - output.get_float_ptr(), - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } @@ -1758,7 +1746,7 @@ void IncMultiHeadSelfAttention::peft_bwd_kernel_wrapper( GenericTensorAccessorR const &output_grad) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; + // bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; if (m->profiling) { @@ -2132,4 +2120,19 @@ template void BatchConfig const *bc, half *output_ptr, cudaStream_t stream); + +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + float *output_ptr, + ffStream_t stream); + +template void Kernels::IncMultiHeadAttention::compute_qkv_kernel( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + half *output_ptr, + ffStream_t stream); + }; // namespace FlexFlow diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 88c59c2053..4c65a8baa8 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -714,11 +714,8 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel(m, - bc, - shard_id, - static_cast
(m->devQKVProjArray), - stream); + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { @@ -728,8 +725,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 if (bc->num_tokens > bc->num_generation_tokens) { - compute_attention_kernel_prompt( - m, bc, shard_id, output_ptr, stream); + compute_attention_kernel_prompt(m, bc, shard_id, output_ptr, stream); } // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); @@ -767,20 +763,10 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( if (input.data_type == DT_HALF) { half const *bias_ptr = static_cast(nullptr); Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - output.get_half_ptr(), - stream); + m, bc, shard_id, input.get_half_ptr(), output.get_half_ptr(), stream); } else if (input.data_type == DT_FLOAT) { Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - output.get_float_ptr(), - stream); + m, bc, shard_id, input.get_float_ptr(), output.get_float_ptr(), stream); } else { assert(false && "Unspported data type"); } diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index e88fe95b22..43e8e46d49 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -929,11 +929,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // phase 1: Implement kernel to compute KQV for input tokens // TODO WARNING: this is commented out only because we are fixing the inc_attn // first - compute_qkv_kernel(m, - bc, - shard_id, - static_cast
(m->devQKVProjArray), - stream); + compute_qkv_kernel( + m, bc, shard_id, static_cast
(m->devQKVProjArray), stream); // phase 2: No need to update key/val cache compute_attention_kernel_fused
( diff --git a/src/runtime/file_loader.cc b/src/runtime/file_loader.cc index 6aa4e418a6..561db0c76b 100644 --- a/src/runtime/file_loader.cc +++ b/src/runtime/file_loader.cc @@ -287,7 +287,9 @@ void load_attention_weights_to_dense_v2(DT *ptr, size_t one_weight_file_size = num_heads * single_proj_size; // size of each of Q/K/V/O for all heads - std::cout<<"hidden_dim: "<op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || // l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - (std::string(l->name).find("attn.o_proj") != - std::string::npos) || + (std::string(l->name).find("attn.o_proj") != std::string::npos) || // mlp layer is_mlp_block(layer_idx) || // llama mlp layer diff --git a/tests/fine_grained_alignment_test.sh b/tests/fine_grained_alignment_test.sh new file mode 100755 index 0000000000..681a015600 --- /dev/null +++ b/tests/fine_grained_alignment_test.sh @@ -0,0 +1,78 @@ +#! /usr/bin/env bash +# set -x +set -e + +MODEL_NAME=${MODEL_NAME:-"JackFram/llama-160m"} +MEMORY_PER_GPU=${MEMORY_PER_GPU:-14000} +ZCOPY_MEMORY=${ZCOPY_MEMORY:-40000} +CACHE_PATH=${FF_CACHE_PATH:-"~/.cache/flexflow"} + +cleanup() { + rm -rf ${CACHE_PATH}/debug ./fine_grained_alignment_config.json ./inference/output/fine_grained_alignment_test_ff.txt ./inference/output/fine_grained_alignment_test_hf.txt +} + +# Cd into directory holding this script +cd "${BASH_SOURCE[0]%/*}/.." + +# Initial cleanup +cleanup + +# Create test prompt file +mkdir -p ./inference/prompt +echo '["Three tips for staying healthy are: "]' > ./inference/prompt/test.json + +# Create output folder +mkdir -p ./inference/output + +# Enable backtrace in case we run into a segfault or assertion failure +export LEGION_BACKTRACE=1 + +python ./tests/inference/huggingface_inference.py --model-name $MODEL_NAME --max-length 10 --prompt-file ../../inference/prompt/test.json --output-file ../../inference/output/fine_grained_alignment_test_hf.txt --use-full-precision --inference-debugging + +json_config=$(cat <<-END + { + "num_gpus": 4, + "memory_per_gpu": ${MEMORY_PER_GPU}, + "zero_copy_memory_per_node": ${ZCOPY_MEMORY}, + "num_cpus": 4, + "legion_utility_processors": 4, + "data_parallelism_degree": 1, + "tensor_parallelism_degree": 2, + "pipeline_parallelism_degree": 2, + "inference_debugging": true, + "fusion": true, + "refresh_cache": false, + "llm_model": "${MODEL_NAME}", + "cache_path": "${CACHE_PATH}", + "full_precision": true, + "prompt": "./inference/prompt/test.json", + "max_length": 10, + "output_file": "./inference/output/fine_grained_alignment_test_ff.txt" + } +END +) +echo $json_config > ./fine_grained_alignment_config.json + +python ./inference/python/incr_decoding.py -config-file ./fine_grained_alignment_config.json + +# # C++ test +# echo "C++ test" +# ./build/inference/incr_decoding/incr_decoding \ +# -ll:gpu 2 -ll:cpu 4 -ll:util 4 \ +# -tensor-parallelism-degree 2 \ +# -ll:fsize 8192 -ll:zsize 12000 \ +# -llm-model $MODEL_NAME \ +# -prompt ./inference/prompt/peft.json \ +# --use-full-precision \ +# --inference-debugging + +# Check alignment +python ./tests/inference/inference_alignment_test.py -m $MODEL_NAME -tp 2 -n 2 + +# Print succeess message +echo "" +echo "Inference alignment tests passed!" +echo "" + +# Cleanup after the test +cleanup diff --git a/tests/inference/huggingface_inference.py b/tests/inference/huggingface_inference.py index 5e563c9974..1a2bcf9509 100644 --- a/tests/inference/huggingface_inference.py +++ b/tests/inference/huggingface_inference.py @@ -10,30 +10,9 @@ LlamaTokenizer, GenerationConfig, ) -######################### debugging helper functions ######################### -def pre_forward_hook(module, input): - assert module.name is not None and module.decoding_step is not None - name = module.name.replace("model.", "") - print( - f"Pre-forward hook activated on module: {name}, decoding step: {module.decoding_step}" - ) - print("Pre-Input: ", input[0].shape) - torch.save( - input, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.input" - ) -def post_forward_hook(module, input, output): - assert module.name is not None and module.decoding_step is not None - name = module.name.replace("model.", "") - print( - f"Post-forward Hook activated for module: {name}, decoding step: {module.decoding_step}" - ) - print("Post-Input/Output: ", input[0].shape, output[0].shape) - torch.save( - output, f"./hf_tensors/decoding_step_{module.decoding_step}_{name}.output" - ) - print("===") - module.decoding_step += 1 -############################################################################## +import sys +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "peft")) +from hf_utils import * def main(): # Change working dir to folder storing this script @@ -91,26 +70,20 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) generation_config = GenerationConfig.from_pretrained(args.model_name) generation_config.do_sample = args.do_sample + if not args.do_sample: + generation_config.num_beams=1 + generation_config.temperature = None + generation_config.top_p = None ################# debugging ################# if args.inference_debugging: # Print model and configs print(hf_config) print(model) - # Save weights to file - shutil.rmtree("./hf_tensors") - # Check that the output folder exists - os.makedirs("./hf_tensors", exist_ok=True) + make_debug_dirs() + register_inference_hooks(model) # Save weights - for name, params in model.named_parameters(): - torch.save(params, f"./hf_tensors/{name}") - # params.detach().cpu().numpy().tofile(f"./hf_tensors/{name}") - # Register hooks to save per-op hidden states - for name, layer in dict(model.named_modules()).items(): - layer.name = name - layer.decoding_step = 0 - print(f"Adding hooks to layer {layer.name}") - layer.register_forward_pre_hook(pre_forward_hook) - layer.register_forward_hook(post_forward_hook) + # save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"]) + ############################################### # Generate output with open(args.output_file, "w") as f: diff --git a/tests/inference/inference_alignment_test.py b/tests/inference/inference_alignment_test.py new file mode 100644 index 0000000000..614723e2c4 --- /dev/null +++ b/tests/inference/inference_alignment_test.py @@ -0,0 +1,329 @@ +import numpy as np +import os, torch, argparse, sys +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "peft")) +from alignment.align_test_utils import * +from transformers import AutoConfig +from tqdm import tqdm + +class AlignmentTest: + def __init__(self, model_name, tp_degree=1): + raise NotImplementedError() + def check_weights_alignment(self): + raise NotImplementedError() + def check_fwd_pass(self): + raise NotImplementedError() + def check_bwd_pass(self): + raise NotImplementedError() + def check_step(self, step_idx, learning_rate=0.001): + raise NotImplementedError() + +class LllamaAlignmentTest(AlignmentTest): + def __init__(self, model_name, tp_degree=1): + self.model_name = model_name + self.hf_config = AutoConfig.from_pretrained(model_name) + self.num_layers = self.hf_config.num_hidden_layers + self.hidden_size = self.hf_config.hidden_size + self.intermediate_size = self.hf_config.intermediate_size + self.num_attention_heads = self.hf_config.num_attention_heads + self.num_key_value_heads = self.num_attention_heads + self.projsize = self.hidden_size // self.num_attention_heads + self.tp_degree = tp_degree + + self.num_tokens = None + self.ff_batch_size = None + + + def check_weights_alignment(self): + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "lm_head.weight": + f_version = f"layers.{self.num_layers-1}.lm_head.weight_0" + elif hf_filename == "norm.weight": + f_version = f"layers.{self.num_layers-1}.norm.weight_0" + else: + f_version = "" + if hf_filename.startswith("layers."): + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version += f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "") + # compute weight index, then rename lora if needed if needed + weight_index="0" + if "lora_A" in f_version: + weight_index="A" + elif "lora_B" in f_version: + weight_index="B" + f_version = f_version.replace("lora_A", "lora").replace("lora_B", "lora") + if f_version.endswith(".weight"): + if weight_index == "0": + f_version += f"_{weight_index}" + else: + f_version += f"_{weight_index}.original" + elif f_version.endswith(".gradient"): + prefix = f_version.split(".gradient")[0] + f_version = prefix + f".weight_{weight_index}.gradient" + return f_version + def get_tp_partition_dim(ff_weight_name) -> int: + # MLP layers split the intermediate size dimension + # gate_proj, up_proj: [hidden_size, intermediate_size] + # down_proj: [intermediate_size, hidden_size] + if self.tp_degree == 1: + return -1 + if "lora.weight_B" in ff_weight_name: + return -1 + if "lm_head" in ff_weight_name or "norm" in ff_weight_name: + return 1 + if "gate_proj" in ff_weight_name or "up_proj" in ff_weight_name: + return 1 + elif "down_proj" in ff_weight_name: + return 0 + else: + return -1 + print("-- Weights alignment --") + hf_weights_folder = os.path.join(hf_path, "weights", "step_0") + ff_weights_folder = os.path.join(ff_path, "weights", "step_0", "shard_0") + files_list = os.listdir(hf_weights_folder) + for hf_weight_name in tqdm(sorted(files_list)): + if hf_weight_name.endswith(".weight"): + ff_weight_name = convert_hf_filename_to_ff(hf_weight_name) + # print(hf_weight_name, ff_weight_name) + hf_w_path = os.path.join(hf_weights_folder, hf_weight_name) + ff_w_path = os.path.join(ff_weights_folder, ff_weight_name) + if not os.path.isfile(hf_w_path): + print(f"File '{hf_w_path}' not found") + if not os.path.isfile(ff_w_path): + print(f"File '{ff_w_path}' not found") + assert(os.path.isfile(hf_w_path)) + assert(os.path.isfile(ff_w_path)) + + # 1. get shape of hf weight + hf_weight = torch.load(hf_w_path, map_location='cpu') + hf_weigth_shape = hf_weight.shape + ff_partition_dim = get_tp_partition_dim(ff_weight_name) + ff_weigth_shape = list(hf_weigth_shape)[::-1] + if ff_partition_dim >= 0: + ff_weigth_shape[ff_partition_dim] //= self.tp_degree + + # 2. handle flexflow shards in case of tensor parallelism + ff_weights = [load_ff_tensor(ff_w_path.replace("shard_0", f"shard_{tp_idx}"), ff_weigth_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + if ff_partition_dim >= 0: + ff_weight = np.concatenate(ff_weights, axis=ff_partition_dim) + else: + assert(are_np_arrays_identical(ff_weights)) + ff_weight = ff_weights[0] + else: + ff_weight = ff_weights[0] + ff_weight = torch.from_numpy(ff_weight).to(hf_weight.dtype) + + # check equivalence + try: + torch.testing.assert_close(ff_weight, hf_weight.T) + except Exception as e: + print(f"Error comparing {ff_w_path} weight to {hf_w_path}:\n{e}\n") + raise e + + def check_fwd_pass(self, step_idx=0): + hf_fwd_folder = os.path.join(hf_path, "fwd", f"step_{step_idx}") + ff_fwd_folder = os.path.join(ff_path, "fwd", f"step_{step_idx}", "shard_0") + + def convert_hf_filename_to_ff(hf_filename): + if hf_filename == "embed_tokens": + f_version = f"layers.0.embed_tokens" + elif hf_filename == "lm_head" or hf_filename == "norm": + f_version = f"layers.{self.num_layers-1}.{hf_filename}" + else: + assert hf_filename.startswith("layers.") + layernum = hf_filename.split("layers.")[1].split(".")[0] + f_version = f"layers.{layernum}." + f_version += hf_filename.replace(".base_layer", "").replace(".default", "") + # right now, attention in flexflow is done with a single operator, so there is a single output file without the projection suffix + f_version = f_version.replace(".q_proj", "").replace(".k_proj", "").replace(".v_proj", "").replace(".o_proj", "") + return f_version + + def get_hf_tensor(hf_tensor_name, tensor_comparison_idx): + hf_tensor_filename = f"{hf_tensor_name}.{tensor_comparison_idx.hf_tensor_type}_{tensor_comparison_idx.hf_tensor_idx}" + hf_tensor_path = os.path.join(hf_fwd_folder, hf_tensor_filename) + + if not os.path.isfile(hf_tensor_path): + raise FileNotFoundError(f"File '{hf_tensor_path}' not found") + print("loading hf tensor: ", hf_tensor_filename) + hf_tensor = torch.load(hf_tensor_path, map_location='cpu') + if hf_tensor_name == "embed_tokens": + self.num_tokens = hf_tensor.shape[1] + return hf_tensor + + def get_ff_tensor(ff_tensor_name, tensor_comparison_idx, hf_shape, tp_type=TPType.REPLICATE): + ff_tensor_suffix = f".{tensor_comparison_idx.ff_tensor_type}" if len(tensor_comparison_idx.ff_tensor_type) > 0 else "" + ff_tensor_idx_suffix = f"_{tensor_comparison_idx.ff_tensor_idx}" if tensor_comparison_idx.ff_tensor_idx is not None else "" + ff_tensor_filename = f"{ff_tensor_name}{ff_tensor_suffix}{ff_tensor_idx_suffix}" + ff_tensor_path = os.path.join(ff_fwd_folder, ff_tensor_filename) + if not os.path.isfile(ff_tensor_path): + raise FileNotFoundError(f"File '{ff_tensor_path}' not found") + + print("loading ff tensor: ", ff_tensor_filename) + ff_shape = list(hf_shape)[::-1] + if tp_type == TPType.PARTITION: + ff_shape[0] //= self.tp_degree + + if "layers.0.embed_tokens.input_0" in ff_tensor_path: + # get number of tokens + ff_tensor = np.loadtxt(ff_tensor_path, delimiter=',') + self.ff_batch_size = ff_tensor.shape[0] + + ff_shape = replace_value(ff_shape, self.num_tokens, self.ff_batch_size) + ff_tensors = [load_ff_tensor(ff_tensor_path.replace("shard_0", f"shard_{tp_idx}"), ff_shape) for tp_idx in range(self.tp_degree)] + if self.tp_degree > 1: + # if replicate, check that they are identical + if tp_type == TPType.REPLICATE: + assert(are_np_arrays_identical(ff_tensors)) + ff_tensor = ff_tensors[0] + # if partition, concatenate along the partition dimension + elif tp_type == TPType.PARTITION: + ff_tensor = np.concatenate(ff_tensors, axis=0) + # if to_reduce, sum along the partition dimension + elif tp_type == TPType.TO_REDUCE: + ff_tensor = np.sum(ff_tensors, axis=0) + else: + ff_tensor = ff_tensors[0] + ff_tensor = torch.from_numpy(ff_tensor) + ff_tensor = truncate_dimension(ff_tensor, self.ff_batch_size, self.num_tokens) + return ff_tensor + + def compare(hf_tensor, ff_tensor, label="", additional_ff_tensor=None, tolerance=1e-2): + ff_tensor = ff_tensor.to(hf_tensor.dtype) + hf_tensor = hf_tensor.T + if additional_ff_tensor is not None: + additional_ff_tensor = additional_ff_tensor.to(hf_tensor.dtype) + ff_tensor = ff_tensor - additional_ff_tensor + try: + # torch.testing.assert_close(hf_tensor, ff_tensor, rtol=1.3e-6, atol=tolerance) + if not np.allclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance): + mismatches = np.where(~np.isclose(hf_tensor.detach().numpy(), ff_tensor.detach().numpy(), atol=tolerance))[0] + print(f"Pct mismatch {label}: {100.0*(np.prod(mismatches.shape) / ff_tensor.numel()):.3f}%") + assert(np.prod(mismatches.shape) <= .05 * ff_tensor.numel()) + except Exception as e: + print(f"Error in comparison {label}:\n{e}\n") + print("HF tensor:") + print(hf_tensor.squeeze()) + print(hf_tensor.shape) + print("FF tensor:") + print(ff_tensor.squeeze()) + print(ff_tensor.shape) + raise e + + print(f"-- FWD pass {step_idx}--") + + # Embedding layer + hf_tensor_name = "embed_tokens" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Embedding output") + + # Transformers blocks + for i in range(self.num_layers): + # Input laye norm + hf_tensor_name = f"layers.{i}.input_layernorm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + if i == 0: + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + else: + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Input layernorm {i} input") + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Input layernorm {i} output") + + # Attention + hf_tensor_name = f"layers.{i}.self_attn.o_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + # the raw attention result, w/o o_proj. This is the output of senf_attn of FF and the input of o_proj in HF + output_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + # ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + # TP for self-attn partitions the attention heads across TP workers + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + print("comparing attention tensor: ", hf_tensor_name, " and ", ff_tensor_name) + compare(hf_tensor, ff_tensor, label=f"Attention {i} output") + + # Post-attention layernorm + hf_tensor_name = f"layers.{i}.post_attention_layernorm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label=f"Post-attention layernorm {i} output") + + # W1 (gate_proj) + hf_tensor_name = f"layers.{i}.mlp.gate_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W1 {i} output") + + # W3 (up_proj) + hf_tensor_name = f"layers.{i}.mlp.up_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W3 {i} output") + + # W2 (down_proj) + hf_tensor_name = f"layers.{i}.mlp.down_proj" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_down_proj_out = get_hf_tensor(hf_tensor_name, output_comparison) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label=f"W2 {i} input") + + hf_down_proj_in = hf_tensor.clone() + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_down_proj_out = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.TO_REDUCE) + + # Norm + hf_tensor_name = "norm" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=1) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape) + compare(hf_tensor, ff_tensor, label="Norm output") + + # LM head + hf_tensor_name = "lm_head" + ff_tensor_name = convert_hf_filename_to_ff(hf_tensor_name) + input_comparison = TensorComparisonIdxs(hf_tensor_type="input", ff_tensor_type="input", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, input_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, input_comparison, hf_tensor.shape, tp_type=TPType.REPLICATE) + compare(hf_tensor, ff_tensor, label="LM head input") + output_comparison = TensorComparisonIdxs(hf_tensor_type="output", ff_tensor_type="output", hf_tensor_idx=0, ff_tensor_idx=0) + hf_tensor = get_hf_tensor(hf_tensor_name, output_comparison) + ff_tensor = get_ff_tensor(ff_tensor_name, output_comparison, hf_tensor.shape, tp_type=TPType.PARTITION) + compare(hf_tensor, ff_tensor, label="LM head output") + + +parser = argparse.ArgumentParser(description='Argument Parser Example') +# Adding arguments +parser.add_argument('-m', '--model-name', type=str, default="goliaro/llama-160m-lora", help='Name of the model') +parser.add_argument('-n', '--num-steps', type=int, default=1, help='Number of decoding steps') +parser.add_argument('-tp', '--tensor-parallelism-degree', type=int, default=1, help='The tensor parallelism degree used when running FlexFlow') + +# Parse the arguments from command line +args = parser.parse_args() + +if __name__ == "__main__": + llama_alignment = LllamaAlignmentTest(args.model_name, tp_degree=args.tensor_parallelism_degree) + # llama_alignment.check_weights_alignment() + for i in range(args.num_steps): + llama_alignment.check_fwd_pass(i) diff --git a/tests/peft/alignment/align_test_utils.py b/tests/peft/alignment/align_test_utils.py index 93727bdc89..3085bbda56 100644 --- a/tests/peft/alignment/align_test_utils.py +++ b/tests/peft/alignment/align_test_utils.py @@ -3,6 +3,8 @@ from typing import List from enum import Enum from dataclasses import dataclass +import warnings + abs_dirname = os.path.dirname(os.path.abspath(__file__)) cache_folder = os.path.expanduser(os.getenv("FF_CACHE_PATH", "~/.cache/flexflow")) @@ -472,7 +474,16 @@ def replace_value(lst, old_value, new_value): if occurrences == 0: raise ValueError(f"Value {old_value} not found in the list.") elif occurrences > 1: - raise ValueError(f"Multiple instances of {old_value} found in the list.") + warnings.warn(f"Multiple instances of {old_value} found in the list.") + occurrence_idx=0 + for i, value in enumerate(lst): + if value == old_value: + occurrence_idx += 1 + if occurrence_idx == 2: + lst[i] = new_value + break + return lst + # raise ValueError(f"Multiple instances of {old_value} found in the list.") else: index = lst.index(old_value) lst[index] = new_value diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index 16b46cfa81..a2fc5548ab 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -77,7 +77,7 @@ def main(): if args.save_peft_tensors: make_debug_dirs() register_peft_hooks(model) - save_peft_weights(model, target_modules=["lora", "lm_head", "down_proj"]) + save_model_weights(model, target_modules=["lora", "lm_head", "down_proj"]) # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes") diff --git a/tests/peft/hf_utils.py b/tests/peft/hf_utils.py index 9332c803b2..b7b7997dee 100644 --- a/tests/peft/hf_utils.py +++ b/tests/peft/hf_utils.py @@ -40,7 +40,7 @@ def get_dst_folder(subdir, step_idx=0): def simplify_name(name): - return name.replace("base_model.model.model.", "").replace("base_model.model.", "") + return name.replace("base_model.model.model.", "").replace("base_model.model.", "").replace("model.layers.", "layers.").replace("model.", "") def get_optim_type(args): @@ -114,7 +114,7 @@ def peft_backward_hook(module, grad_input, grad_output): module.bwd_step += 1 -def peft_forward_hook(module, input, output): +def fwd_hook(module, input, output): if len(input) == 0 or len(output) == 0: return assert module.name is not None and module.fwd_step is not None @@ -312,11 +312,18 @@ def register_peft_hooks(model): layer.bwd_step = 0 if verbose: print(f"Adding hooks to layer {layer.name}") - layer.register_forward_hook(peft_forward_hook) + layer.register_forward_hook(fwd_hook) layer.register_full_backward_hook(peft_backward_hook) +def register_inference_hooks(model): + for name, layer in dict(model.named_modules()).items(): + layer.name = name + layer.fwd_step = 0 + if verbose: + print(f"Adding hooks to layer {layer.name}") + layer.register_forward_hook(fwd_hook) -def save_peft_weights(model, target_modules=[]): +def save_model_weights(model, target_modules=[]): # Save any weights of interest for name, params in model.named_parameters(): simplified_name = simplify_name(name)