From ba7afc79b1f338b8485d6f51763c078f877116da Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 19 Jul 2024 13:32:02 +0800 Subject: [PATCH] [Fix] fix initialization of ref_llm for full param dpo training with zero-3 (#778) * Fix initialization of ref_llm * Update dpo.py * Update dpo.py * Update dpo.py * Update sft.py * Update dpo.py * Update dpo.py * Update dpo.py --- xtuner/model/dpo.py | 45 ++++++++++++++++++++++++--------------------- xtuner/model/sft.py | 21 +++++++++++++++------ 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/xtuner/model/dpo.py b/xtuner/model/dpo.py index 9a7b97a19..9384ddb34 100644 --- a/xtuner/model/dpo.py +++ b/xtuner/model/dpo.py @@ -16,6 +16,26 @@ from .sft import SupervisedFinetune +def disable_grad(model): + # freeze parameters + parameter_names = [n for n, _ in model.named_parameters()] + for param_name in parameter_names: + param = model.get_parameter(param_name) + param.requires_grad = False + return model.eval() + + +def create_reference_model(model): + if is_deepspeed_zero3_enabled(): + raise ValueError('DeepSpeed ZeRO-3 is enabled and is not compatible ' + 'with `create_reference_model()`. Please instantiate ' + 'your reference model directly with ' + '`AutoCausalLM.from_pretrained()`.') + ref_model = deepcopy(model) + ref_model = disable_grad(ref_model) + return ref_model + + class DPO(SupervisedFinetune): """A general class of DPO and its variants.""" @@ -27,32 +47,15 @@ def __init__(self, label_smoothing=0.0, **kwargs): super().__init__(llm, **kwargs) - self.ref_llm = ref_llm self.loss_type = loss_type self.label_smoothing = label_smoothing self.beta = beta - if not self.use_lora: - self.ref_llm = self.create_reference_model(ref_llm, **kwargs) - - def create_reference_model(self, ref_llm=None, **kwargs): - ref_model = None - if ref_llm is None: - if is_deepspeed_zero3_enabled(): - raise ValueError( - 'DeepSpeed ZeRO-3 is enabled and is not compatible ' - 'with `deepcopy(self.llm)`. Please instantiate ' - 'your reference model by modifying key `model.ref_llm` ' - 'in your config with `AutoCausalLM.from_pretrained()`.') - ref_model = deepcopy(self.llm) + if ref_llm is not None: + ref_llm = self._build_llm_from_cfg(ref_llm, kwargs.get("use_varlen_attn"), kwargs.get("max_position_embeddings")) + self.ref_llm = disable_grad(ref_llm) else: - ref_model = SupervisedFinetune(ref_llm, **kwargs).llm - # freeze parameters - parameter_names = [n for n, _ in ref_model.named_parameters()] - for param_name in parameter_names: - param = ref_model.get_parameter(param_name) - param.requires_grad = False - return ref_model.eval() + self.ref_llm = None if self.use_lora else create_reference_model(self.llm) def _gather_masked_logits(self, logits, labels, mask): logits = torch.gather( diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index d030c6c20..9c3fa38c9 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -79,10 +79,8 @@ def __init__(self, tokenizer=None, max_position_embeddings=None): super().__init__() - with LoadWoInit(): - if isinstance(llm, dict): - llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) - self.llm = self._build_from_cfg_or_module(llm) + + self.llm = self._build_llm_from_cfg(llm, use_varlen_attn, max_position_embeddings) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -90,8 +88,6 @@ def __init__(self, smart_tokenizer_and_embedding_resize(tokenizer, self.llm) self.llm.config.use_cache = False - dispatch_modules(self.llm, use_varlen_attn=use_varlen_attn) - if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): @@ -119,6 +115,19 @@ def __init__(self, # the sequence. self.use_varlen_attn = use_varlen_attn + + def _build_llm_from_cfg(self, llm_cfg, use_varlen_attn, max_position_embeddings): + # For forward + with LoadWoInit(): + if isinstance(llm_cfg, dict): + llm = self._dispatch_lm_model_cfg(llm_cfg, max_position_embeddings) + llm = self._build_from_cfg_or_module(llm) + + llm.config.use_cache = False + dispatch_modules(llm, use_varlen_attn=use_varlen_attn) + return llm + + def gradient_checkpointing_enable(self): self.activation_checkpointing_enable()