diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index cd7d91014b6..56231b20e68 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1458,7 +1458,7 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.chatglm4v import vision_model_forward convert_forward(model, vision_module.VisionModel, vision_model_forward) - elif model.config.num_layers == 40: + elif model.config.num_layers in [40, 28]: # glm-4-9b from ipex_llm.transformers.models.chatglm4 import chatglm4_attention_forward from ipex_llm.transformers.models.chatglm4 import chatglm4_model_forward diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 2daffedb118..282ce5bf8c7 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -44,6 +44,7 @@ def chatglm4_model_forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else