From e8c1203be2ab67d897331529d4a4faf92084ac21 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 10 Aug 2023 18:41:14 +0530 Subject: [PATCH] Fix vicuna script (#1745) --- apps/language_models/scripts/vicuna.py | 47 +++----------------------- 1 file changed, 5 insertions(+), 42 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 2c2d4cacf6..04d5df23cb 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -419,44 +419,6 @@ def generate_new_token(self, params, sharded=True): return ret_dict - def generate_new_token(self, params): - is_first = params["is_first"] - if is_first: - prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - # crop input_ids - # input_ids = input_ids[len(input_ids) - 20 :] - ############ - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward(input_ids, is_first=is_first) - else: - token = params["token"] - past_key_values = params["past_key_values"] - input_ids = [token] - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward( - input_ids, past_key_values=past_key_values, is_first=is_first - ) - - _logits = output["logits"] - _past_key_values = output["past_key_values"] - _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) - _detok = self.tokenizer.decode(_token) - - ret_dict = { - "token": _token, - "detok": _detok, - "past_key_values": _past_key_values, - } - - print(f" token : {_token} | detok : {_detok}") - - return ret_dict - class ShardedVicuna(VicunaBase): # Class representing Sharded Vicuna Model @@ -976,7 +938,7 @@ def compile_to_vmfb_one_model( "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], ) module.load_module(vmfb_path) @@ -1044,7 +1006,7 @@ def compile_to_vmfb_one_model4( "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], ) module.load_module(vmfb_path) @@ -1640,7 +1602,7 @@ def compile(self, download_vmfb=False): "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-vm-bytecode-module-output-format=flatbuffer-binary", "--iree-opt-const-expr-hoisting=False", - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" + "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807", ], ) print("Saved vic vmfb at ", str(path)) @@ -1792,8 +1754,9 @@ def autocomplete(self, prompt): system_message, history, model=model_list[args.model_name], - device=args.device, + devices=args.device, precision=args.precision, + config_file=None, cli=args.cli, ) )[0]