From 58f7ad425776ca21ab96600106b98893b9f9ef59 Mon Sep 17 00:00:00 2001 From: huanghaian Date: Fri, 21 Jun 2024 10:45:05 +0800 Subject: [PATCH] update config --- .../internvl_v1_5_internlm2_1_8b_finetune.py | 170 ++++++++++++++++ ...ernvl_v1_5_internlm2_1_8b_lora_finetune.py | 183 +++++++++++++++++ ...rnvl_v1_5_internlm2_1_8b_qlora_finetune.py | 187 ++++++++++++++++++ ...> internvl_v1_5_internlm2_26b_finetune.py} | 54 ++--- ...ernvl_v1_5_internlm2_26b_lora_finetune.py} | 52 ++--- ...rnvl_v1_5_internlm2_26b_qlora_finetune.py} | 52 ++--- ...t.py => internvl_v1_5_phi3_4b_finetune.py} | 48 +---- ...=> internvl_v1_5_phi3_4b_lora_finetune.py} | 48 +---- ...> internvl_v1_5_phi3_4b_qlora_finetune.py} | 48 +---- xtuner/model/__init__.py | 4 +- xtuner/model/internvl.py | 132 +++++++++++-- 11 files changed, 720 insertions(+), 258 deletions(-) create mode 100644 xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_finetune.py create mode 100644 xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_lora_finetune.py create mode 100644 xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_qlora_finetune.py rename xtuner/configs/internvl/v1_5/{internvl_26b_sft.py => internvl_v1_5_internlm2_26b_finetune.py} (80%) rename xtuner/configs/internvl/v1_5/{internvl_26b_sft_lora.py => internvl_v1_5_internlm2_26b_lora_finetune.py} (82%) rename xtuner/configs/internvl/v1_5/{internvl_26b_sft_qlora.py => internvl_v1_5_internlm2_26b_qlora_finetune.py} (82%) rename xtuner/configs/internvl/v1_5/{mini_internvl_phi3_sft.py => internvl_v1_5_phi3_4b_finetune.py} (81%) rename xtuner/configs/internvl/v1_5/{mini_internvl_phi3_sft_lora.py => internvl_v1_5_phi3_4b_lora_finetune.py} (82%) rename xtuner/configs/internvl/v1_5/{mini_internvl_phi3_sft_qlora.py => internvl_v1_5_phi3_4b_qlora_finetune.py} (82%) diff --git a/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_finetune.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_finetune.py new file mode 100644 index 000000000..54b387bd4 --- /dev/null +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_finetune.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from xtuner.dataset import InternVL_V1_5_Dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import InternVL_V1_5 +from xtuner.utils import PROMPT_TEMPLATE +from transformers import AutoTokenizer +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +path = "/mnt/hwfile/xtuner/huanghaian/model/Mini-InternVL-Chat-2B-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat + +# Data +data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' +max_length = 8192 + +# Scheduler & Optimizer +batch_size = 4 # per_device +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +# 1024 -> 4e-5 +# 128 -> 5e-6 +lr = 1e-6 +betas = (0.9, 0.999) +weight_decay = 0.05 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 1000 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +model = dict( + type=InternVL_V1_5, + model_path=path, + freeze_llm=False, + freeze_visual_encoder=True # or False +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=InternVL_V1_5_Dataset, + model_path=path, + data_path=data_path, + image_folder=image_folder, + template=prompt_template, + max_length=max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=path, + trust_remote_code=True) + +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_lora_finetune.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_lora_finetune.py new file mode 100644 index 000000000..9b0eed439 --- /dev/null +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_lora_finetune.py @@ -0,0 +1,183 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from xtuner.dataset import InternVL_V1_5_Dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import InternVL_V1_5 +from xtuner.utils import PROMPT_TEMPLATE +from transformers import AutoTokenizer +from peft import LoraConfig +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +path = "/mnt/hwfile/xtuner/huanghaian/model/Mini-InternVL-Chat-2B-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat + +# Data +data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' +max_length = 8192 + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +# 1024 -> 4e-5 +# 128 -> 5e-6 +lr = 1e-6 +betas = (0.9, 0.999) +weight_decay = 0.05 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 1000 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +model = dict( + type=InternVL_V1_5, + model_path=path, + freeze_llm=True, + freeze_visual_encoder=True, + # comment the following lines if you don't want to use Lora in llm + llm_lora=dict( + type=LoraConfig, + r=128, + lora_alpha=256, + lora_dropout=0.05, + target_modules=None, + task_type='CAUSAL_LM'), + # uncomment the following lines if you don't want to use Lora in visual encoder + # visual_encoder_lora=dict( + # type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, + # target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']) +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=InternVL_V1_5_Dataset, + model_path=path, + data_path=data_path, + image_folder=image_folder, + template=prompt_template, + max_length=max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=path, + trust_remote_code=True) + +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_qlora_finetune.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_qlora_finetune.py new file mode 100644 index 000000000..b139abf46 --- /dev/null +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_1_8b_qlora_finetune.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW + +from xtuner.dataset import InternVL_V1_5_Dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.samplers import LengthGroupedSampler +from xtuner.engine.hooks import DatasetInfoHook +from xtuner.engine.runner import TrainLoop +from xtuner.model import InternVL_V1_5 +from xtuner.utils import PROMPT_TEMPLATE +from transformers import AutoTokenizer +from peft import LoraConfig + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +path = "/mnt/hwfile/xtuner/huanghaian/model/Mini-InternVL-Chat-2B-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat + +# Data +# Data +data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' +image_folder = data_root + 'llava_images' +max_length = 8192 + +# Scheduler & Optimizer +batch_size = 8 # per_device +accumulative_counts = 2 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +# 1024 -> 4e-5 +# 128 -> 5e-6 +lr = 1e-6 +betas = (0.9, 0.999) +weight_decay = 0.05 +max_norm = 1 # grad clip +warmup_ratio = 0.03 + +# Save +save_steps = 1000 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer & Image Processor # +####################################################################### +model = dict( + type=InternVL_V1_5, + model_path=path, + freeze_llm=True, + freeze_visual_encoder=True, + quantization_llm=True, # or False + quantization_vit=False, # or True and uncomment visual_encoder_lora + # comment the following lines if you don't want to use Lora in llm + llm_lora=dict( + type=LoraConfig, + r=128, + lora_alpha=256, + lora_dropout=0.05, + target_modules=None, + task_type='CAUSAL_LM'), + # uncomment the following lines if you don't want to use Lora in visual encoder + # visual_encoder_lora=dict( + # type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, + # target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2']) +) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +llava_dataset = dict( + type=InternVL_V1_5_Dataset, + model_path=path, + data_path=data_path, + image_folder=image_folder, + template=prompt_template, + max_length=max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=llava_dataset, + sampler=dict( + type=LengthGroupedSampler, + length_property='modality_length', + per_device_batch_size=batch_size * accumulative_counts), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1e-5, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=0.0, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=path, + trust_remote_code=True) + +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict(by_epoch=False) diff --git a/xtuner/configs/internvl/v1_5/internvl_26b_sft.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_finetune.py similarity index 80% rename from xtuner/configs/internvl/v1_5/internvl_26b_sft.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_finetune.py index d86dea32d..5e9b67812 100644 --- a/xtuner/configs/internvl/v1_5/internvl_26b_sft.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer ####################################################################### @@ -18,32 +17,12 @@ ####################################################################### # Model path = "/mnt/hwfile/xtuner/huanghaian/model/InternVL-Chat-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' -prompt_template = PROMPT_TEMPLATE.internlm2_chat max_length = 8192 # Scheduler & Optimizer @@ -52,28 +31,24 @@ dataloader_num_workers = 4 max_epochs = 1 optim_type = AdamW -# 1024 -> 2e-5 -# 128 -> 2.5e-6 +# 1024 -> 4e-5 +# 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=False, freeze_visual_encoder=True # or False ) @@ -82,14 +57,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/internlm_26b_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/configs/internvl/v1_5/internvl_26b_sft_lora.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_lora_finetune.py similarity index 82% rename from xtuner/configs/internvl/v1_5/internvl_26b_sft_lora.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_lora_finetune.py index b27b909a1..00ab354c3 100644 --- a/xtuner/configs/internvl/v1_5/internvl_26b_sft_lora.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_lora_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer from peft import LoraConfig @@ -19,32 +18,12 @@ ####################################################################### # Model path = "/mnt/hwfile/xtuner/huanghaian/model/InternVL-Chat-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' -prompt_template = PROMPT_TEMPLATE.internlm2_chat max_length = 8192 # Scheduler & Optimizer @@ -54,27 +33,23 @@ max_epochs = 1 optim_type = AdamW # 1024 -> 4e-5 -# 128 -> 2.5e-6 +# 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=True, freeze_visual_encoder=True, # comment the following lines if you don't want to use Lora in llm @@ -95,14 +70,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/internlm_26b_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/configs/internvl/v1_5/internvl_26b_sft_qlora.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_qlora_finetune.py similarity index 82% rename from xtuner/configs/internvl/v1_5/internvl_26b_sft_qlora.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_qlora_finetune.py index b0531cbfd..f4b3b7503 100644 --- a/xtuner/configs/internvl/v1_5/internvl_26b_sft_qlora.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_internlm2_26b_qlora_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer from peft import LoraConfig @@ -20,32 +19,12 @@ ####################################################################### # Model path = "/mnt/hwfile/xtuner/huanghaian/model/InternVL-Chat-V1-5" +prompt_template = PROMPT_TEMPLATE.internlm2_chat # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' -prompt_template = PROMPT_TEMPLATE.internlm2_chat max_length = 8192 # Scheduler & Optimizer @@ -55,27 +34,23 @@ max_epochs = 1 optim_type = AdamW # 1024 -> 4e-5 -# 128 -> 2.5e-6 +# 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=True, freeze_visual_encoder=True, quantization_llm=True, # or False @@ -98,14 +73,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/internlm_26b_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_finetune.py similarity index 81% rename from xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_finetune.py index 93a1b1d6c..b59a13216 100644 --- a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer ####################################################################### @@ -21,27 +20,7 @@ # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' prompt_template = PROMPT_TEMPLATE.phi3_chat max_length = 8192 @@ -56,24 +35,20 @@ # 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=False, freeze_visual_encoder=True # or False ) @@ -82,14 +57,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/mini_phi3_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_lora.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_lora_finetune.py similarity index 82% rename from xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_lora.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_lora_finetune.py index 6350cbab7..abb267f47 100644 --- a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_lora.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_lora_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer from peft import LoraConfig @@ -22,27 +21,7 @@ # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' prompt_template = PROMPT_TEMPLATE.phi3_chat max_length = 8192 @@ -57,24 +36,20 @@ # 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=True, freeze_visual_encoder=True, # comment the following lines if you don't want to use Lora in llm @@ -95,14 +70,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/mini_phi3_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_qlora.py b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_qlora_finetune.py similarity index 82% rename from xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_qlora.py rename to xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_qlora_finetune.py index 628598933..c2aad8289 100644 --- a/xtuner/configs/internvl/v1_5/mini_internvl_phi3_sft_qlora.py +++ b/xtuner/configs/internvl/v1_5/internvl_v1_5_phi3_4b_qlora_finetune.py @@ -4,13 +4,12 @@ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR from torch.optim import AdamW -from xtuner.dataset import InternVL_V1_5_LLaVADataset +from xtuner.dataset import InternVL_V1_5_Dataset from xtuner.dataset.collate_fns import default_collate_fn -from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory from xtuner.dataset.samplers import LengthGroupedSampler from xtuner.engine.hooks import DatasetInfoHook from xtuner.engine.runner import TrainLoop -from xtuner.model import InternVL +from xtuner.model import InternVL_V1_5 from xtuner.utils import PROMPT_TEMPLATE from transformers import AutoTokenizer from peft import LoraConfig @@ -23,27 +22,7 @@ # Data data_root = '/mnt/hwfile/xtuner/linzhihao/dataset/llava_data/' - -# 为了高效训练,请确保数据格式为: -""" -{ - "id": "000000033471", - "image": ["coco/train2017/000000033471.jpg"], # 如果是纯文本,则该字段为 None 或者不存在 - "image_wh": [[640, 427]], # 如果是纯文本,则该字段为 None 或者不存在 - "conversations": [ - { - "from": "human", - "value": "\nWhat are the colors of the bus in the image?" - }, - { - "from": "gpt", - "value": "The bus in the image is white and red." - } - ] - } -""" - -data_path = '/mnt/hwfile/xtuner/huanghaian/data/llava_v1_5_mix665k_processed.json' +data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json' image_folder = data_root + 'llava_images' prompt_template = PROMPT_TEMPLATE.phi3_chat max_length = 8192 @@ -58,24 +37,20 @@ # 128 -> 5e-6 lr = 1e-6 betas = (0.9, 0.999) -weight_decay = 0 +weight_decay = 0.05 max_norm = 1 # grad clip warmup_ratio = 0.03 # Save -save_steps = 100 +save_steps = 1000 save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) -# Evaluate the generation performance during the training -evaluation_freq = 100 -SYSTEM = '' - ####################################################################### # PART 2 Model & Tokenizer & Image Processor # ####################################################################### model = dict( - type=InternVL, - path=path, + type=InternVL_V1_5, + model_path=path, freeze_llm=True, freeze_visual_encoder=True, quantization_llm=True, # or False @@ -98,14 +73,11 @@ # PART 3 Dataset & Dataloader # ####################################################################### llava_dataset = dict( - type=InternVL_V1_5_LLaVADataset, - offline_processed_text_folder='/mnt/petrelfs/huanghaian/code/xtuner/intervl/mini_phi3_llava_sft', - path=path, + type=InternVL_V1_5_Dataset, + model_path=path, data_path=data_path, image_folder=image_folder, - dataset_map_fn=llava_map_fn, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template=prompt_template, max_length=max_length) train_dataloader = dict( diff --git a/xtuner/model/__init__.py b/xtuner/model/__init__.py index a9d3dd35f..41814f309 100644 --- a/xtuner/model/__init__.py +++ b/xtuner/model/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .llava import LLaVAModel from .sft import SupervisedFinetune -from .internvl import InternVL +from .internvl import InternVL_V1_5 -__all__ = ['SupervisedFinetune', 'LLaVAModel', 'InternVL'] +__all__ = ['SupervisedFinetune', 'LLaVAModel', 'InternVL_V1_5'] diff --git a/xtuner/model/internvl.py b/xtuner/model/internvl.py index c0853ae18..3a2a62767 100644 --- a/xtuner/model/internvl.py +++ b/xtuner/model/internvl.py @@ -3,25 +3,24 @@ from collections import OrderedDict from transformers import AutoTokenizer, AutoModel, AutoConfig import torch -import torch.nn as nn from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training from transformers import BitsAndBytesConfig -from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.registry import BUILDER -from .modules import ProjectorConfig, ProjectorModel, dispatch_modules -from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 -from .utils import (LoadWoInit, find_all_linear_names, +from .utils import (find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, - make_inputs_require_grad, - prepare_inputs_labels_for_multimodal, traverse_dict) + make_inputs_require_grad) from mmengine import print_log +from torch.nn import CrossEntropyLoss +from typing import List, Optional, Tuple, Union +from transformers.modeling_outputs import CausalLMOutputWithPast -class InternVL(BaseModel): - def __init__(self, path, freeze_llm=False, +class InternVL_V1_5(BaseModel): + def __init__(self, model_path, + freeze_llm=False, freeze_visual_encoder=False, llm_lora=None, visual_encoder_lora=None, @@ -40,7 +39,7 @@ def __init__(self, path, freeze_llm=False, if quantization_llm: assert quantization_llm and llm_lora is not None - config = AutoConfig.from_pretrained(path, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config.llm_config._attn_implementation = 'flash_attention_2' if quantization_vit is False and quantization_llm is False: @@ -67,13 +66,13 @@ def __init__(self, path, freeze_llm=False, quantization = quantization_clazz(**quantization_config) self.model = AutoModel.from_pretrained( - path, + model_path, torch_dtype=torch.bfloat16, quantization_config=quantization, config=config, trust_remote_code=True) - tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) img_context_token_id = tokenizer.convert_tokens_to_ids('') self.model.img_context_token_id = img_context_token_id @@ -102,6 +101,7 @@ def __init__(self, path, freeze_llm=False, self.load_state_dict(pretrained_state_dict, strict=False) print(f'Load pretrained weight from {pretrained_pth}') + self._count = 0 print_log(self, logger='current') def _parse_lora_config(self, lora_config): @@ -132,14 +132,12 @@ def gradient_checkpointing_enable(self): self.activation_checkpointing_enable() def activation_checkpointing_enable(self): - # self.model.vision_model.gradient_checkpointing_enable() self.model.language_model.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): - # self.model.vision_model.gradient_checkpointing_disable() self.model.language_model.gradient_checkpointing_disable() def state_dict(self, *args, **kwargs): @@ -195,12 +193,104 @@ def forward(self, data, data_samples=None, mode='loss'): labels = data['labels'] use_cache = False - outputs = self.model(input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - image_flags=image_flags, - pixel_values=concat_images, - labels=labels, - use_cache=use_cache) + # Directly calling this code in LORA fine-tuning will result in an error, + # so we must rewrite it. + # TODO: Once the official is fixed, we can remove it. + # outputs = self.model(input_ids=input_ids, + # position_ids=position_ids, + # attention_mask=attention_mask, + # image_flags=image_flags, + # pixel_values=concat_images, + # labels=labels, + # use_cache=use_cache) + outputs = self._llm_forward(input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + image_flags=image_flags, + pixel_values=concat_images, + labels=labels, + use_cache=use_cache) loss_dict = {'loss': outputs.loss} return loss_dict + + def _llm_forward( + self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + image_flags: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict + + image_flags = image_flags.squeeze(-1) + # We only added the clone code here to avoid the error. + input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone() + + vit_embeds = self.model.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags == 1] + vit_batch_size = pixel_values.shape[0] + + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + + if torch.distributed.get_rank() == 0 and self._count % 100 == 0: + print( + f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') + self._count += 1 + + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.model.img_context_token_id) + try: + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) + except Exception as e: + vit_embeds = vit_embeds.reshape(-1, C) + print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' + f'vit_embeds.shape={vit_embeds.shape}') + n_token = selected.sum() + input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] + + input_embeds = input_embeds.reshape(B, N, C) + + outputs = self.model.language_model( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + logits = outputs.logits + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.model.language_model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )