From ad7bebf0767b433e2ca44ed78daa4d9b5d59d200 Mon Sep 17 00:00:00 2001 From: PhaneeshB Date: Sun, 13 Aug 2023 21:15:32 +0530 Subject: [PATCH] add support passing iree flags for LLMs --- apps/language_models/scripts/vicuna.py | 30 ++++++++++++++++----- apps/stable_diffusion/web/ui/stablelm_ui.py | 11 ++++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index b38dea9b4b..fbc408e208 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -123,7 +123,12 @@ action=argparse.BooleanOptionalAction, help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk", ) - +parser.add_argument( + "--iree_vulkan_target_triple", + type=str, + default="", + help="Specify target triple for vulkan.", +) # fmt: off def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: @@ -160,11 +165,13 @@ def __init__( max_num_tokens=512, device="cpu", precision="int8", + extra_args_cmd=[], ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device self.precision = precision + self.extra_args = extra_args_cmd def get_tokenizer(self): # Retrieve the tokenizer from Huggingface @@ -433,8 +440,9 @@ def __init__( config_json=None, weight_group_size=128, compressed=False, + extra_args_cmd=[], ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) + super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd) self.max_sequence_length = 256 self.device = device self.precision = precision @@ -940,7 +948,7 @@ def compile_to_vmfb_one_model( "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", - ], + ] + self.extra_args, ) module.load_module(vmfb_path) modules.append(module) @@ -1008,7 +1016,7 @@ def compile_to_vmfb_one_model4( "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", - ], + ] + self.extra_args, ) module.load_module(vmfb_path) modules.append(module) @@ -1220,8 +1228,9 @@ def __init__( weight_group_size=128, download_vmfb=False, cache_vicunas=False, + extra_args_cmd=[], ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) + super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd) if "llama2" in self.model_name and hf_auth_token == None: raise ValueError( "HF auth token required. Pass it using --hf_auth_token flag." @@ -1604,7 +1613,7 @@ def compile(self, download_vmfb=False): "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", - ], + ] + self.extra_args, ) print("Saved vic vmfb at ", str(path)) shark_module.load_module(path) @@ -1683,6 +1692,13 @@ def autocomplete(self, prompt): if __name__ == "__main__": args, unknown = parser.parse_known_args() + _extra_args = [] + # vulkan target triple + if args.iree_vulkan_target_triple != "": + _extra_args.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + vic = None if not args.sharded: vic_mlir_path = ( @@ -1706,6 +1722,7 @@ def autocomplete(self, prompt): weight_group_size=args.weight_group_size, download_vmfb=args.download_vmfb, cache_vicunas=args.cache_vicunas, + extra_args_cmd=_extra_args, ) else: if args.config is not None: @@ -1720,6 +1737,7 @@ def autocomplete(self, prompt): precision=args.precision, config_json=config_json, weight_group_size=args.weight_group_size, + extra_args_cmd=_extra_args, ) if args.model_name == "vicuna": system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index cda8466caa..2e9e56ff56 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -180,6 +180,15 @@ def chat( print("unrecognized device") max_toks = 128 if model_name == "codegen" else 512 + + # get iree flags that need to be overridden, from commandline args + _extra_args = [] + # vulkan target triple + if args.iree_vulkan_target_triple != "": + _extra_args.append( + f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" + ) + if model_name == "vicuna4": vicuna_model = ShardedVicuna( model_name, @@ -188,6 +197,7 @@ def chat( precision=precision, max_num_tokens=max_toks, compressed=True, + extra_args_cmd=_extra_args, ) else: # if config_file is None: @@ -198,6 +208,7 @@ def chat( device=device, precision=precision, max_num_tokens=max_toks, + extra_args_cmd=_extra_args, ) # else: # if config_file is not None: