Skip to content

Commit

Permalink
add support for external weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Dec 12, 2023
1 parent 3cc643b commit f0d3d42
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 220 deletions.
146 changes: 113 additions & 33 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,171 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
import time
from shark.iree_utils.compile_utils import (
get_iree_compiled_module,
load_vmfb_using_mmap,
)
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
from itertools import chain
import gc
import os
import torch
from transformers import AutoTokenizer

llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
}
}


class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
self,
model_name,
hf_auth_token=None,
device=None,
precision="fp32",
external_weights=None,
external_weight_file=None,
use_system_prompt=True,
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()

self.vmfb_name = get_resource_path("llm.vmfb.tempfile")
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
self.external_weight_file = external_weight_file
self.use_system_prompt = use_system_prompt
self.global_iter = 0
if os.path.exists(self.vmfb_name):
self.iree_module_dict = dict()
(
self.iree_module_dict["vmfb"],
self.iree_module_dict["config"],
self.iree_module_dict["temp_file_to_unlink"],
) = load_vmfb_using_mmap(
self.vmfb_name,
device,
device_idx=0,
rt_flags=[],
external_weight_file=external_weight_file,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
external_weight_file=external_weight_file,
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_name,
use_fast=False,
use_auth_token=hf_auth_token,
)
self.compile()

def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
self.tempfile_name,
device=self.device,
mmap=True,
frontend="torch",
external_weight_file=self.external_weight_file,
write_to=self.vmfb_name,
)
# TODO: delete the temp file

def sanitize_prompt(self, prompt):
print(prompt)
if isinstance(prompt, list):
prompt = list(chain.from_iterable(prompt))
prompt = " ".join([x for x in prompt if isinstance(x, str)])
prompt = prompt.replace("\n", " ")
prompt = prompt.replace("\t", " ")
prompt = prompt.replace("\r", " ")
if self.use_system_prompt and self.global_iter == 0:
prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt
prompt += " [/INST]"
print(prompt)
return prompt

def chat(self, prompt):
prompt = self.sanitize_prompt(prompt)

input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids

def format_out(results):
return torch.tensor(results.to_host()[0][0])

history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
st_time = time.time()
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device, input_tensor
)
]
token = self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"].device,
token,
)
]
token = self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
)

history.append(token)
yield self.tokenizer.decode(history)
total_time = time.time() - st_time
history.append(format_out(token))
yield self.tokenizer.decode(history), total_time

if token == llm_model_map["llama2_7b"]["stop_token"]:
if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]:
break

for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
self.global_iter += 1
return result_output, total_time


if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
external_weights="safetensors",
external_weight_file="llama2_7b.safetensors",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
for i in lm.chat("hi, what are you?"):
print(i)
2 changes: 1 addition & 1 deletion apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs):
)

with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
Expand Down
Loading

0 comments on commit f0d3d42

Please sign in to comment.