diff --git a/xtuner/dataset/utils.py b/xtuner/dataset/utils.py index c6b780dc5..31dc2c845 100644 --- a/xtuner/dataset/utils.py +++ b/xtuner/dataset/utils.py @@ -37,14 +37,19 @@ def encode_fn(example, tokenizer, max_length, input_ids_with_output=True): ] """ if tokenizer.__class__.__name__ == 'QWenTokenizer': - bos_token = '' - eos_token = '<|endoftext|>' + bos_token_id = [] + eos_token_id = tokenizer.encode( + '<|endoftext|>', add_special_tokens=False) elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': - bos_token = '' - eos_token = tokenizer.eos_token + bos_token_id = [] + eos_token_id = tokenizer.eos_token_id else: - bos_token = tokenizer.bos_token - eos_token = tokenizer.eos_token + bos_token_id = tokenizer.bos_token_id + eos_token_id = tokenizer.eos_token_id + if isinstance(bos_token_id, int): + bos_token_id = [bos_token_id] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] is_multi_turn_conversation = len(example['conversation']) > 1 if is_multi_turn_conversation: assert input_ids_with_output @@ -52,16 +57,15 @@ def encode_fn(example, tokenizer, max_length, input_ids_with_output=True): input_ids, labels = [], [] for single_turn_conversation in example['conversation']: input = single_turn_conversation['input'] - input_encode = tokenizer( - f'{bos_token}{input}', add_special_tokens=False) - input_ids += input_encode['input_ids'] - labels += [IGNORE_INDEX] * len(input_encode['input_ids']) + input_encode = tokenizer(f'{input}', add_special_tokens=False) + input_ids += bos_token_id + input_encode['input_ids'] + labels += [IGNORE_INDEX] * ( + len(bos_token_id + input_encode['input_ids'])) if input_ids_with_output: output = single_turn_conversation['output'] - output_encode = tokenizer( - f'{output}{eos_token}', add_special_tokens=False) - input_ids += output_encode['input_ids'] - labels += copy.deepcopy(output_encode['input_ids']) + output_encode = tokenizer(f'{output}', add_special_tokens=False) + input_ids += output_encode['input_ids'] + eos_token_id + labels += copy.deepcopy(output_encode['input_ids'] + eos_token_id) if len(input_ids) > max_length: input_ids = input_ids[:max_length] diff --git a/xtuner/engine/hooks/evaluate_chat_hook.py b/xtuner/engine/hooks/evaluate_chat_hook.py index b6f49bd04..2565a8f33 100644 --- a/xtuner/engine/hooks/evaluate_chat_hook.py +++ b/xtuner/engine/hooks/evaluate_chat_hook.py @@ -57,6 +57,7 @@ def _generate_samples(self, runner, max_new_tokens=None): # Cast to inference mode model.llm.gradient_checkpointing_disable() model.llm.config.use_cache = True + model.eval() for sample_input in self.evaluation_inputs: inputs = self.instruction.format( @@ -77,6 +78,7 @@ def _generate_samples(self, runner, max_new_tokens=None): if is_checkpointing: model.llm.gradient_checkpointing_enable() model.llm.config.use_cache = use_cache + model.train() def before_train(self, runner): runner.logger.info('before_train in EvaluateChatHook .') diff --git a/xtuner/tools/chat.py b/xtuner/tools/chat.py index 54aa927dd..002dc62ca 100644 --- a/xtuner/tools/chat.py +++ b/xtuner/tools/chat.py @@ -144,6 +144,7 @@ def main(): if args.adapter is not None: model = PeftModel.from_pretrained(model, args.adapter) print(f'Load adapter from {args.adapter}') + model.eval() Streamer, stop_criteria = get_chat_utils(model) if args.no_streamer: