Skip to content

Commit

Permalink
Optimize qwen 1.5 14B batch performance (#11370)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jun 20, 2024
1 parent 5aa3e42 commit f0fdfa0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,8 @@ def _optimize_pre(model):
if model.config.model_type == "qwen2":
from ipex_llm.transformers.models.qwen2 import merge_qkv
model.apply(merge_qkv)
from ipex_llm.transformers.models.qwen2 import padding_mlp
model.apply(padding_mlp)
if model.config.model_type == "qwen2_moe":
from ipex_llm.transformers.models.qwen2_moe import merge_qkv
model.apply(merge_qkv)
Expand Down
34 changes: 33 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common import invalidInputError

from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast
Expand Down Expand Up @@ -288,6 +289,37 @@ def merge_qkv(module: torch.nn.Module):
del module.q_proj, module.k_proj, module.v_proj


def padding_mlp(module: torch.nn.Module):
# for qwen 1.5 14B
if isinstance(module, Qwen2MLP):
hidden_size = module.hidden_size
intermediate_size = module.intermediate_size
padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
if intermediate_size % 256 == 0:
return

gate_weight = module.gate_proj.weight.data
new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=gate_weight.dtype, device=gate_weight.device)
new_gate_weight[:intermediate_size, :] = gate_weight
module.gate_proj.out_features = padding_intermediate_size
module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)

up_weight = module.up_proj.weight.data
new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=up_weight.dtype, device=up_weight.device)
new_up_weight[:intermediate_size, :] = up_weight
module.up_proj.out_features = padding_intermediate_size
module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)

down_weight = module.down_proj.weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
module.down_proj.in_features = padding_intermediate_size
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)


def qwen2_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down

0 comments on commit f0fdfa0

Please sign in to comment.