Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Nov 1, 2024
1 parent eda7649 commit a63cc49
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Arguments info:
- `--max-context-len MAX_CONTEXT_LEN`: Defines the maximum sequence length for both input and output tokens. It is default to be `1024`.
- `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`.
- `--disable-transpose-value-cache`: Disable the optimization of transposing value cache.
- `--disable-streaming`: Disable streaming mode of generation.

### Sample Output
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -61,6 +61,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down Expand Up @@ -92,6 +93,11 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)

DEFAULT_SYSTEM_PROMPT = """\
"""

Expand All @@ -105,7 +111,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
_input_ids, max_new_tokens=args.n_predict, do_print=True, streamer=streamer
)
end = time.time()
print(f"Inference time: {end-st} s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -62,6 +62,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down Expand Up @@ -91,6 +92,11 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],

if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)

DEFAULT_SYSTEM_PROMPT = """\
"""
Expand All @@ -105,7 +111,7 @@ def get_prompt(message: str, chat_history: list[tuple[str, str]],
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
_input_ids, max_new_tokens=args.n_predict, do_print=True, streamer=streamer
)
end = time.time()
print(f"Inference time: {end-st} s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -68,6 +68,7 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down Expand Up @@ -98,6 +99,11 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)

print("-" * 80)
print("done")
with torch.inference_mode():
Expand All @@ -108,7 +114,7 @@ def get_prompt(user_input: str, chat_history: list[tuple[str, str]],
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
_input_ids, max_new_tokens=args.n_predict, do_print=True, streamer=streamer
)
end = time.time()
print(f"Inference time: {end-st} s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging
import os

Expand Down Expand Up @@ -48,6 +48,7 @@
parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down Expand Up @@ -79,6 +80,11 @@
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)

print("-" * 80)
print("done")
with torch.inference_mode():
Expand All @@ -89,7 +95,7 @@
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
_input_ids, max_new_tokens=args.n_predict, do_print=True, streamer=streamer
)
end = time.time()
print(f"Inference time: {end-st} s")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import time
import argparse
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoTokenizer, TextStreamer
from transformers.utils import logging

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -50,6 +50,7 @@
parser.add_argument('--load_in_low_bit', type=str, default="sym_int4",
help='Load in low bit to use')
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False)

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down Expand Up @@ -81,6 +82,11 @@
if args.lowbit_path and not os.path.exists(args.lowbit_path):
model.save_low_bit(args.lowbit_path)

if args.disable_streaming:
streamer = None
else:
streamer = TextStreamer(tokenizer=tokenizer, skip_special_tokens=True)

print("-" * 80)
print("done")
messages = [{"role": "system", "content": "You are a helpful assistant."},
Expand All @@ -95,7 +101,7 @@
print("input length:", len(_input_ids[0]))
st = time.time()
output = model.generate(
_input_ids, max_new_tokens=args.n_predict, do_print=True
_input_ids, max_new_tokens=args.n_predict, do_print=True, streamer=streamer
)
end = time.time()
print(f"Inference time: {end-st} s")
Expand Down

0 comments on commit a63cc49

Please sign in to comment.