Skip to content

Commit

Permalink
add support passing iree flags for LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Aug 14, 2023
1 parent 16f46f8 commit 22d358c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
30 changes: 24 additions & 6 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,12 @@
action=argparse.BooleanOptionalAction,
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk",
)

parser.add_argument(
"--iree_vulkan_target_triple",
type=str,
default="",
help="Specify target triple for vulkan.",
)

# fmt: off
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
Expand Down Expand Up @@ -160,11 +165,13 @@ def __init__(
max_num_tokens=512,
device="cpu",
precision="int8",
extra_args_cmd=[],
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
self.precision = precision
self.extra_args = extra_args_cmd

def get_tokenizer(self):
# Retrieve the tokenizer from Huggingface
Expand Down Expand Up @@ -433,8 +440,9 @@ def __init__(
config_json=None,
weight_group_size=128,
compressed=False,
extra_args_cmd=[],
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
self.max_sequence_length = 256
self.device = device
self.precision = precision
Expand Down Expand Up @@ -940,7 +948,7 @@ def compile_to_vmfb_one_model(
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
] + self.extra_args,
)
module.load_module(vmfb_path)
modules.append(module)
Expand Down Expand Up @@ -1008,7 +1016,7 @@ def compile_to_vmfb_one_model4(
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
] + self.extra_args,
)
module.load_module(vmfb_path)
modules.append(module)
Expand Down Expand Up @@ -1220,8 +1228,9 @@ def __init__(
weight_group_size=128,
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
Expand Down Expand Up @@ -1604,7 +1613,7 @@ def compile(self, download_vmfb=False):
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
] + self.extra_args,
)
print("Saved vic vmfb at ", str(path))
shark_module.load_module(path)
Expand Down Expand Up @@ -1683,6 +1692,13 @@ def autocomplete(self, prompt):
if __name__ == "__main__":
args, unknown = parser.parse_known_args()

_extra_args = []
# vulkan target triple
if args.iree_vulkan_target_triple != "":
_extra_args.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)

vic = None
if not args.sharded:
vic_mlir_path = (
Expand All @@ -1706,6 +1722,7 @@ def autocomplete(self, prompt):
weight_group_size=args.weight_group_size,
download_vmfb=args.download_vmfb,
cache_vicunas=args.cache_vicunas,
extra_args_cmd=_extra_args,
)
else:
if args.config is not None:
Expand All @@ -1720,6 +1737,7 @@ def autocomplete(self, prompt):
precision=args.precision,
config_json=config_json,
weight_group_size=args.weight_group_size,
extra_args_cmd=_extra_args,
)
if args.model_name == "vicuna":
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
Expand Down
11 changes: 11 additions & 0 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def chat(
print("unrecognized device")

max_toks = 128 if model_name == "codegen" else 512

# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
if args.iree_vulkan_target_triple != "":
_extra_args.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)

if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
Expand All @@ -188,6 +197,7 @@ def chat(
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
Expand All @@ -198,6 +208,7 @@ def chat(
device=device,
precision=precision,
max_num_tokens=max_toks,
extra_args_cmd=_extra_args,
)
# else:
# if config_file is not None:
Expand Down

0 comments on commit 22d358c

Please sign in to comment.