Skip to content

Commit

Permalink
[Feature] support saving eval output before save checkpoint (#385)
Browse files Browse the repository at this point in the history
* support saving eval output before save checkpoint

* refactor
  • Loading branch information
HIT-cwh authored Feb 1, 2024
1 parent 58537c3 commit 47c08d8
Showing 1 changed file with 119 additions and 65 deletions.
184 changes: 119 additions & 65 deletions xtuner/engine/hooks/evaluate_chat_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings

import torch
Expand Down Expand Up @@ -85,15 +86,111 @@ def __init__(self,
self.stop_criteria.append(
StopWordStoppingCriteria(self.tokenizer, word))

def _generate_samples(self, runner, max_new_tokens=None):
def _save_eval_output(self, runner, eval_outputs):
save_path = os.path.join(runner.log_dir, 'vis_data',
f'eval_outputs_iter_{runner.iter}.txt')
with open(save_path, 'w') as f:
for i, output in enumerate(eval_outputs):
f.write(f'Eval output {i + 1}:\n{output}\n\n')

def _eval_images(self,
runner,
model,
device,
max_new_tokens=None,
save_eval_output=False):
if save_eval_output:
eval_outputs = []

for sample_image, sample_input in zip(self.evaluation_images,
self.evaluation_inputs):
image = expand2square(
sample_image,
tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
image = image.to(device)
sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
chunk_encode = []
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(
chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
input_ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_ids.append(IMAGE_TOKEN_INDEX)
input_ids = torch.tensor(input_ids).to(device)
visual_outputs = model.visual_encoder(
image.unsqueeze(0), output_hidden_states=True)
pixel_values = model.projector(
visual_outputs.hidden_states[model.visual_select_layer][:, 1:])

mm_inputs = prepare_inputs_labels_for_multimodal(
llm=model.llm,
input_ids=input_ids.unsqueeze(0),
pixel_values=pixel_values)

generation_output = model.generate(
**mm_inputs,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria)
generation_output = self.tokenizer.decode(generation_output[0])
runner.logger.info(f'Sample output:\n'
f'{inputs + generation_output}\n')
if save_eval_output:
eval_outputs.append(f'{inputs + generation_output}\n')

if save_eval_output:
self._save_eval_output(runner, eval_outputs)

def _eval_language(self,
runner,
model,
device,
max_new_tokens=None,
save_eval_output=False):
if save_eval_output:
eval_outputs = []

for sample_input in self.evaluation_inputs:
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
input_ids = self.tokenizer.encode(inputs, return_tensors='pt')
input_ids = input_ids.to(device)
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
stopping_criteria=self.stop_criteria)
generation_output = self.tokenizer.decode(generation_output[0])
runner.logger.info(f'Sample output:\n{generation_output}\n')
if save_eval_output:
eval_outputs.append(f'{generation_output}\n')

if save_eval_output:
self._save_eval_output(runner, eval_outputs)

def _generate_samples(self,
runner,
max_new_tokens=None,
save_eval_output=False):
if max_new_tokens is None:
max_new_tokens = self.max_new_tokens
model = runner.model
if is_model_wrapper(model):
model = model.module

device = next(iter(model.parameters())).device

is_checkpointing = model.llm.is_gradient_checkpointing
use_cache = model.llm.config.use_cache

Expand All @@ -102,67 +199,11 @@ def _generate_samples(self, runner, max_new_tokens=None):
model.llm.config.use_cache = True
model.eval()
if self.evaluation_images is not None:
for sample_image, sample_input in zip(self.evaluation_images,
self.evaluation_inputs):
image = expand2square(
sample_image,
tuple(
int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
image = image.to(device)
sample_input = DEFAULT_IMAGE_TOKEN + '\n' + sample_input
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
chunk_encode = []
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
if idx == 0:
cur_encode = self.tokenizer.encode(chunk)
else:
cur_encode = self.tokenizer.encode(
chunk, add_special_tokens=False)
chunk_encode.append(cur_encode)
assert len(chunk_encode) == 2
input_ids = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_ids.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_ids.append(IMAGE_TOKEN_INDEX)
input_ids = torch.tensor(input_ids).to(device)
visual_outputs = model.visual_encoder(
image.unsqueeze(0), output_hidden_states=True)
pixel_values = model.projector(visual_outputs.hidden_states[
model.visual_select_layer][:, 1:])

mm_inputs = prepare_inputs_labels_for_multimodal(
llm=model.llm,
input_ids=input_ids.unsqueeze(0),
pixel_values=pixel_values)

generation_output = model.generate(
**mm_inputs,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
bos_token_id=self.tokenizer.bos_token_id,
stopping_criteria=self.stop_criteria)
runner.logger.info(
f'Sample output:\n'
f'{inputs + self.tokenizer.decode(generation_output[0])}\n'
)
self._eval_images(runner, model, device, max_new_tokens,
save_eval_output)
else:
for sample_input in self.evaluation_inputs:
inputs = (self.system + self.instruction).format(
input=sample_input, round=1, **runner.cfg)
input_ids = self.tokenizer.encode(inputs, return_tensors='pt')
input_ids = input_ids.to(device)
generation_output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
generation_config=self.gen_config,
stopping_criteria=self.stop_criteria)
runner.logger.info(
f'Sample output:\n'
f'{self.tokenizer.decode(generation_output[0])}\n')
self._eval_language(runner, model, device, max_new_tokens,
save_eval_output)

# Cast to training mode
if is_checkpointing:
Expand All @@ -179,11 +220,24 @@ def after_train_iter(self,
batch_idx: int,
data_batch=None,
outputs=None) -> None:
if (self.every_n_iters is None or batch_idx == 0
or batch_idx % self.every_n_iters != 0):
if self.every_n_iters is None:
return

save_eval_output = False
try:
save_ckpt_freq = runner.cfg.default_hooks.checkpoint.interval
save_eval_output = self.every_n_train_iters(runner, save_ckpt_freq)
except KeyError:
pass

do_chat = (
save_eval_output
or self.every_n_train_iters(runner, self.every_n_iters))
if not do_chat:
return

runner.logger.info('after_train_iter in EvaluateChatHook.')
self._generate_samples(runner)
self._generate_samples(runner, save_eval_output=save_eval_output)

def after_train(self, runner):
runner.logger.info('after_train in EvaluateChatHook.')
Expand Down

0 comments on commit 47c08d8

Please sign in to comment.