Skip to content

Commit

Permalink
Add opt_causallm_samples.py. (#1916)
Browse files Browse the repository at this point in the history
  • Loading branch information
godot73 authored Oct 25, 2023
1 parent 841773f commit 0c38c33
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 38 deletions.
61 changes: 36 additions & 25 deletions tank/examples/opt/opt_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from typing import Iterable


def create_module(model_name, tokenizer, device, args):
Expand Down Expand Up @@ -70,11 +71,11 @@ def shouldStop(tokens):
return False


def generate_new_token(shark_model, tokenizer, new_text, args):
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
model_inputs = tokenizer(
new_text,
padding="max_length",
max_length=args.max_seq_len,
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
Expand All @@ -83,7 +84,7 @@ def generate_new_token(shark_model, tokenizer, new_text, args):
model_inputs["attention_mask"],
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = shark_model("forward", inputs)
output = shark_module("forward", inputs)
output = torch.FloatTensor(output[0])
next_toks = torch.topk(output, 1)
stop_generation = False
Expand Down Expand Up @@ -135,6 +136,34 @@ def parse_args():
return args


def generate_tokens(
opt_shark_module: "SharkInference",
tokenizer,
input_text: str,
max_output_len: int,
print_intermediate_results: True,
) -> Iterable[str]:
words_list = []
new_text = input_text
try:
for _ in range(max_output_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text, max_output_len
)
detok = generated_token_op["detok"]
if generated_token_op["stop_generation"]:
break
if print_intermediate_results:
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text += detok
except KeyboardInterrupt as e:
print("Exiting token generation.")
return words_list


if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
Expand All @@ -155,25 +184,7 @@ def parse_args():
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
opt_shark_module.load_module(vmfb_path)
while True:
try:
new_text = input("Give me a sentence to complete:")
new_text_init = new_text
words_list = []

for i in range(args.max_seq_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text, args
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok

except KeyboardInterrupt:
print("Exiting program.")
break
input_text = input("Give me a sentence to complete:")
generate_tokens(
opt_shark_module, tokenizer, input_text, args.max_seq_len
)
74 changes: 74 additions & 0 deletions tank/examples/opt/opt_causallm_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import os

import opt_causallm
import opt_util

from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, OPTForCausalLM


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args


if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = opt_causallm.create_module(
args.model_name, tokenizer, "cpu-task", args
)
opt_shark_module.load_module(vmfb_path)

for prompt in opt_util.PROMPTS:
print("\n\nprompt: {}".format(prompt))
response = opt_causallm.generate_tokens(
opt_shark_module,
tokenizer,
prompt,
args.max_seq_len,
print_intermediate_results=False,
)
print("reponse: {}".format("".join(response)))
14 changes: 1 addition & 13 deletions tank/examples/opt/opt_perf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
from typing import Tuple

from opt_util import PROMPTS
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
Expand All @@ -44,19 +45,6 @@
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"

PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]

ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])


Expand Down
12 changes: 12 additions & 0 deletions tank/examples/opt/opt_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]

0 comments on commit 0c38c33

Please sign in to comment.