Skip to content

Commit

Permalink
check transformers version before dispatch (#672)
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh authored May 10, 2024
1 parent d12bc05 commit b283b99
Showing 1 changed file with 54 additions and 2 deletions.
56 changes: 54 additions & 2 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,23 @@
baichuan_13b_attn_forward)
from .yi import yi_attn_forward

IS_LOW_VERSION_TRANSFORMERS = digit_version(
transformers.__version__) < digit_version('4.38')
LOWEST_TRANSFORMERS_VERSION = dict(
internlm2=digit_version('4.36'),
internlm=digit_version('4.36'),
llama=digit_version('4.36'),
phi3=digit_version('4.39'),
yi=digit_version('4.36'),
mistral=digit_version('4.36'),
# Training mixtral with lower version may lead to nccl timeout
# Refer to https://github.com/microsoft/DeepSpeed/issues/5066
mixtral=digit_version('4.40'),
cohere=digit_version('4.40'),
qwen2=digit_version('4.39'),
qwen2_moe=digit_version('4.40'),
)

TRANSFORMERS_VERSION = digit_version(transformers.__version__)
IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.38')
# Transformers requires torch version >= 2.1.1 when using Torch SDPA.
# Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611 # noqa: E501
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1')
Expand Down Expand Up @@ -448,7 +463,44 @@ def set_qwen_moe_blocks_z3_leaf_modules(model):
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])


def check_transformers_version(model):

def check(model_name):
msg = '{} requires transformers version at least {}, but got {}'
assert TRANSFORMERS_VERSION >= LOWEST_TRANSFORMERS_VERSION[
model_name], msg.format(model_name,
LOWEST_TRANSFORMERS_VERSION[model_name],
TRANSFORMERS_VERSION)

model_name = model.__class__.__name__.lower()

if 'internlm2' in model_name:
check('internlm2')
elif 'internlm' in model_name:
check('internlm')
elif 'llama' in model_name:
check('llama')
elif 'phi3' in model_name:
check('phi3')
elif 'baichuan' in model_name:
check('baichuan')
elif 'yi' in model_name:
check('yi')
elif 'mistral' in model_name:
check('mistral')
elif 'mixtral' in model_name:
check('mixtral')
elif 'cohere' in model_name:
check('cohere')
elif 'qwen2moe' in model_name:
check('qwen2moe')
elif 'qwen2' in model_name:
check('qwen2')


def dispatch_modules(model, use_varlen_attn=False):
check_transformers_version(model)

model_name = model.__class__.__name__.lower()
if 'internlm2' in model_name:
dispatch_internlm2_attn_forward(model, use_varlen_attn)
Expand Down

0 comments on commit b283b99

Please sign in to comment.