Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PP inference for chatglm3 #11375

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
- [microsoft/Phi-3-mini-4k-instruct](./run_phi3_arc_2_card.sh)
Expand Down Expand Up @@ -71,6 +72,21 @@ bash run_qwen1.5_arc_2_card.sh

</details>

<details>
<summary> Show chatglm example </summary>

#### Run chatglm3-6B on two Intel Arc A770

You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for chatglm to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

```bash
bash run_chatglm_arc_2_card.sh
```

</details>

</details>

<details>
<summary> Show Baichuan2 example </summary>

Expand Down
24 changes: 16 additions & 8 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
import argparse

from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
from ipex_llm.transformers import AutoModel, AutoModelForCausalLM, init_pipeline_parallel
from transformers import AutoTokenizer

init_pipeline_parallel()
Expand All @@ -41,13 +41,21 @@

# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True,
torch_dtype=torch.float16,
pipeline_parallel_stages=args.gpu_num)
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True,
torch_dtype=torch.float16,
pipeline_parallel_stages=args.gpu_num)
except:
model = AutoModel.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True,
use_cache=True,
pipeline_parallel_stages=args.gpu_num)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0

NUM_GPUS=2 # number of used GPU
# To run chatglm3-6b
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'THUDM/chatglm3-6b' --gpu-num $NUM_GPUS
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def chatglm2_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size, seq_length = input_ids.shape

if inputs_embeds is None:
batch_size, seq_length = input_ids.shape
inputs_embeds = self.embedding(input_ids)
else:
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
seq_length, batch_size, _ = inputs_embeds.shape

if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (
Expand Down
67 changes: 51 additions & 16 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs)
return outputs


class Dummy_GLMBlock(nn.Module):
def __init__(self, *args):
super().__init__()
# to avoid AttributeError
self.input_layernorm = DummyLayer()
self.mlp = Dummy_MLPLayer()

def forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
):
return hidden_states, kv_cache


def init_pipeline_parallel():
import oneccl_bindings_for_pytorch
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
Expand All @@ -79,28 +92,49 @@ def init_pipeline_parallel():


def pipeline_parallel(model, pipeline_parallel_stages):
slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
pipeline_parallel_stages
global num_layers
if hasattr(model.config, 'num_hidden_layers'):
num_layers = model.config.num_hidden_layers
elif hasattr(model.config, 'num_layers'):
# for chatglm3-6b
num_layers = model.config.num_layers

slice_size = (num_layers + pipeline_parallel_stages - 1) // pipeline_parallel_stages

local_rank = dist.get_rank()

global layer_start
global layer_end
layer_start = slice_size * local_rank
layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)

for i in range(model.config.num_hidden_layers):
if i < layer_start or i >= layer_end:
model._modules['model'].layers[i] = Dummy_DecoderLayer()
else:
# align layer_idx and len(past_key_values), otherwise abnormal output
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start

if local_rank != 0:
model._modules['model'].embed_tokens = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()
layer_end = layer_start + min(slice_size, num_layers - layer_start)

if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
# for chatglm3-6b
for i in range(num_layers):
if i < layer_start or i >= layer_end:
model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
else:
model._modules['transformer'].encoder.layers[i].self_attention.num_layers = \
i - layer_start

if local_rank != 0:
model._modules['transformer'].embedding = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['transformer'].encoder.final_layernorm = DummyLayer()
model._modules['transformer'].output_layer = DummyLayer()
else:
for i in range(num_layers):
if i < layer_start or i >= layer_end:
model._modules['model'].layers[i] = Dummy_DecoderLayer()
else:
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start

if local_rank != 0:
model._modules['model'].embed_tokens = DummyLayer()
if local_rank != pipeline_parallel_stages - 1:
model._modules['model'].norm = DummyLayer()
model._modules['lm_head'] = DummyLayer()

model.pipeline_parallel_stages = pipeline_parallel_stages
model = model.to(f'xpu:{local_rank}')
Expand Down Expand Up @@ -176,6 +210,7 @@ def pipeline_parallel_generate(self,

global layer_start
global layer_end
global num_layers

self.first_token_time = 0
self.next_token_time = []
Expand Down
Loading