Skip to content

Commit

Permalink
Dan shark studio (#1970)
Browse files Browse the repository at this point in the history
* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
2 people authored and monorimet committed Nov 14, 2023
1 parent c216348 commit a3deeec
Show file tree
Hide file tree
Showing 7 changed files with 1,041 additions and 5 deletions.
72 changes: 72 additions & 0 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch

llm_model_map = {
"llama2_7b": {"initializer":stateless_llama.export_transformer_model,
"hf_model_name":"meta-llama/Llama-2-7b-chat-hf",
"stop_token":2,
"max_tokens":4096,
}

}


class LanguageModel():
def __init__(self, model_name, hf_auth_token=None, device=None, precision="fp32"):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()

self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()

def compile(self) -> None:
#this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(self.tempfile_name, device=self.device, frontend="torch")
#TODO: delete the temp file

def chat(self, prompt):

history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids
device_inputs = [ireert.asdevicearray(self.iree_module_dict["config"], input_tensor)]
if iter == 0:
token = torch.tensor(self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs).to_host()[0][0])
else:
token = torch.tensor(self.iree_module_dict["vmfb"]["run_forward"](*device_inputs).to_host()[0][0])

history.append(token)
yield self.tokenizer.decode(history)

if token == llm_model_map["llama2_7b"]["stop_token"]:
break

for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output






if __name__ == "__main__":
lm = LanguageModel("llama2_7b", hf_auth_token="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", device="cpu-task")
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)
13 changes: 13 additions & 0 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
import sys


def get_available_devices():
return ["cpu-task"]

def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
Loading

0 comments on commit a3deeec

Please sign in to comment.