diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index cac49836548..3ec10b6ac48 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -735,6 +735,8 @@ def _optimize_pre(model): if model.config.model_type == "qwen2": from ipex_llm.transformers.models.qwen2 import merge_qkv model.apply(merge_qkv) + from ipex_llm.transformers.models.qwen2 import padding_mlp + model.apply(padding_mlp) if model.config.model_type == "qwen2_moe": from ipex_llm.transformers.models.qwen2_moe import merge_qkv model.apply(merge_qkv) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 6d209aa5567..80cd4299030 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -49,7 +49,8 @@ from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common import invalidInputError -from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv +from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP +from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast @@ -288,6 +289,37 @@ def merge_qkv(module: torch.nn.Module): del module.q_proj, module.k_proj, module.v_proj +def padding_mlp(module: torch.nn.Module): + # for qwen 1.5 14B + if isinstance(module, Qwen2MLP): + hidden_size = module.hidden_size + intermediate_size = module.intermediate_size + padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256 + if intermediate_size % 256 == 0: + return + + gate_weight = module.gate_proj.weight.data + new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size], + dtype=gate_weight.dtype, device=gate_weight.device) + new_gate_weight[:intermediate_size, :] = gate_weight + module.gate_proj.out_features = padding_intermediate_size + module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False) + + up_weight = module.up_proj.weight.data + new_up_weight = torch.zeros([padding_intermediate_size, hidden_size], + dtype=up_weight.dtype, device=up_weight.device) + new_up_weight[:intermediate_size, :] = up_weight + module.up_proj.out_features = padding_intermediate_size + module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False) + + down_weight = module.down_proj.weight.data + new_down_weight = torch.zeros([hidden_size, padding_intermediate_size], + dtype=down_weight.dtype, device=down_weight.device) + new_down_weight[:, :intermediate_size] = down_weight + module.down_proj.in_features = padding_intermediate_size + module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False) + + def qwen2_attention_forward( self, hidden_states: torch.Tensor,