From 9c28b40790f139e29625ad273e0878d3649f542e Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:41:43 +0800 Subject: [PATCH] [Bugs] fix dispatch bugs when model not in LOWEST_TRANSFORMERS_VERSION (#802) * fix dispatch bugs when model not in LOWEST_TRANSFORMERS_VERSION * move rope_theta --- xtuner/model/modules/dispatch/__init__.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 7cb159515..e81ec7a3a 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -228,14 +228,14 @@ def replace_rote(model): from mmengine import print_log print_log = log_once(print_log) - assert hasattr(model.config, 'rope_theta'), \ - '`rope_theta` should be in the model config.' - rope_theta = model.config.rope_theta - def traverse(module): for name, child in module.named_children(): cls_name = type(child).__name__ if cls_name in ROTE_DISPATCH_MAPPING: + assert hasattr(model.config, 'rope_theta'), \ + '`rope_theta` should be in the model config.' + rope_theta = model.config.rope_theta + rote = ROTE_DISPATCH_MAPPING[cls_name] rote = rote.build() print_log(f'replace {cls_name}', 'current') @@ -258,10 +258,11 @@ def check(model_name): # a walkaround for reward model model_name = model_name[:-5] + 'ForCausalLM' msg = '{} requires transformers version at least {}, but got {}' - assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ - model_name], msg.format(model_name, - LOWEST_TRANSFORMERS_VERSION[model_name], - TRANSFORMERS_VERSION) + if model_name in LOWEST_TRANSFORMERS_VERSION: + assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ + model_name], msg.format( + model_name, LOWEST_TRANSFORMERS_VERSION[model_name], + TRANSFORMERS_VERSION) check(type(model).__name__) if use_varlen_attn: