Skip to content

Commit

Permalink
Add support for StableLM-3B model (#2019)
Browse files Browse the repository at this point in the history
* Add support for StableLM-3B model

* Add support for Quantized StableLM-3B model

* Update stablelm_pipeline.py
  • Loading branch information
vivekkhandelwal1 authored Dec 12, 2023
1 parent bf70e80 commit 3cc643b
Showing 1 changed file with 137 additions and 24 deletions.
161 changes: 137 additions & 24 deletions apps/language_models/src/pipelines/stablelm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,49 @@
from io import BytesIO
from pathlib import Path
from apps.language_models.utils import (
get_torch_mlir_module_bytecode,
get_vmfb_from_path,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.stablelm_model import (
StableLMModel,
)
import argparse

parser = argparse.ArgumentParser(
prog="stablelm runner",
description="runs a StableLM model",
)

parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
)
parser.add_argument(
"--stablelm_mlir_path",
default=None,
help="path to StableLM's mlir file",
)
parser.add_argument(
"--use_precompiled_model",
default=True,
action=argparse.BooleanOptionalAction,
help="use the precompiled vmfb",
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for stablelm-3B model.",
)


class StopOnTokens(StoppingCriteria):
Expand All @@ -29,14 +65,22 @@ def __init__(
self,
model_name,
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
max_num_tokens=512,
max_num_tokens=256,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
if precision != "int4" and args.hf_auth_token == None:
raise ValueError(
""" HF auth token required for StableLM-3B. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = args.hf_auth_token

self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
Expand All @@ -50,9 +94,23 @@ def shouldStop(self, tokens):
return False

def get_src_model(self):
kwargs = {}
if self.precision == "int4":
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
from transformers import GPTQConfig

quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
print("[DEBUG] Loading Model")
model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, torch_dtype=torch.float32
self.hf_model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
use_auth_token=self.hf_auth_token,
**kwargs,
)
print("[DEBUG] Model loaded successfully")
return model

def get_model_inputs(self):
Expand All @@ -61,9 +119,7 @@ def get_model_inputs(self):
return input_ids, attention_mask

def compile(self):
tmp_model_name = (
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
)
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"

# device = "cuda" # "cpu"
# TODO: vmfb and mlir name should include precision and device
Expand All @@ -83,13 +139,19 @@ def compile(self):
print(
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if not mlir_path.exists():
model = StableLMModel(self.get_src_model())
model_inputs = self.get_model_inputs()
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
from shark.shark_importer import import_with_fx

ts_graph = import_with_fx(
model,
model_inputs,
is_f16=True if self.precision in ["fp16"] else False,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
module = torch_mlir.compile(
ts_graph,
[*model_inputs],
Expand All @@ -100,15 +162,16 @@ def compile(self):
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
f_ = open(tmp_model_name + ".mlir", "wb")
f_.write(bytecode)
print("Saved mlir")
f_.close()
f_ = open(mlir_path, "wb")
f_.write(bytecode)
print("Saved mlir at: ", mlir_path)
f_.close()
del bytecode

from shark.shark_inference import SharkInference

shark_module = SharkInference(
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
)
shark_module.compile()

Expand All @@ -120,14 +183,22 @@ def compile(self):
return shark_module

def get_tokenizer(self):
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
tok = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_auth_token=self.hf_auth_token,
)
tok.add_special_tokens({"pad_token": "<PAD>"})
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
return tok

def generate(self, prompt):
words_list = []
import time

start = time.time()
count = 0
for i in range(self.max_num_tokens):
count = count + 1
params = {
"new_text": prompt,
}
Expand All @@ -145,6 +216,12 @@ def generate(self, prompt):
if detok == "":
break
prompt = prompt + detok
end = time.time()
print(
"\n\nTime taken is {:.2f} tokens/second\n".format(
count / (end - start)
)
)
return words_list

def generate_new_token(self, params):
Expand Down Expand Up @@ -178,10 +255,46 @@ def generate_new_token(self, params):
return ret_dict


# Initialize a StopOnTokens object
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
if __name__ == "__main__":
args = parser.parse_args()

stable_lm = SharkStableLM(
model_name="stablelm_zephyr_3b",
hf_model_path="stabilityai/stablelm-zephyr-3b",
device=args.device,
precision=args.precision,
)

default_prompt_text = "The weather is always wonderful"
continue_execution = True

print("\n-----\nScript executing for the following config: \n")
print("StableLM Model: ", stable_lm.hf_model_path)
print("Precision: ", args.precision)
print("Device: ", args.device)

while continue_execution:
use_default_prompt = input(
"\nDo you wish to use the default prompt text? Y/N ?: "
)
if use_default_prompt in ["Y", "y"]:
prompt = default_prompt_text
else:
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)

res_str = stable_lm.generate(prompt)
torch.cuda.empty_cache()
import gc

gc.collect()
print(
"\n\n-----\nHere's the complete formatted result: \n\n",
prompt + "".join(res_str),
)
continue_execution = input(
"\nDo you wish to run script one more time? Y/N ?: "
)
continue_execution = (
True if continue_execution in ["Y", "y"] else False
)

0 comments on commit 3cc643b

Please sign in to comment.