Skip to content

Commit

Permalink
recover sft.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Aug 29, 2024
1 parent cd72523 commit 999255c
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,8 @@ def __init__(self,
tokenizer=None,
max_position_embeddings=None):
super().__init__()
with LoadWoInit():
if isinstance(llm, dict):
llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
self.llm = self._build_from_cfg_or_module(llm)
self.llm = self.build_llm_from_cfg(llm, use_varlen_attn,
max_position_embeddings)

if tokenizer is not None:
if isinstance(tokenizer, dict):
Expand Down Expand Up @@ -120,6 +118,19 @@ def __init__(self,
# the sequence.
self.use_varlen_attn = use_varlen_attn

def build_llm_from_cfg(self, llm_cfg, use_varlen_attn,
max_position_embeddings):
# For forward
with LoadWoInit():
if isinstance(llm_cfg, dict):
llm = self._dispatch_lm_model_cfg(llm_cfg,
max_position_embeddings)
llm = self._build_from_cfg_or_module(llm)

llm.config.use_cache = False
dispatch_modules(llm, use_varlen_attn=use_varlen_attn)
return llm

def gradient_checkpointing_enable(self):
self.activation_checkpointing_enable()

Expand Down

0 comments on commit 999255c

Please sign in to comment.