diff --git a/tank/examples/opt/opt_causallm.py b/tank/examples/opt/opt_causallm.py index 210671c5ed..d1b87d0a5e 100644 --- a/tank/examples/opt/opt_causallm.py +++ b/tank/examples/opt/opt_causallm.py @@ -1,3 +1,4 @@ +import argparse import os import torch import numpy as np @@ -10,21 +11,16 @@ from shark.shark_importer import import_with_fx from transformers import AutoTokenizer, OPTForCausalLM -OPT_MODEL = "opt-1.3b" -OPT_FS_NAME = "opt-1_3b" -MAX_SEQUENCE_LENGTH = 128 -MAX_NEW_TOKENS = 60 - -def create_module(model_name, tokenizer, device): - opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name) +def create_module(model_name, tokenizer, device, args): + opt_base_model = OPTForCausalLM.from_pretrained(model_name) opt_base_model.eval() opt_model = OPTForCausalLMModel(opt_base_model) encoded_inputs = tokenizer( "What is the meaning of life?", padding="max_length", truncation=True, - max_length=MAX_SEQUENCE_LENGTH, + max_length=args.max_seq_len, return_tensors="pt", ) inputs = ( @@ -33,8 +29,11 @@ def create_module(model_name, tokenizer, device): ) # np.save("model_inputs_0.npy", inputs[0]) # np.save("model_inputs_1.npy", inputs[1]) + opt_fs_name = "-".join( + "_".join(args.model_name.split("/")[1].split("-")).split(".") + ) - mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir" + mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir" if os.path.isfile(mlir_path): print(f"Found .mlir from {mlir_path}") else: @@ -42,7 +41,7 @@ def create_module(model_name, tokenizer, device): model=opt_model, inputs=inputs, is_f16=False, - model_name=OPT_FS_NAME, + model_name=opt_fs_name, return_str=True, ) with open(mlir_path, "w") as f: @@ -57,7 +56,7 @@ def create_module(model_name, tokenizer, device): is_benchmark=False, ) - vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}" + vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu" shark_module.save_module(module_name=vmfb_name, debug=False) vmfb_path = vmfb_name + ".vmfb" return vmfb_path @@ -71,11 +70,11 @@ def shouldStop(tokens): return False -def generate_new_token(shark_model, tokenizer, new_text): +def generate_new_token(shark_model, tokenizer, new_text, args): model_inputs = tokenizer( new_text, padding="max_length", - max_length=MAX_SEQUENCE_LENGTH, + max_length=args.max_seq_len, truncation=True, return_tensors="pt", ) @@ -104,18 +103,56 @@ def generate_new_token(shark_model, tokenizer, new_text): return ret_dict +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--max-seq-len", type=int, default=32) + parser.add_argument( + "--model-name", + help="Model name", + type=str, + choices=[ + "facebook/opt-125m", + "facebook/opt-350m", + "facebook/opt-1.3b", + "facebook/opt-6.7b", + ], + default="facebook/opt-1.3b", + ) + parser.add_argument( + "--recompile", + help="If set, recompiles MLIR -> .vmfb", + action=argparse.BooleanOptionalAction, + default=False, + ) + parser.add_argument( + "--plugin-path", + help="path to executable plugin", + type=str, + default=None, + ) + args = parser.parse_args() + print("args={}".format(args)) + return args + + if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained( - "facebook/" + OPT_MODEL, use_fast=False + args = parse_args() + tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False) + opt_fs_name = "-".join( + "_".join(args.model_name.split("/")[1].split("-")).split(".") ) - vmfb_path = ( - f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb" + vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb" + if args.plugin_path is not None: + rt_flags = [f"--executable_plugin={args.plugin_path}"] + else: + rt_flags = [] + opt_shark_module = SharkInference( + mlir_module=None, device="cpu-task", rt_flags=rt_flags ) - opt_shark_module = SharkInference(mlir_module=None, device="cpu-task") if os.path.isfile(vmfb_path): opt_shark_module.load_module(vmfb_path) else: - vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task") + vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args) opt_shark_module.load_module(vmfb_path) while True: try: @@ -123,9 +160,9 @@ def generate_new_token(shark_model, tokenizer, new_text): new_text_init = new_text words_list = [] - for i in range(MAX_NEW_TOKENS): + for i in range(args.max_seq_len): generated_token_op = generate_new_token( - opt_shark_module, tokenizer, new_text + opt_shark_module, tokenizer, new_text, args ) detok = generated_token_op["detok"] stop_generation = generated_token_op["stop_generation"] diff --git a/tank/examples/opt/opt_perf_comparison.py b/tank/examples/opt/opt_perf_comparison.py index 8d95934e2b..f105183449 100644 --- a/tank/examples/opt/opt_perf_comparison.py +++ b/tank/examples/opt/opt_perf_comparison.py @@ -147,7 +147,7 @@ def load_shark_model( plugin_path: str = [], ) -> ModelWrapper: opt_fs_name = get_opt_fs_name(model_name) - vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb" + vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) if recompile_shark or not os.path.isfile(vmfb_name): print(f"vmfb not found. compiling and saving to {vmfb_name}") @@ -344,7 +344,7 @@ def parse_args(): default=PLATFORM_SHARK, ) parser.add_argument( - "--plugin_path", + "--plugin-path", help="path to executable plugin", type=str, default=None,