-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add opt_causallm_samples.py. (#1916)
- Loading branch information
Showing
4 changed files
with
123 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import argparse | ||
import os | ||
|
||
import opt_causallm | ||
import opt_util | ||
|
||
from shark.shark_inference import SharkInference | ||
from transformers import AutoTokenizer, OPTForCausalLM | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--max-seq-len", type=int, default=32) | ||
parser.add_argument( | ||
"--model-name", | ||
help="Model name", | ||
type=str, | ||
choices=[ | ||
"facebook/opt-125m", | ||
"facebook/opt-350m", | ||
"facebook/opt-1.3b", | ||
"facebook/opt-6.7b", | ||
], | ||
default="facebook/opt-1.3b", | ||
) | ||
parser.add_argument( | ||
"--recompile", | ||
help="If set, recompiles MLIR -> .vmfb", | ||
action=argparse.BooleanOptionalAction, | ||
default=False, | ||
) | ||
parser.add_argument( | ||
"--plugin-path", | ||
help="path to executable plugin", | ||
type=str, | ||
default=None, | ||
) | ||
args = parser.parse_args() | ||
print("args={}".format(args)) | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False) | ||
opt_fs_name = "-".join( | ||
"_".join(args.model_name.split("/")[1].split("-")).split(".") | ||
) | ||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb" | ||
if args.plugin_path is not None: | ||
rt_flags = [f"--executable_plugin={args.plugin_path}"] | ||
else: | ||
rt_flags = [] | ||
opt_shark_module = SharkInference( | ||
mlir_module=None, device="cpu-task", rt_flags=rt_flags | ||
) | ||
if os.path.isfile(vmfb_path): | ||
opt_shark_module.load_module(vmfb_path) | ||
else: | ||
vmfb_path = opt_causallm.create_module( | ||
args.model_name, tokenizer, "cpu-task", args | ||
) | ||
opt_shark_module.load_module(vmfb_path) | ||
|
||
for prompt in opt_util.PROMPTS: | ||
print("\n\nprompt: {}".format(prompt)) | ||
response = opt_causallm.generate_tokens( | ||
opt_shark_module, | ||
tokenizer, | ||
prompt, | ||
args.max_seq_len, | ||
print_intermediate_results=False, | ||
) | ||
print("reponse: {}".format("".join(response))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
PROMPTS = [ | ||
"What is the meaning of life?", | ||
"Tell me something you don't know.", | ||
"What does Xilinx do?", | ||
"What is the mass of earth?", | ||
"What is a poem?", | ||
"What is recursion?", | ||
"Tell me a one line joke.", | ||
"Who is Gilgamesh?", | ||
"Tell me something about cryptocurrency.", | ||
"How did it all begin?", | ||
] |