Skip to content

Commit

Permalink
[Feature] Support Baichuan2 models (#102)
Browse files Browse the repository at this point in the history
* add baichuan2_7b_base

* fix lm_head bug for Baichuan2

* Update README.md

* Update README_zh-CN.md

* Update README.md

* remove infrequent configs

* Update README.md

* Update README_zh-CN.md

* add baichuan2 chat template

* Update README_zh-CN.md
  • Loading branch information
LZHgrla authored Sep 6, 2023
1 parent 86faaaa commit 1d0259b
Show file tree
Hide file tree
Showing 16 changed files with 2,314 additions and 22 deletions.
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ English | [简体中文](README_zh-CN.md)

## 🎉 News

- **\[2023.09.06\]** Support the training of [Baichuan2](https://huggingface.co/baichuan-inc) models! Try it out by `xtuner train baichuan2_7b_base_qlora_oasst1_e3`!
- **\[2023.08.30\]** XTuner is released, with multiple fine-tuned adapters on [HuggingFace](https://huggingface.co/xtuner).

## 📖 Introduction

XTuner is a toolkit for efficiently fine-tuning LLM, developed by the [MMRazor](https://github.com/open-mmlab/mmrazor) and [MMDeploy](https://github.com/open-mmlab/mmdeploy) teams.

- **Efficiency**: Support LLM fine-tuning on consumer-grade GPUs. The minimum GPU memory required for 7B LLM fine-tuning is only **8GB**, indicating that users can use nearly any GPU (even the free resource, *e.g.*, Colab) to fine-tune custom LLMs.
- **Versatile**: Support various **LLMs** ([InternLM](https://github.com/InternLM/InternLM), [Llama2](https://github.com/facebookresearch/llama), [ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b), [Qwen](https://github.com/QwenLM/Qwen-7B), [Baichuan](https://github.com/baichuan-inc), ...), **datasets** ([MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca), [WizardLM](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k), [oasst1](https://huggingface.co/datasets/timdettmers/openassistant-guanaco), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), ...) and **algorithms** ([QLoRA](http://arxiv.org/abs/2305.14314), [LoRA](http://arxiv.org/abs/2106.09685)), allowing users to choose the most suitable solution for their requirements.
- **Versatile**: Support various **LLMs** ([InternLM](https://huggingface.co/internlm), [Llama2](https://huggingface.co/meta-llama), [ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b), [Qwen](https://huggingface.co/Qwen), [Baichuan2](https://huggingface.co/baichuan-inc), ...), **datasets** ([MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca), [WizardLM](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k), [oasst1](https://huggingface.co/datasets/timdettmers/openassistant-guanaco), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), ...) and **algorithms** ([QLoRA](http://arxiv.org/abs/2305.14314), [LoRA](http://arxiv.org/abs/2106.09685)), allowing users to choose the most suitable solution for their requirements.
- **Compatibility**: Compatible with [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 and [HuggingFace](https://huggingface.co) 🤗 training pipeline, enabling effortless integration and utilization.

## 🌟 Demos
Expand Down Expand Up @@ -70,17 +71,18 @@ XTuner is a toolkit for efficiently fine-tuning LLM, developed by the [MMRazor](
<tr valign="top">
<td align="left" valign="top">
<ul>
<li><a href="https://github.com/InternLM/InternLM">InternLM</a></li>
<li><a href="https://github.com/InternLM/InternLM">InternLM-Chat</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/internlm/internlm-7b">InternLM</a></li>
<li><a href="https://huggingface.co/internlm/internlm-chat-7b">InternLM-Chat</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen-Chat</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-13B">Baichuan-13B-Base</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-13B">Baichuan-13B-Chat</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B-Chat">Qwen-Chat</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Base">Baichuan-13B-Base</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Chat">Baichuan-13B-Chat</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2-7B-Base</a></li>
<li>...</li>
</ul>
</td>
Expand Down
24 changes: 13 additions & 11 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

## 🎉 更新

- **\[2023.09.06\]** 支持 [Baichuan2](https://huggingface.co/baichuan-inc) 系列模型训练!快速体验:`xtuner train baichuan2_7b_base_qlora_oasst1_e3`
- **\[2023.08.30\]** XTuner 正式发布!众多微调模型已上传至 [HuggingFace](https://huggingface.co/xtuner)

## 📖 介绍

XTuner 是一个轻量级微调大语言模型的工具库,由 [MMRazor](https://github.com/open-mmlab/mmrazor)[MMDeploy](https://github.com/open-mmlab/mmdeploy) 团队联合开发。

- **轻量级**: 支持在消费级显卡上微调大语言模型。对于 7B 参数量,微调所需的最小显存仅为 **8GB**,这使得用户可以使用几乎任何显卡(甚至免费资源,例如Colab)来微调获得自定义大语言模型助手。
- **多样性**: 支持多种**大语言模型**[InternLM](https://github.com/InternLM/InternLM)[Llama2](https://github.com/facebookresearch/llama)[ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b)[Qwen](https://github.com/QwenLM/Qwen-7B)[Baichuan](https://github.com/baichuan-inc), ...),**数据集**[MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca), [WizardLM](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k), [oasst1](https://huggingface.co/datasets/timdettmers/openassistant-guanaco), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), ...)和**微调算法**[QLoRA](http://arxiv.org/abs/2305.14314)[LoRA](http://arxiv.org/abs/2106.09685)),支撑用户根据自身具体需求选择合适的解决方案。
- **多样性**: 支持多种**大语言模型**[InternLM](https://huggingface.co/internlm)[Llama2](https://huggingface.co/meta-llama)[ChatGLM2](https://huggingface.co/THUDM/chatglm2-6b)[Qwen](https://huggingface.co/Qwen)[Baichuan2](https://huggingface.co/baichuan-inc), ...),**数据集**[MOSS_003_SFT](https://huggingface.co/datasets/fnlp/moss-003-sft-data), [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca), [WizardLM](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k), [oasst1](https://huggingface.co/datasets/timdettmers/openassistant-guanaco), [Open-Platypus](https://huggingface.co/datasets/garage-bAInd/Open-Platypus), [Code Alpaca](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K), [Colorist](https://huggingface.co/datasets/burkelibbey/colors), ...)和**微调算法**[QLoRA](http://arxiv.org/abs/2305.14314)[LoRA](http://arxiv.org/abs/2106.09685)),支撑用户根据自身具体需求选择合适的解决方案。
- **兼容性**: 兼容 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 和 [HuggingFace](https://huggingface.co) 🤗 的训练流程,支撑用户无感式集成与使用。

## 🌟 示例
Expand Down Expand Up @@ -70,17 +71,18 @@ XTuner 是一个轻量级微调大语言模型的工具库,由 [MMRazor](https
<tr valign="top">
<td align="left" valign="top">
<ul>
<li><a href="https://github.com/InternLM/InternLM">InternLM</a></li>
<li><a href="https://github.com/InternLM/InternLM">InternLM-Chat</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2</a></li>
<li><a href="https://github.com/facebookresearch/llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/internlm/internlm-7b">InternLM</a></li>
<li><a href="https://huggingface.co/internlm/internlm-chat-7b">InternLM-Chat</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama2-Chat</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen</a></li>
<li><a href="https://github.com/QwenLM/Qwen-7B">Qwen-Chat</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-13B">Baichuan-13B-Base</a></li>
<li><a href="https://github.com/baichuan-inc/Baichuan-13B">Baichuan-13B-Chat</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B-Chat">Qwen-Chat</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-7B">Baichuan-7B</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Base">Baichuan-13B-Base</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan-13B-Chat">Baichuan-13B-Chat</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2-7B-Base</a></li>
<li>...</li>
</ul>
</td>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from bitsandbytes.optim import PagedAdamW32bit
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'baichuan-inc/Baichuan2-7B-Base'

# Data
alpaca_en_path = 'tatsu-lab/alpaca'
prompt_template = PROMPT_TEMPLATE.alpaca
max_length = 2048
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3
optim_type = PagedAdamW32bit
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip

# Evaluate the generation performance during the training
evaluation_freq = 500
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=alpaca_en_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=alpaca_en,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
instruction=prompt_template.INSTRUCTION_START)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
Loading

0 comments on commit 1d0259b

Please sign in to comment.