Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU] Qwen2 divide linear in decode layers #12089

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions python/llm/dev/benchmark/all-in-one/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,26 +614,28 @@ def transformers_int4_npu_win(repo_id,
# Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format
st = time.perf_counter()
transpose_value = True
if repo_id in CHATGLM_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=True,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value,
torch_dtype=torch.float16, attn_implementation="eager").eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=True,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value,
use_cache=True, attn_implementation="eager").eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True, torch_dtype=torch.float16,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=True,
optimize_model=optimize_model, max_output_len=max_output_len, max_prompt_len=int(in_out_len[0]), transpose_value_cache=transpose_value,
use_cache=True, attn_implementation="eager").eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter()
load_time = end - st
print(">> loading of model costs {}s".format(load_time))

model = BenchmarkWrapper(model)
print(model)
model = BenchmarkWrapper(model, do_print=True)

result = {}
with torch.inference_mode():
Expand All @@ -658,7 +660,7 @@ def transformers_int4_npu_win(repo_id,
min_new_tokens=out_len, num_beams=num_beams)
end = time.perf_counter()
print("model generate cost: " + str(end - st))
output = tokenizer.batch_decode(output_ids)
output = tokenizer.batch_decode(output_ids[:, actual_in_len:])
print(output[0])
actual_out_len = output_ids.shape[1] - actual_in_len
if i >= warm_up:
Expand Down
12 changes: 11 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def optimize_llm_pre(model: torch.nn.Module, qtype):

if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj
model.apply(split_mlp_down_proj)
from ipex_llm.transformers.npu_models.qwen2_mp import split_linears
# model.apply(split_mlp_down_proj)
n_splits_linear = max(1, int(os.environ.get("IPEX_LLM_N_SPLITS_LINEAR", "1")))
n_splits_down_proj = max(1, int(os.environ.get("IPEX_LLM_N_SPLITS_DOWN_PROJ", "1")))

if n_splits_down_proj == 1 and model.config.intermediate_size == 18944:
n_splits_down_proj = 2

model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
n_splits_down_proj=n_splits_down_proj))
print(model)

# for Qwen2-7B-Insturct, divide lm_head into 14 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
Expand Down
195 changes: 158 additions & 37 deletions python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def run_model(

class LLMBaseNNFactory(NNFactory):

def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU"):
def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU",
n_splits_linear=1, n_splits_down_proj=1):
super().__init__(profile, device)
self.cache_parameter_ops = []
self.input_ops = []
Expand All @@ -104,6 +105,12 @@ def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="N
self.max_seq_len = max_seq_len
self.transpose_value = transpose_value
self.dtype = dtype
self.n_splits_linear=n_splits_linear
self.n_splits_down_proj=n_splits_down_proj

def reduce_linear(self, to_concat):
concat = self.sequence_concat(to_concat, axis=0)
return self.reduce_sum(concat, reduction_axes=0, keep_dims=True)

def attention(self,
*,
Expand All @@ -124,31 +131,79 @@ def attention(self,
v_bias=None):
hidden_size = num_heads * head_dim
num_key_value_groups = num_heads // num_key_value_heads
query_states = self.linear(
hidden_states,
num_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)
groupsize = hidden_size // self.n_splits_linear
# print(f"hidden states: {hidden_states.shape}")
if self.n_splits_linear == 1:
query_states = self.linear(
hidden_states,
num_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)

key_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)

value_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)
else:
query_states_to_concat = []
key_states_to_concat = []
value_states_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, i * groupsize],
end=[seq_len, (i + 1) * groupsize])
query_states_to_concat.append(
self.linear(
sub_hidden_states,
num_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
)
)
key_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
)
)
value_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
)
)
if mode == "decode":
query_states = self.reduce_linear(query_states_to_concat)
key_states = self.reduce_linear(key_states_to_concat)
value_states = self.reduce_linear(value_states_to_concat)
else:
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
if q_bias is not None:
query_states = query_states + q_bias
key_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)
if k_bias is not None:
key_states = key_states + k_bias
value_states = self.linear(
hidden_states,
num_key_value_heads * head_dim,
hidden_size,
bias=False,
wt_dtype=self.dtype,
)
if v_bias is not None:
value_states = value_states + v_bias

Expand Down Expand Up @@ -215,23 +270,89 @@ def attention(self,
attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])

attn_output = self.linear(
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
)

# print(f"attn_output: {attn_output.shape}")
if self.n_splits_linear == 1:
attn_output = self.linear(
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
)
else:
attn_output_to_concat = []
for i in range(self.n_splits_linear):
sub_attn_output = self.slice(attn_output,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
attn_output_to_concat.append(
self.linear(
sub_attn_output, hidden_size, groupsize, bias=False, wt_dtype=self.dtype
)
)
if mode == "decode":
attn_output = self.reduce_linear(attn_output_to_concat)
else:
attn_output = sum(attn_output_to_concat)

return attn_output, new_key_states, new_value_states

def mlp(self, hidden_states):
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
)
def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
print(f"mp_models_base mlp")
use_concat_reduce = (mode == "decode" and False)
if self.n_splits_linear == 1:
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
gate_up_groupsize = self.hidden_size // self.n_splits_linear
mm1_to_concat = []
mm2_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states, begin=[0, 0, i * gate_up_groupsize],
end=[1, seq_len, (i + 1) * gate_up_groupsize])
mm1_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize, bias=False, wt_dtype=self.dtype
)
)
mm2_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize, bias=False, wt_dtype=self.dtype
)
)
if use_concat_reduce:
mm1 = self.reduce_linear(mm1_to_concat)
mm2 = self.reduce_linear(mm2_to_concat)
else:
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]

if self.n_splits_down_proj == 1:
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
)
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
down_groupsize = self.intermediate_size // self.n_splits_down_proj
hidden_states_to_concat = []
for i in range(self.n_splits_down_proj):
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
end=[1, seq_len, (i + 1) * down_groupsize])
hidden_states_to_concat.append(
self.linear(
sub_mm1, self.hidden_size, down_groupsize, bias=False, wt_dtype=self.dtype
)
)
# print(hidden_states_to_concat[0].shape)
# hidden_states = self.concat_list(hidden_states_to_concat, 0)
# hidden_states = self.reduce_sum(hidden_states, 0)
if use_concat_reduce:
hidden_states = self.reduce_linear(hidden_states_to_concat)
else:
hidden_states = sum(hidden_states_to_concat)
return hidden_states

def layer_norm(self, hidden_states, layernorm_weight):
Expand Down
Loading
Loading