Skip to content

Commit

Permalink
[Fix] Use token_id instead of token for encode_fn & Set eval mo…
Browse files Browse the repository at this point in the history
…de before generate (#107)

* set eval mode before generate

* use token_id instead of token
  • Loading branch information
LZHgrla authored Sep 7, 2023
1 parent 36cf646 commit abd9de1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
32 changes: 18 additions & 14 deletions xtuner/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,35 @@ def encode_fn(example, tokenizer, max_length, input_ids_with_output=True):
]
"""
if tokenizer.__class__.__name__ == 'QWenTokenizer':
bos_token = ''
eos_token = '<|endoftext|>'
bos_token_id = []
eos_token_id = tokenizer.encode(
'<|endoftext|>', add_special_tokens=False)
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token = ''
eos_token = tokenizer.eos_token
bos_token_id = []
eos_token_id = tokenizer.eos_token_id
else:
bos_token = tokenizer.bos_token
eos_token = tokenizer.eos_token
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
if isinstance(bos_token_id, int):
bos_token_id = [bos_token_id]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output

input_ids, labels = [], []
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
input_encode = tokenizer(
f'{bos_token}{input}', add_special_tokens=False)
input_ids += input_encode['input_ids']
labels += [IGNORE_INDEX] * len(input_encode['input_ids'])
input_encode = tokenizer(f'{input}', add_special_tokens=False)
input_ids += bos_token_id + input_encode['input_ids']
labels += [IGNORE_INDEX] * (
len(bos_token_id + input_encode['input_ids']))
if input_ids_with_output:
output = single_turn_conversation['output']
output_encode = tokenizer(
f'{output}{eos_token}', add_special_tokens=False)
input_ids += output_encode['input_ids']
labels += copy.deepcopy(output_encode['input_ids'])
output_encode = tokenizer(f'{output}', add_special_tokens=False)
input_ids += output_encode['input_ids'] + eos_token_id
labels += copy.deepcopy(output_encode['input_ids'] + eos_token_id)

if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
Expand Down
2 changes: 2 additions & 0 deletions xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _generate_samples(self, runner, max_new_tokens=None):
# Cast to inference mode
model.llm.gradient_checkpointing_disable()
model.llm.config.use_cache = True
model.eval()

for sample_input in self.evaluation_inputs:
inputs = self.instruction.format(
Expand All @@ -77,6 +78,7 @@ def _generate_samples(self, runner, max_new_tokens=None):
if is_checkpointing:
model.llm.gradient_checkpointing_enable()
model.llm.config.use_cache = use_cache
model.train()

def before_train(self, runner):
runner.logger.info('before_train in EvaluateChatHook .')
Expand Down
1 change: 1 addition & 0 deletions xtuner/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def main():
if args.adapter is not None:
model = PeftModel.from_pretrained(model, args.adapter)
print(f'Load adapter from {args.adapter}')
model.eval()

Streamer, stop_criteria = get_chat_utils(model)
if args.no_streamer:
Expand Down

0 comments on commit abd9de1

Please sign in to comment.