Skip to content

Commit

Permalink
Fix vicuna script (#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Aug 10, 2023
1 parent e4d7abb commit e8c1203
Showing 1 changed file with 5 additions and 42 deletions.
47 changes: 5 additions & 42 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,44 +419,6 @@ def generate_new_token(self, params, sharded=True):

return ret_dict

def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
# crop input_ids
# input_ids = input_ids[len(input_ids) - 20 :]
############
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)

_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)

ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}

print(f" token : {_token} | detok : {_detok}")

return ret_dict


class ShardedVicuna(VicunaBase):
# Class representing Sharded Vicuna Model
Expand Down Expand Up @@ -976,7 +938,7 @@ def compile_to_vmfb_one_model(
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
)
module.load_module(vmfb_path)
Expand Down Expand Up @@ -1044,7 +1006,7 @@ def compile_to_vmfb_one_model4(
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
)
module.load_module(vmfb_path)
Expand Down Expand Up @@ -1640,7 +1602,7 @@ def compile(self, download_vmfb=False):
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
"--iree-opt-const-expr-hoisting=False",
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
],
)
print("Saved vic vmfb at ", str(path))
Expand Down Expand Up @@ -1792,8 +1754,9 @@ def autocomplete(self, prompt):
system_message,
history,
model=model_list[args.model_name],
device=args.device,
devices=args.device,
precision=args.precision,
config_file=None,
cli=args.cli,
)
)[0]

0 comments on commit e8c1203

Please sign in to comment.