From 999255cff2c2002ed9627ce10a970c504bbd39d4 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 29 Aug 2024 12:30:06 +0000 Subject: [PATCH] recover sft.py --- xtuner/model/sft.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index c0fc3617a..4c93b520f 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -80,10 +80,8 @@ def __init__(self, tokenizer=None, max_position_embeddings=None): super().__init__() - with LoadWoInit(): - if isinstance(llm, dict): - llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) - self.llm = self._build_from_cfg_or_module(llm) + self.llm = self.build_llm_from_cfg(llm, use_varlen_attn, + max_position_embeddings) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -120,6 +118,19 @@ def __init__(self, # the sequence. self.use_varlen_attn = use_varlen_attn + def build_llm_from_cfg(self, llm_cfg, use_varlen_attn, + max_position_embeddings): + # For forward + with LoadWoInit(): + if isinstance(llm_cfg, dict): + llm = self._dispatch_lm_model_cfg(llm_cfg, + max_position_embeddings) + llm = self._build_from_cfg_or_module(llm) + + llm.config.use_cache = False + dispatch_modules(llm, use_varlen_attn=use_varlen_attn) + return llm + def gradient_checkpointing_enable(self): self.activation_checkpointing_enable()