Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Nov 14, 2023
1 parent a3deeec commit 4302b23
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 199 deletions.
65 changes: 42 additions & 23 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,24 @@
import torch

llm_model_map = {
"llama2_7b": {"initializer":stateless_llama.export_transformer_model,
"hf_model_name":"meta-llama/Llama-2-7b-chat-hf",
"stop_token":2,
"max_tokens":4096,
}

"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
}


class LanguageModel():
def __init__(self, model_name, hf_auth_token=None, device=None, precision="fp32"):
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](self.hf_model_name, hf_auth_token, compile_to="torch")
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
Expand All @@ -33,20 +37,35 @@ def __init__(self, model_name, hf_auth_token=None, device=None, precision="fp32"
self.compile()

def compile(self) -> None:
#this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(self.tempfile_name, device=self.device, frontend="torch")
#TODO: delete the temp file
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file

def chat(self, prompt):

history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
device_inputs = [ireert.asdevicearray(self.iree_module_dict["config"], input_tensor)]
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs).to_host()[0][0])
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(self.iree_module_dict["vmfb"]["run_forward"](*device_inputs).to_host()[0][0])
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)

history.append(token)
yield self.tokenizer.decode(history)
Expand All @@ -61,12 +80,12 @@ def chat(self, prompt):
yield result_output






if __name__ == "__main__":
lm = LanguageModel("llama2_7b", hf_auth_token="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", device="cpu-task")
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)
print(i)
3 changes: 2 additions & 1 deletion apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
def get_available_devices():
return ["cpu-task"]


def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
return os.path.join(base_path, relative_path)
Loading

0 comments on commit 4302b23

Please sign in to comment.