diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 2089a048b..d98454507 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -13,8 +13,23 @@ baichuan_13b_attn_forward) from .yi import yi_attn_forward -IS_LOW_VERSION_TRANSFORMERS = digit_version( - transformers.__version__) < digit_version('4.38') +LOWEST_TRANSFORMERS_VERSION = dict( + internlm2=digit_version('4.36'), + internlm=digit_version('4.36'), + llama=digit_version('4.36'), + phi3=digit_version('4.39'), + yi=digit_version('4.36'), + mistral=digit_version('4.36'), + # Training mixtral with lower version may lead to nccl timeout + # Refer to https://github.com/microsoft/DeepSpeed/issues/5066 + mixtral=digit_version('4.40'), + cohere=digit_version('4.40'), + qwen2=digit_version('4.39'), + qwen2_moe=digit_version('4.40'), +) + +TRANSFORMERS_VERSION = digit_version(transformers.__version__) +IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.38') # Transformers requires torch version >= 2.1.1 when using Torch SDPA. # Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611 # noqa: E501 SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1') @@ -448,7 +463,44 @@ def set_qwen_moe_blocks_z3_leaf_modules(model): set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) +def check_transformers_version(model): + + def check(model_name): + msg = '{} requires transformers version at least {}, but got {}' + assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[ + model_name], msg.format(model_name, + LOWEST_TRANSFORMERS_VERSION[model_name], + TRANSFORMERS_VERSION) + + model_name = model.__class__.__name__.lower() + + if 'internlm2' in model_name: + check('internlm2') + elif 'internlm' in model_name: + check('internlm') + elif 'llama' in model_name: + check('llama') + elif 'phi3' in model_name: + check('phi3') + elif 'baichuan' in model_name: + check('baichuan') + elif 'yi' in model_name: + check('yi') + elif 'mistral' in model_name: + check('mistral') + elif 'mixtral' in model_name: + check('mixtral') + elif 'cohere' in model_name: + check('cohere') + elif 'qwen2moe' in model_name: + check('qwen2moe') + elif 'qwen2' in model_name: + check('qwen2') + + def dispatch_modules(model, use_varlen_attn=False): + check_transformers_version(model) + model_name = model.__class__.__name__.lower() if 'internlm2' in model_name: dispatch_internlm2_attn_forward(model, use_varlen_attn)