From 089f8a1f547fbbc823c515562e5ba598b1acf1c1 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 31 Aug 2023 18:50:29 +0800 Subject: [PATCH] fix config --- .../starcoder/starcoder_qlora_stack_exchange_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xtuner/configs/starcoder/starcoder_qlora_stack_exchange_example.py b/xtuner/configs/starcoder/starcoder_qlora_stack_exchange_example.py index b7e4828a0..591c3f3d3 100644 --- a/xtuner/configs/starcoder/starcoder_qlora_stack_exchange_example.py +++ b/xtuner/configs/starcoder/starcoder_qlora_stack_exchange_example.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +from bitsandbytes.optim import PagedAdamW32bit from datasets import load_dataset from mmengine.dataset import DefaultSampler from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, LoggerHook, ParamSchedulerHook) from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR from peft import LoraConfig -from torch.optim import AdamW from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig) @@ -35,7 +35,7 @@ accumulative_counts = 16 # 1bs * 16acc * 1gpu = 16 batchsize dataloader_num_workers = 0 max_epochs = 1 -optim_type = AdamW +optim_type = PagedAdamW32bit lr = 1e-4 betas = (0.9, 0.999) weight_decay = 0.05 @@ -62,7 +62,6 @@ type=AutoModelForCausalLM.from_pretrained, pretrained_model_name_or_path=pretrained_model_name_or_path, trust_remote_code=True, - load_in_8bit=True, torch_dtype=torch.float16, quantization_config=dict( type=BitsAndBytesConfig,