From 26d333f83bcd84122ffe68459dbbe292a80b5c03 Mon Sep 17 00:00:00 2001 From: pppppM <67539920+pppppM@users.noreply.github.com> Date: Wed, 24 Jan 2024 19:37:07 +0800 Subject: [PATCH] [Feature] Support MMBench DDP Evaluate (#300) * support ddp mmbench evaluate * Update xtuner/tools/mmbench.py Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> * Update xtuner/tools/mmbench.py Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> * update minimum version of mmengine * Update runtime.txt --------- Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> --- requirements/runtime.txt | 6 +- xtuner/tools/mmbench.py | 150 +++++++++++++-------- xtuner/tools/model_converters/pth_to_hf.py | 5 +- 3 files changed, 103 insertions(+), 58 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 9be3701ab..fbfb37687 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -5,9 +5,9 @@ datasets>=2.16.0 einops # Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44 lagent>=0.1.2 -# Minimum 0.10.1 to support exclude_frozen_parameters for DeepSpeedStrategy, -# see https://github.com/open-mmlab/mmengine/pull/1415, https://github.com/open-mmlab/mmengine/pull/1424 -mmengine>=0.10.1 +# Minimum 0.10.3 to support distributed evaluation for MMBench +# see https://github.com/open-mmlab/mmengine/pull/1469 +mmengine>=0.10.3 openpyxl # Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476 peft>=0.4.0 diff --git a/xtuner/tools/mmbench.py b/xtuner/tools/mmbench.py index 2cf32987f..f6bcb7036 100644 --- a/xtuner/tools/mmbench.py +++ b/xtuner/tools/mmbench.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import json +import math import os import os.path as osp import re @@ -13,6 +14,9 @@ import tqdm from huggingface_hub import snapshot_download from mmengine import mkdir_or_exist +from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, + master_only) +from mmengine.utils.dl_utils import set_multi_processing from peft import PeftModel from rich.console import Console from rich.table import Table @@ -22,7 +26,7 @@ CLIPVisionModel, GenerationConfig) from xtuner.dataset.utils import decode_base64_to_image, expand2square -from xtuner.model.utils import prepare_inputs_labels_for_multimodal +from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal from xtuner.tools.utils import get_stop_criteria, is_cn_string from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, PROMPT_TEMPLATE) @@ -78,10 +82,20 @@ def parse_args(): type=int, default=0, help='Random seed for reproducible text generation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') args = parser.parse_args() return args +@master_only +def master_print(msg): + print(msg) + + class MMBenchDataset(Dataset): ABBRS = { 'coarse_perception': 'CP', @@ -155,6 +169,7 @@ def load_from_df(self, idx, key): else: return None + @master_only def eval_result(self, result_df, show=True): def calc_acc(df, group='category'): @@ -255,28 +270,18 @@ def show_result(ret_json): def main(): args = parse_args() + torch.manual_seed(args.seed) - # work_dir - if args.work_dir is not None: - # update configs according to CLI args if args.work_dir is not None - save_dir = args.work_dir + if args.launcher != 'none': + set_multi_processing(distributed=True) + init_dist(args.launcher) + + rank, world_size = get_dist_info() + torch.cuda.set_device(rank) else: - # use config filename as default work_dir - save_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.data_path))[0]) - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) - save_dir = osp.join(save_dir, timestamp) - mkdir_or_exist(osp.abspath(save_dir)) - print('=======================================================') - print(f'Dataset path: {osp.abspath(args.data_path)}\n' - f'Results will be saved to {osp.abspath(save_dir)}') - print('=======================================================') - results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') - results_json_path = osp.join(save_dir, 'mmbench_result.json') - args_path = osp.join(save_dir, 'args.json') - with open(args_path, 'w') as f: - json.dump(args.__dict__, f, indent=2) + rank = 0 + world_size = 1 # build llm quantization_config = None @@ -295,20 +300,21 @@ def main(): model_kwargs = { 'quantization_config': quantization_config, 'load_in_8bit': load_in_8bit, - 'device_map': 'auto', + 'device_map': rank if world_size > 1 else 'auto', 'offload_folder': args.offload_folder, 'trust_remote_code': True, 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] } # build llm - llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, - **model_kwargs) + with LoadWoInit(): + llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, + **model_kwargs) tokenizer = AutoTokenizer.from_pretrained( args.model_name_or_path, trust_remote_code=True, encode_special_tokens=True) - print(f'Load LLM from {args.model_name_or_path}') + master_print(f'Load LLM from {args.model_name_or_path}') llava_path = snapshot_download( repo_id=args.llava) if not osp.isdir(args.llava) else args.llava @@ -323,41 +329,42 @@ def main(): assert args.visual_encoder is not None, ( 'Please specify the `--visual-encoder`!') visual_encoder_path = args.visual_encoder - visual_encoder = CLIPVisionModel.from_pretrained( - visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) - image_processor = CLIPImageProcessor.from_pretrained(visual_encoder_path) - print(f'Load visual_encoder from {visual_encoder_path}') + with LoadWoInit(): + visual_encoder = CLIPVisionModel.from_pretrained( + visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) + image_processor = CLIPImageProcessor.from_pretrained( + visual_encoder_path) + master_print(f'Load visual_encoder from {visual_encoder_path}') # load adapter if 'llm_adapter' in os.listdir(llava_path): adapter_path = osp.join(llava_path, 'llm_adapter') - llm = PeftModel.from_pretrained( - llm, - adapter_path, - offload_folder=args.offload_folder, - trust_remote_code=True) - print(f'Load LLM adapter from {args.llava}') + + with LoadWoInit(): + llm = PeftModel.from_pretrained( + llm, adapter_path, offload_folder=args.offload_folder) + + master_print(f'Load LLM adapter from {args.llava}') + if 'visual_encoder_adapter' in os.listdir(llava_path): adapter_path = osp.join(llava_path, 'visual_encoder_adapter') visual_encoder = PeftModel.from_pretrained( - visual_encoder, - adapter_path, - offload_folder=args.offload_folder, - trust_remote_code=True) - print(f'Load visual_encoder adapter from {args.llava}') + visual_encoder, adapter_path, offload_folder=args.offload_folder) + master_print(f'Load visual_encoder adapter from {args.llava}') # build projector projector_path = osp.join(llava_path, 'projector') - projector = AutoModel.from_pretrained( - projector_path, - torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype], - trust_remote_code=True) - print(f'Load projector from {args.llava}') + with LoadWoInit(): + projector = AutoModel.from_pretrained( + projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) + master_print(f'Load projector from {args.llava}') projector.cuda() projector.eval() + visual_encoder.cuda() visual_encoder.eval() + llm.eval() stop_words = args.stop_words @@ -375,10 +382,40 @@ def main(): if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, ) + # work_dir + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + save_dir = args.work_dir + else: + # use config filename as default work_dir + save_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.data_path))[0]) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + save_dir = osp.join(save_dir, timestamp) + + if rank == 0: + mkdir_or_exist(osp.abspath(save_dir)) + print('=======================================================') + print(f'Dataset path: {osp.abspath(args.data_path)}\n' + f'Results will be saved to {osp.abspath(save_dir)}') + print('=======================================================') + + args_path = osp.join(save_dir, 'args.json') + with open(args_path, 'w') as f: + json.dump(args.__dict__, f, indent=2) + + results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx') + results_json_path = osp.join(save_dir, 'mmbench_result.json') + dataset = MMBenchDataset(args.data_path) + results = [] n_samples = len(dataset) - for i in tqdm.tqdm(range(n_samples)): + per_rank_samples = math.ceil(n_samples / world_size) + + per_rank_ids = range(per_rank_samples * rank, + min(n_samples, per_rank_samples * (rank + 1))) + for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): data_sample = dataset[i] if data_sample['context'] is not None: text = data_sample['context'] + '\n' + data_sample[ @@ -452,17 +489,22 @@ def main(): cur_result['answer'] = data_sample.get('answer') results.append(cur_result) - results_df = pd.DataFrame(results) - with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer: - results_df.to_excel(writer, index=False) + results = collect_results(results, n_samples) - if dataset.split == 'dev': - results_dict = dataset.eval_result(results_df, show=True) - with open(results_json_path, 'w') as f: - json.dump(results_dict, f, indent=2) - else: - print('All done!') + if get_rank() == 0: + + results_df = pd.DataFrame(results) + with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer: + results_df.to_excel(writer, index=False) + + if dataset.split == 'dev': + results_dict = dataset.eval_result(results_df, show=True) + with open(results_json_path, 'w') as f: + json.dump(results_dict, f, indent=2) + else: + print('All done!') if __name__ == '__main__': + main() diff --git a/xtuner/tools/model_converters/pth_to_hf.py b/xtuner/tools/model_converters/pth_to_hf.py index aab7de595..9e68afe67 100644 --- a/xtuner/tools/model_converters/pth_to_hf.py +++ b/xtuner/tools/model_converters/pth_to_hf.py @@ -126,7 +126,10 @@ def main(): tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path) print(f'Saving LLM to {llm_path}') - model.llm.save_pretrained(llm_path, max_shard_size=args.max_shard_size) + model.llm.save_pretrained( + llm_path, + max_shard_size=args.max_shard_size, + safe_serialization=False) shutil.copyfile(args.config, osp.join(args.save_dir, 'xtuner_config.py')) print('All done!')