Skip to content

Commit

Permalink
[Feature] Support MMBench DDP Evaluate (#300)
Browse files Browse the repository at this point in the history
* support ddp mmbench evaluate

* Update xtuner/tools/mmbench.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* Update xtuner/tools/mmbench.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* update minimum version of mmengine

* Update runtime.txt

---------

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>
  • Loading branch information
pppppM and LZHgrla authored Jan 24, 2024
1 parent f225761 commit 26d333f
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 58 deletions.
6 changes: 3 additions & 3 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ datasets>=2.16.0
einops
# Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44
lagent>=0.1.2
# Minimum 0.10.1 to support exclude_frozen_parameters for DeepSpeedStrategy,
# see https://github.com/open-mmlab/mmengine/pull/1415, https://github.com/open-mmlab/mmengine/pull/1424
mmengine>=0.10.1
# Minimum 0.10.3 to support distributed evaluation for MMBench
# see https://github.com/open-mmlab/mmengine/pull/1469
mmengine>=0.10.3
openpyxl
# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476
peft>=0.4.0
Expand Down
150 changes: 96 additions & 54 deletions xtuner/tools/mmbench.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import math
import os
import os.path as osp
import re
Expand All @@ -13,6 +14,9 @@
import tqdm
from huggingface_hub import snapshot_download
from mmengine import mkdir_or_exist
from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
master_only)
from mmengine.utils.dl_utils import set_multi_processing
from peft import PeftModel
from rich.console import Console
from rich.table import Table
Expand All @@ -22,7 +26,7 @@
CLIPVisionModel, GenerationConfig)

from xtuner.dataset.utils import decode_base64_to_image, expand2square
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
from xtuner.tools.utils import get_stop_criteria, is_cn_string
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
PROMPT_TEMPLATE)
Expand Down Expand Up @@ -78,10 +82,20 @@ def parse_args():
type=int,
default=0,
help='Random seed for reproducible text generation')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
args = parser.parse_args()
return args


@master_only
def master_print(msg):
print(msg)


class MMBenchDataset(Dataset):
ABBRS = {
'coarse_perception': 'CP',
Expand Down Expand Up @@ -155,6 +169,7 @@ def load_from_df(self, idx, key):
else:
return None

@master_only
def eval_result(self, result_df, show=True):

def calc_acc(df, group='category'):
Expand Down Expand Up @@ -255,28 +270,18 @@ def show_result(ret_json):

def main():
args = parse_args()

torch.manual_seed(args.seed)

# work_dir
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
save_dir = args.work_dir
if args.launcher != 'none':
set_multi_processing(distributed=True)
init_dist(args.launcher)

rank, world_size = get_dist_info()
torch.cuda.set_device(rank)
else:
# use config filename as default work_dir
save_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.data_path))[0])
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
save_dir = osp.join(save_dir, timestamp)
mkdir_or_exist(osp.abspath(save_dir))
print('=======================================================')
print(f'Dataset path: {osp.abspath(args.data_path)}\n'
f'Results will be saved to {osp.abspath(save_dir)}')
print('=======================================================')
results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
results_json_path = osp.join(save_dir, 'mmbench_result.json')
args_path = osp.join(save_dir, 'args.json')
with open(args_path, 'w') as f:
json.dump(args.__dict__, f, indent=2)
rank = 0
world_size = 1

# build llm
quantization_config = None
Expand All @@ -295,20 +300,21 @@ def main():
model_kwargs = {
'quantization_config': quantization_config,
'load_in_8bit': load_in_8bit,
'device_map': 'auto',
'device_map': rank if world_size > 1 else 'auto',
'offload_folder': args.offload_folder,
'trust_remote_code': True,
'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
}

# build llm
llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
**model_kwargs)
with LoadWoInit():
llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
**model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
encode_special_tokens=True)
print(f'Load LLM from {args.model_name_or_path}')
master_print(f'Load LLM from {args.model_name_or_path}')

llava_path = snapshot_download(
repo_id=args.llava) if not osp.isdir(args.llava) else args.llava
Expand All @@ -323,41 +329,42 @@ def main():
assert args.visual_encoder is not None, (
'Please specify the `--visual-encoder`!')
visual_encoder_path = args.visual_encoder
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(visual_encoder_path)
print(f'Load visual_encoder from {visual_encoder_path}')
with LoadWoInit():
visual_encoder = CLIPVisionModel.from_pretrained(
visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
image_processor = CLIPImageProcessor.from_pretrained(
visual_encoder_path)
master_print(f'Load visual_encoder from {visual_encoder_path}')

# load adapter
if 'llm_adapter' in os.listdir(llava_path):
adapter_path = osp.join(llava_path, 'llm_adapter')
llm = PeftModel.from_pretrained(
llm,
adapter_path,
offload_folder=args.offload_folder,
trust_remote_code=True)
print(f'Load LLM adapter from {args.llava}')

with LoadWoInit():
llm = PeftModel.from_pretrained(
llm, adapter_path, offload_folder=args.offload_folder)

master_print(f'Load LLM adapter from {args.llava}')

if 'visual_encoder_adapter' in os.listdir(llava_path):
adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
visual_encoder = PeftModel.from_pretrained(
visual_encoder,
adapter_path,
offload_folder=args.offload_folder,
trust_remote_code=True)
print(f'Load visual_encoder adapter from {args.llava}')
visual_encoder, adapter_path, offload_folder=args.offload_folder)
master_print(f'Load visual_encoder adapter from {args.llava}')

# build projector
projector_path = osp.join(llava_path, 'projector')
projector = AutoModel.from_pretrained(
projector_path,
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
trust_remote_code=True)
print(f'Load projector from {args.llava}')
with LoadWoInit():
projector = AutoModel.from_pretrained(
projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
master_print(f'Load projector from {args.llava}')

projector.cuda()
projector.eval()

visual_encoder.cuda()
visual_encoder.eval()

llm.eval()

stop_words = args.stop_words
Expand All @@ -375,10 +382,40 @@ def main():
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)

# work_dir
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
save_dir = args.work_dir
else:
# use config filename as default work_dir
save_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.data_path))[0])
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
save_dir = osp.join(save_dir, timestamp)

if rank == 0:
mkdir_or_exist(osp.abspath(save_dir))
print('=======================================================')
print(f'Dataset path: {osp.abspath(args.data_path)}\n'
f'Results will be saved to {osp.abspath(save_dir)}')
print('=======================================================')

args_path = osp.join(save_dir, 'args.json')
with open(args_path, 'w') as f:
json.dump(args.__dict__, f, indent=2)

results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
results_json_path = osp.join(save_dir, 'mmbench_result.json')

dataset = MMBenchDataset(args.data_path)

results = []
n_samples = len(dataset)
for i in tqdm.tqdm(range(n_samples)):
per_rank_samples = math.ceil(n_samples / world_size)

per_rank_ids = range(per_rank_samples * rank,
min(n_samples, per_rank_samples * (rank + 1)))
for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
data_sample = dataset[i]
if data_sample['context'] is not None:
text = data_sample['context'] + '\n' + data_sample[
Expand Down Expand Up @@ -452,17 +489,22 @@ def main():
cur_result['answer'] = data_sample.get('answer')
results.append(cur_result)

results_df = pd.DataFrame(results)
with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
results_df.to_excel(writer, index=False)
results = collect_results(results, n_samples)

if dataset.split == 'dev':
results_dict = dataset.eval_result(results_df, show=True)
with open(results_json_path, 'w') as f:
json.dump(results_dict, f, indent=2)
else:
print('All done!')
if get_rank() == 0:

results_df = pd.DataFrame(results)
with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
results_df.to_excel(writer, index=False)

if dataset.split == 'dev':
results_dict = dataset.eval_result(results_df, show=True)
with open(results_json_path, 'w') as f:
json.dump(results_dict, f, indent=2)
else:
print('All done!')


if __name__ == '__main__':

main()
5 changes: 4 additions & 1 deletion xtuner/tools/model_converters/pth_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ def main():
tokenizer = BUILDER.build(cfg.tokenizer)
tokenizer.save_pretrained(llm_path)
print(f'Saving LLM to {llm_path}')
model.llm.save_pretrained(llm_path, max_shard_size=args.max_shard_size)
model.llm.save_pretrained(
llm_path,
max_shard_size=args.max_shard_size,
safe_serialization=False)

shutil.copyfile(args.config, osp.join(args.save_dir, 'xtuner_config.py'))
print('All done!')
Expand Down

0 comments on commit 26d333f

Please sign in to comment.