Skip to content

Commit

Permalink
[Feature] Add mistral pretrain (#204)
Browse files Browse the repository at this point in the history
* [Feature] Add mistral pretrain

* [feat] rename pretrain_map_fn

* [feat] add custom hook

* [feat] change mistral config name

* Update chat.py

* Update xtuner/utils/templates.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* Update xtuner/configs/mistral/mistral_7b_qlora_skypile_pretrain_e1.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* Update xtuner/configs/mistral/mistral_7b_qlora_skypile_pretrain_e1.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* Update xtuner/configs/mistral/mistral_7b_qlora_skypile_pretrain_e1.py

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>

* fix pre-commit

---------

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>
Co-authored-by: LZHgrla <linzhihao@pjlab.org.cn>
  • Loading branch information
3 people authored Nov 8, 2023
1 parent 0badead commit 8ce2569
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 35 deletions.
173 changes: 173 additions & 0 deletions xtuner/configs/mistral/mistral_7b_qlora_skypile_pretrain_e1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from bitsandbytes.optim import PagedAdamW32bit
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from transformers import BitsAndBytesConfig, LlamaTokenizer, MistralForCausalLM

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import pretrain_map_fn
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.model import SupervisedFinetune

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'mistralai/Mistral-7B-v0.1'

# Data
data_path = 'Skywork/SkyPile-150B'
max_length = 2048
pack_to_max_length = True

# Scheduler & Optimizer
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 1
optim_type = PagedAdamW32bit
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0
max_norm = 1 # grad clip

# Evaluate the generation performance during the training
evaluation_freq = 500
evaluation_inputs = ['上海的景点有']

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=LlamaTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=MistralForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.05,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=data_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=pretrain_map_fn,
template_map_fn=None,
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
max_new_tokens=100)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
4 changes: 3 additions & 1 deletion xtuner/dataset/map_fns/dataset_map_fns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .oasst1_map_fn import oasst1_map_fn
from .openai_map_fn import openai_map_fn
from .openorca_map_fn import openorca_map_fn
from .pretrain_map_fn import pretrain_map_fn
from .sql_map_fn import sql_map_fn
from .stack_exchange_map_fn import stack_exchange_map_fn
from .tiny_codes_map_fn import tiny_codes_map_fn
Expand All @@ -21,5 +22,6 @@
'medical_map_fn', 'openorca_map_fn', 'code_alpaca_map_fn',
'tiny_codes_map_fn', 'colors_map_fn', 'law_reference_map_fn',
'crime_kg_assitant_map_fn', 'sql_map_fn', 'openai_map_fn',
'wizardlm_map_fn', 'stack_exchange_map_fn', 'msagent_react_map_fn'
'wizardlm_map_fn', 'stack_exchange_map_fn', 'msagent_react_map_fn',
'pretrain_map_fn'
]
14 changes: 14 additions & 0 deletions xtuner/dataset/map_fns/dataset_map_fns/pretrain_map_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
def pretrain_map_fn(example):
r"""Example before preprocessing:
example['text'] = 'xxx'
Example after preprocessing:
example['conversation'] = [
{
'input': '',
'output': 'xxx'
},
]
"""
return {'conversation': [{'input': '', 'output': example['text'].strip()}]}
71 changes: 37 additions & 34 deletions xtuner/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def parse_args():
parser.add_argument(
'--prompt-template',
choices=PROMPT_TEMPLATE.keys(),
default=PROMPT_TEMPLATE.default,
default=None,
help='Specify a prompt template')
system_group = parser.add_mutually_exclusive_group()
system_group.add_argument(
Expand Down Expand Up @@ -241,39 +241,42 @@ def main():
print('Log: Exit!')
exit(0)

template = PROMPT_TEMPLATE[args.prompt_template]
prompt_text = ''
if 'SYSTEM' in template and n_turn == 0:
system_text = None
if args.system_template is not None:
system_text = SYSTEM_TEMPLATE[args.system_template].format(
round=n_turn + 1, bot_name=args.bot_name)
elif args.system is not None:
system_text = args.system
if system_text is not None:
prompt_text += template['SYSTEM'].format(
system=system_text,
round=n_turn + 1,
bot_name=args.bot_name)
prompt_text += template['INSTRUCTION'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == args.system_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
'- Inner thoughts: disabled.')
if not calculate_open:
prompt_text.replace(
'- Calculator: enabled. API: Calculate(expression)',
'- Calculator: disabled.')
if not solve_open:
prompt_text.replace(
'- Equation solver: enabled. API: Solve(equation)',
'- Equation solver: disabled.')
if not search_open:
prompt_text.replace(
'- Web search: enabled. API: Search(query)',
'- Web search: disabled.')

if args.prompt_template:
prompt_text = ''
template = PROMPT_TEMPLATE[args.prompt_template]
if 'SYSTEM' in template and n_turn == 0:
system_text = None
if args.system_template is not None:
system_text = SYSTEM_TEMPLATE[
args.system_template].format(
round=n_turn + 1, bot_name=args.bot_name)
elif args.system is not None:
system_text = args.system
if system_text is not None:
prompt_text += template['SYSTEM'].format(
system=system_text,
round=n_turn + 1,
bot_name=args.bot_name)
prompt_text += template['INSTRUCTION'].format(
input=text, round=n_turn + 1, bot_name=args.bot_name)
if args.prompt_template == args.system_template == 'moss_sft':
if not inner_thoughts_open:
prompt_text.replace('- Inner thoughts: enabled.',
'- Inner thoughts: disabled.')
if not calculate_open:
prompt_text.replace(('- Calculator: enabled. API: '
'Calculate(expression)'),
'- Calculator: disabled.')
if not solve_open:
prompt_text.replace(
'- Equation solver: enabled. API: Solve(equation)',
'- Equation solver: disabled.')
if not search_open:
prompt_text.replace(
'- Web search: enabled. API: Search(query)',
'- Web search: disabled.')
else:
prompt_text = text
inputs += prompt_text
ids = tokenizer.encode(inputs, return_tensors='pt')
streamer = Streamer(tokenizer) if Streamer is not None else None
Expand Down

0 comments on commit 8ce2569

Please sign in to comment.