Skip to content

Commit

Permalink
Updates to opt_causallm example (#1905)
Browse files Browse the repository at this point in the history
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
  • Loading branch information
monorimet authored Oct 24, 2023
1 parent 0361db4 commit 841773f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
79 changes: 58 additions & 21 deletions tank/examples/opt/opt_causallm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import os
import torch
import numpy as np
Expand All @@ -10,21 +11,16 @@
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM

OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60


def create_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def create_module(model_name, tokenizer, device, args):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
return_tensors="pt",
)
inputs = (
Expand All @@ -33,16 +29,19 @@ def create_module(model_name, tokenizer, device):
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)

mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
if os.path.isfile(mlir_path):
print(f"Found .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
Expand All @@ -57,7 +56,7 @@ def create_module(model_name, tokenizer, device):
is_benchmark=False,
)

vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
Expand All @@ -71,11 +70,11 @@ def shouldStop(tokens):
return False


def generate_new_token(shark_model, tokenizer, new_text):
def generate_new_token(shark_model, tokenizer, new_text, args):
model_inputs = tokenizer(
new_text,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
truncation=True,
return_tensors="pt",
)
Expand Down Expand Up @@ -104,28 +103,66 @@ def generate_new_token(shark_model, tokenizer, new_text):
return ret_dict


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"facebook/" + OPT_MODEL, use_fast=False
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
opt_shark_module.load_module(vmfb_path)
while True:
try:
new_text = input("Give me a sentence to complete:")
new_text_init = new_text
words_list = []

for i in range(MAX_NEW_TOKENS):
for i in range(args.max_seq_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text
opt_shark_module, tokenizer, new_text, args
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
Expand Down
4 changes: 2 additions & 2 deletions tank/examples/opt/opt_perf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def load_shark_model(
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
Expand Down Expand Up @@ -344,7 +344,7 @@ def parse_args():
default=PLATFORM_SHARK,
)
parser.add_argument(
"--plugin_path",
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
Expand Down

0 comments on commit 841773f

Please sign in to comment.