From 4d32b1ab86b310a46884ec01f6be8038b6c9b62e Mon Sep 17 00:00:00 2001 From: liudan <403644786@qq.com> Date: Mon, 21 Oct 2024 20:26:31 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8D=E4=BA=86minicpm3,=E5=B9=B6?= =?UTF-8?q?=E4=B8=94=E6=B5=8B=E8=AF=95=E5=8F=AF=E4=BB=A5=E8=BF=90=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../minicpm3_4b_full_custom_pretrain_e1.py | 200 +++++++++++++++ .../minicpm3_4b_chat_qlora_custom_sft_e1.py | 228 ++++++++++++++++++ .../minicpm3_4b_full_alpaca_zh_e3.py | 201 +++++++++++++++ .../minicpm3_4b/minicpm_2b_dpo_qlora.py | 221 +++++++++++++++++ xtuner/utils/templates.py | 8 + 5 files changed, 858 insertions(+) create mode 100644 xtuner/configs/custom_dataset/pretrain/minicpm/minicpm3_4b_full_custom_pretrain_e1.py create mode 100644 xtuner/configs/custom_dataset/sft/minicpm/minicpm3_4b_chat_qlora_custom_sft_e1.py create mode 100644 xtuner/configs/minicpm/minicpm3_4b/minicpm3_4b_full_alpaca_zh_e3.py create mode 100644 xtuner/configs/minicpm/minicpm3_4b/minicpm_2b_dpo_qlora.py diff --git a/xtuner/configs/custom_dataset/pretrain/minicpm/minicpm3_4b_full_custom_pretrain_e1.py b/xtuner/configs/custom_dataset/pretrain/minicpm/minicpm3_4b_full_custom_pretrain_e1.py new file mode 100644 index 000000000..799ea18a1 --- /dev/null +++ b/xtuner/configs/custom_dataset/pretrain/minicpm/minicpm3_4b_full_custom_pretrain_e1.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Data format: + +[ + { + "text": "xxx" + }, + { + "text": "xxx" + }, + ... +] +""" # noqa: E501 + +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import pretrain_map_fn +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'openbmb/MiniCPM3-4B' +use_varlen_attn = False + +# Data +data_files = ['/path/to/your.json'] +max_length = 1024 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 1 # bs = 1 GPU * 1 batch_size_per_device * 16 acc +dataloader_num_workers = 0 +max_steps = 10000 +optim_type = AdamW +lr = 2e-5 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_inputs = ['上海是', 'Shanghai is'] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right', + eos_token="<|im_end|>") + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path='json', data_files=data_files), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=pretrain_map_fn, + template_map_fn=None, + remove_unused_columns=True, + shuffle_before_pack=False, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=max_steps * warmup_ratio, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=max_steps * warmup_ratio, + end=max_steps, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_iters=max_steps) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/custom_dataset/sft/minicpm/minicpm3_4b_chat_qlora_custom_sft_e1.py b/xtuner/configs/custom_dataset/sft/minicpm/minicpm3_4b_chat_qlora_custom_sft_e1.py new file mode 100644 index 000000000..9b0b61384 --- /dev/null +++ b/xtuner/configs/custom_dataset/sft/minicpm/minicpm3_4b_chat_qlora_custom_sft_e1.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Data format: + +[{ + "messages": [ + { "role": "system", "content": "xxx." }, + { "role": "user", "content": "xxx." }, + { "role": "assistant", "content": "xxx.", "loss": false}, + { "role": "user", "content": "xxx." }, + { "role": "assistant", "content": "xxx.", "loss": true} + ] +}, +... +] +""" # noqa: E501 +import torch +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import openai_map_fn, template_map_fn_factory +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'openbmb/MiniCPM3-4B' +use_varlen_attn = False + +# Data +data_files = ['/path/to/your.json'] +prompt_template = PROMPT_TEMPLATE.minicpm3 +max_length = 2048 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 16 # bs = 1 GPU * 1 batch_size_per_device * 16 acc +dataloader_num_workers = 0 +max_steps = 10000 +optim_type = AdamW +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right', + eos_token='<|im_end|>') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path='json', data_files=data_files), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=openai_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_steps, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_steps, + end=max_steps, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_iters=max_steps) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/minicpm/minicpm3_4b/minicpm3_4b_full_alpaca_zh_e3.py b/xtuner/configs/minicpm/minicpm3_4b/minicpm3_4b_full_alpaca_zh_e3.py new file mode 100644 index 000000000..1a9e249a6 --- /dev/null +++ b/xtuner/configs/minicpm/minicpm3_4b/minicpm3_4b_full_alpaca_zh_e3.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import alpaca_zh_map_fn, template_map_fn_factory +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.parallel.sequence import SequenceParallelSampler +from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'openbmb/MiniCPM3-4B' +use_varlen_attn = False + +# Data +alpaca_en_path = 'silk-road/alpaca-data-gpt4-chinese' +prompt_template = PROMPT_TEMPLATE.minicpm3 +max_length = 2048 +pack_to_max_length = True + +# parallel +sequence_parallel_size = 1 + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 16 +accumulative_counts *= sequence_parallel_size +dataloader_num_workers = 0 +max_steps = 10000 +optim_type = AdamW +lr = 2e-5 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = SYSTEM_TEMPLATE.alpaca +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right', + eos_token='') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +alpaca_en = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_en_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_zh_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +sampler = SequenceParallelSampler \ + if sequence_parallel_size > 1 else DefaultSampler + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=alpaca_en, + sampler=dict(type=sampler, shuffle=True), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_steps, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_steps, + end=max_steps, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_iters=max_steps) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/minicpm/minicpm3_4b/minicpm_2b_dpo_qlora.py b/xtuner/configs/minicpm/minicpm3_4b/minicpm_2b_dpo_qlora.py new file mode 100644 index 000000000..dcb3344db --- /dev/null +++ b/xtuner/configs/minicpm/minicpm3_4b/minicpm_2b_dpo_qlora.py @@ -0,0 +1,221 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from peft import LoraConfig +from torch.optim import AdamW +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.dataset.collate_fns.preference_collate_fn import \ + preference_collate_fn +from xtuner.dataset.preference_dataset import (build_preference_dataset, + orpo_dpo_mix_40k_map_fn) +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model.dpo import DPO +from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'openbmb/MiniCPM3-4B' +use_varlen_attn = False +dpo_loss_type = 'sigmoid' # One of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo_hard', 'nca_pair', 'robust'] # noqa: E501 +loss_beta = 0.1 +label_smoothing = 0.0 + +# Data +prompt_template = PROMPT_TEMPLATE.minicpm +max_length = 2048 + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 16 +dataloader_num_workers = 0 +max_steps = 3 +optim_type = AdamW +lr = 5e-7 # refer to alignment handbook +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = SYSTEM_TEMPLATE.alpaca +evaluation_inputs = [ + 'What famous British author, known for his tales of mystery and the macabre, shares his initials with a common abbreviation for "rest in peace"?', # noqa: E501 + 'Please tell me five scenic spots in Shanghai', + '890729 - 425663? Only respond with math and no words.' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=DPO, + use_varlen_attn=use_varlen_attn, + loss_type=dpo_loss_type, + beta=loss_beta, + label_smoothing=label_smoothing, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=build_preference_dataset, + dataset=dict(type=load_dataset, path='mlabonne/orpo-dpo-mix-40k'), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=orpo_dpo_mix_40k_map_fn, + is_dpo=True, + is_reward=False, + reward_token_id=-1, + num_proc=32, + use_varlen_attn=use_varlen_attn, + shuffle_before_pack=True, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict( + type=preference_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_steps, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_steps, + end=max_steps, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_iters=max_steps) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/utils/templates.py b/xtuner/utils/templates.py index 59b472731..0e5732a3e 100644 --- a/xtuner/utils/templates.py +++ b/xtuner/utils/templates.py @@ -131,6 +131,14 @@ INSTRUCTION=('[INST] {input} [/INST]'), SEP='\n'), minicpm=dict(INSTRUCTION=('<用户> {input} '), SEP='\n'), + minicpm3=dict( + SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'), + INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n' + '<|im_start|>assistant\n'), + SUFFIX='<|im_end|>', + SUFFIX_AS_EOS=True, + SEP='\n', + STOP_WORDS=['<|im_end|>', '<|endoftext|>']), gemma=dict( # `system` field is extended by xtuner SYSTEM=('system\n{system}\n'),