From b92481fb40c6df2941bd18dafb512e91514abedc Mon Sep 17 00:00:00 2001 From: Xu Song Date: Thu, 11 Jul 2024 17:21:07 +0800 Subject: [PATCH] [Feature] Support the DatasetInfoHook of DPO training (#787) * [Feature] Support the DatasetInfoHook of DPO training * fix yapf check --- .../internlm/internlm2_chat_1_8b_dpo_full.py | 4 +-- ...internlm2_chat_1_8b_dpo_full_varlenattn.py | 4 +-- ..._1_8b_dpo_full_varlenattn_jsonl_dataset.py | 4 +-- .../internlm2_chat_7b_dpo_qlora_varlenattn.py | 4 +-- ...llama3_8b_instruct_dpo_qlora_varlenattn.py | 4 +-- xtuner/engine/hooks/dataset_info_hook.py | 29 ++++++++++++------- 6 files changed, 28 insertions(+), 21 deletions(-) diff --git a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py index dd3909f72..908683fe6 100644 --- a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py +++ b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full.py @@ -11,7 +11,7 @@ preference_collate_fn from xtuner.dataset.preference_dataset import (build_preference_dataset, orpo_dpo_mix_40k_map_fn) -from xtuner.engine.hooks import (EvaluateChatHook, +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model.dpo import DPO @@ -141,7 +141,7 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - # dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, diff --git a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn.py b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn.py index 3e5cdc35a..787ad68bb 100644 --- a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn.py +++ b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn.py @@ -11,7 +11,7 @@ preference_collate_fn from xtuner.dataset.preference_dataset import (build_preference_dataset, orpo_dpo_mix_40k_map_fn) -from xtuner.engine.hooks import (EvaluateChatHook, +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model.dpo import DPO @@ -151,7 +151,7 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - # dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, diff --git a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn_jsonl_dataset.py b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn_jsonl_dataset.py index 55bb270a4..ae1a3cdca 100644 --- a/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn_jsonl_dataset.py +++ b/xtuner/configs/dpo/internlm/internlm2_chat_1_8b_dpo_full_varlenattn_jsonl_dataset.py @@ -10,7 +10,7 @@ preference_collate_fn from xtuner.dataset.preference_dataset import (build_preference_dataset, load_jsonl_dataset) -from xtuner.engine.hooks import (EvaluateChatHook, +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model.dpo import DPO @@ -155,7 +155,7 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - # dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, diff --git a/xtuner/configs/dpo/internlm/internlm2_chat_7b_dpo_qlora_varlenattn.py b/xtuner/configs/dpo/internlm/internlm2_chat_7b_dpo_qlora_varlenattn.py index b051ea2a1..659d029b3 100644 --- a/xtuner/configs/dpo/internlm/internlm2_chat_7b_dpo_qlora_varlenattn.py +++ b/xtuner/configs/dpo/internlm/internlm2_chat_7b_dpo_qlora_varlenattn.py @@ -14,7 +14,7 @@ preference_collate_fn from xtuner.dataset.preference_dataset import (build_preference_dataset, orpo_dpo_mix_40k_map_fn) -from xtuner.engine.hooks import (EvaluateChatHook, +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model.dpo import DPO @@ -170,7 +170,7 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - # dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, diff --git a/xtuner/configs/dpo/llama/llama3_8b_instruct_dpo_qlora_varlenattn.py b/xtuner/configs/dpo/llama/llama3_8b_instruct_dpo_qlora_varlenattn.py index 0ca90f51c..e94b88fd0 100644 --- a/xtuner/configs/dpo/llama/llama3_8b_instruct_dpo_qlora_varlenattn.py +++ b/xtuner/configs/dpo/llama/llama3_8b_instruct_dpo_qlora_varlenattn.py @@ -14,7 +14,7 @@ preference_collate_fn from xtuner.dataset.preference_dataset import (build_preference_dataset, orpo_dpo_mix_40k_map_fn) -from xtuner.engine.hooks import (EvaluateChatHook, +from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook) from xtuner.engine.runner import TrainLoop from xtuner.model.dpo import DPO @@ -170,7 +170,7 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - # dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( type=EvaluateChatHook, tokenizer=tokenizer, diff --git a/xtuner/engine/hooks/dataset_info_hook.py b/xtuner/engine/hooks/dataset_info_hook.py index d997373ec..84dc9498a 100644 --- a/xtuner/engine/hooks/dataset_info_hook.py +++ b/xtuner/engine/hooks/dataset_info_hook.py @@ -25,19 +25,26 @@ def __init__(self, tokenizer, is_intern_repo_dataset=False): self.is_intern_repo_dataset = is_intern_repo_dataset def log(self, runner, dataset, mode='train'): + + def _log(input_ids, log_prefix=''): + if self.is_intern_repo_dataset: + input_ids = [abs(x) for x in input_ids] + # Try to split list to be compatible with IMAGE token + input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX) + text = log_prefix + for idx, ids in enumerate(input_ids): + text += self.tokenizer.decode(ids) + if idx != len(input_ids) - 1: + text += DEFAULT_IMAGE_TOKEN + runner.logger.info(text) + runner.logger.info(f'Num {mode} samples {len(dataset)}') runner.logger.info(f'{mode} example:') - input_ids = dataset[0]['input_ids'] - if self.is_intern_repo_dataset: - input_ids = [abs(x) for x in input_ids] - # Try to split list to be compatible with IMAGE token - input_ids = split_list(input_ids, IMAGE_TOKEN_INDEX) - text = '' - for idx, ids in enumerate(input_ids): - text += self.tokenizer.decode(ids) - if idx != len(input_ids) - 1: - text += DEFAULT_IMAGE_TOKEN - runner.logger.info(text) + if 'chosen_ids' in dataset[0]: + _log(dataset[0]['chosen_ids'], log_prefix='chosen: ') + _log(dataset[0]['rejected_ids'], log_prefix='rejected: ') + else: + _log(dataset[0]['input_ids']) def before_train(self, runner) -> None: do_train = runner.train_loop is not None