Skip to content

Commit

Permalink
[Refactor] Decouple map fn (#54)
Browse files Browse the repository at this point in the history
* decouple map_fn

* decouple map_fn

* decouple map_fn

* modify config

* advance the --prompt-template args

* fix config

* Update sql_map_fn.py

* Update sql_map_fn.py

* Update alpaca_map_fn.py

* rename hooks

* refactor template_map_fn

* add openai_map_fn

* fix config

* fix config

* fix print_log

* fix medical map_fn

* improve

* rename dataset_map_fn to dataset_map_fns

* add alpaca_enzh & oasst1 concat dataset config

---------

Co-authored-by: LZHgrla <36994684+LZHgrla@users.noreply.github.com>
Co-authored-by: LZHgrla <linzhihao@pjlab.org.cn>
  • Loading branch information
3 people authored Aug 28, 2023
1 parent 7e8fb2c commit 830ad06
Show file tree
Hide file tree
Showing 40 changed files with 519 additions and 293 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import torch
from bitsandbytes.optim import PagedAdamW32bit
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from peft import LoraConfig
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)

from xtuner.datasets import ConcatDataset, process_hf_dataset
from xtuner.datasets.collate_fns import default_collate_fn
from xtuner.datasets.map_fns import (alpaca_map_fn, alpaca_zh_map_fn,
oasst1_map_fn, template_map_fn_factory)
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.models import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# path
pretrained_model_name_or_path = 'internlm/internlm-7b'
alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese'
alpaca_en_path = 'tatsu-lab/alpaca'
oasst1_path = 'timdettmers/openassistant-guanaco'

# data
prompt_template = PROMPT_TEMPLATE.alpaca
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
max_epochs = 3

# optim
optim_type = PagedAdamW32bit
lr = 2e-4
betas = (0.9, 0.999)
weight_decay = 0.01
max_norm = 1 # grad clip

# Assess the progress of the model's training via interactive dialogue.
evaluation_freq = 500
human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai']

# other
max_length = 2048
pack_to_max_length = True

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float16,
quantization_config=dict(
type=BitsAndBytesConfig,
load_in_4bit=True,
load_in_8bit=False,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4')),
lora=dict(
type=LoraConfig,
r=64,
lora_alpha=16,
lora_dropout=0.1,
bias='none',
task_type='CAUSAL_LM'))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
alpaca_en = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=alpaca_en_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

alpaca_zh = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=alpaca_zh_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=alpaca_zh_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

oasst1 = dict(
type=process_hf_dataset,
dataset=dict(type=load_dataset, path=oasst1_path),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=oasst1_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataset = dict(
type=ConcatDataset,
datasets_cfg=dict(alpaca_en=alpaca_en, alpaca_zh=alpaca_zh, oasst1=oasst1))

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=DefaultSampler, shuffle=True),
collate_fn=dict(type=default_collate_fn))

#######################################################################
# PART 4 Scheduler #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
dtype='float16')

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = dict(
type=CosineAnnealingLR,
eta_min=lr * 0.1,
by_epoch=True,
T_max=max_epochs,
convert_to_iter_based=True)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
sample_inputs=human_inputs,
instruction=prompt_template.INSTRUCTION_START)
]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, interval=10),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per epoch.
checkpoint=dict(type=CheckpointHook, interval=1),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)
35 changes: 21 additions & 14 deletions xtuner/configs/llama/llama_7b/llama_7b_qlora_alpaca_e3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

from xtuner.datasets import ConcatDataset, process_hf_dataset
from xtuner.datasets.collate_fns import default_collate_fn
from xtuner.datasets.map_fns import alpaca_map_fn, alpaca_zh_map_fn
from xtuner.engine import LogSampleHook, SampleGenerateHook
from xtuner.datasets.map_fns import (alpaca_map_fn, alpaca_zh_map_fn,
template_map_fn_factory)
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.models import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

Expand All @@ -25,6 +26,7 @@
alpaca_en_path = 'tatsu-lab/alpaca'

# data
prompt_template = PROMPT_TEMPLATE.alpaca
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
Expand All @@ -37,10 +39,13 @@
weight_decay = 0.01
max_norm = 1 # grad clip

# Assess the progress of the model's training via interactive dialogue.
evaluation_freq = 500
human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai']

# other
max_length = 2048
pack_to_max_length = True
generate_test_freq = 500

#######################################################################
# PART 2 Model & Tokenizer #
Expand Down Expand Up @@ -83,8 +88,10 @@
dataset=dict(type=load_dataset, path=alpaca_en_path),
tokenizer=tokenizer,
max_length=max_length,
map_fn=alpaca_map_fn,
remove_columns=['instruction', 'text'],
dataset_map_fn=alpaca_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

Expand All @@ -93,8 +100,10 @@
dataset=dict(type=load_dataset, path=alpaca_zh_path),
tokenizer=tokenizer,
max_length=max_length,
map_fn=alpaca_zh_map_fn,
remove_columns=['instruction', 'instruction_zh', 'input_zh', 'output_zh'],
dataset_map_fn=alpaca_zh_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

Expand Down Expand Up @@ -139,15 +148,13 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=LogSampleHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=SampleGenerateHook,
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=generate_test_freq,
sample_inputs=[
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
],
instruction=PROMPT_TEMPLATE.alpaca.INSTRUCTION_START)
every_n_iters=evaluation_freq,
sample_inputs=human_inputs,
instruction=prompt_template.INSTRUCTION_START)
]

# configure default hooks
Expand Down
27 changes: 16 additions & 11 deletions xtuner/configs/llama/llama_7b/llama_7b_qlora_oasst1_e3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from xtuner.datasets import process_hf_dataset
from xtuner.datasets.collate_fns import default_collate_fn
from xtuner.datasets.map_fns import oasst1_map_fn
from xtuner.engine import LogSampleHook, SampleGenerateHook
from xtuner.datasets.map_fns import oasst1_map_fn, template_map_fn_factory
from xtuner.engine import DatasetInfoHook, EvaluateChatHook
from xtuner.models import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

Expand All @@ -24,6 +24,7 @@
data_path = 'timdettmers/openassistant-guanaco'

# data
prompt_template = PROMPT_TEMPLATE.openassistant
batch_size = 1 # per_device
accumulative_counts = 16
dataloader_num_workers = 0
Expand All @@ -36,10 +37,13 @@
weight_decay = 0.01
max_norm = 1 # grad clip

# Assess the progress of the model's training via interactive dialogue.
evaluation_freq = 500
human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai']

# other
max_length = 2048
pack_to_max_length = True
generate_test_freq = 500
#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
Expand Down Expand Up @@ -81,7 +85,10 @@
dataset=dict(type=load_dataset, path=data_path),
tokenizer=tokenizer,
max_length=max_length,
map_fn=oasst1_map_fn,
dataset_map_fn=oasst1_map_fn,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

Expand Down Expand Up @@ -122,15 +129,13 @@
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(type=LogSampleHook, tokenizer=tokenizer),
dict(type=DatasetInfoHook, tokenizer=tokenizer),
dict(
type=SampleGenerateHook,
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=generate_test_freq,
sample_inputs=[
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
],
instruction=PROMPT_TEMPLATE.openassistant.INSTRUCTION_START)
every_n_iters=evaluation_freq,
sample_inputs=human_inputs,
instruction=prompt_template.INSTRUCTION_START)
]

# configure default hooks
Expand Down
Loading

0 comments on commit 830ad06

Please sign in to comment.