Skip to content

Commit

Permalink
added functionality for int8 vicuna and 4 shards (#1712)
Browse files Browse the repository at this point in the history
combined vicuna_4_shards.py and vicuna.py to reduce code duplication

Co-authored-by: Elias Joseph <elias@nod-labs.com>
  • Loading branch information
Eliasj42 and Elias Joseph authored Aug 4, 2023
1 parent 7fe57eb commit ed484b8
Show file tree
Hide file tree
Showing 4 changed files with 1,154 additions and 32 deletions.
241 changes: 229 additions & 12 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna4 import(
LlamaModel,
EightLayerLayerSV,
EightLayerLayerFV,
CompiledEightLayerLayerSV,
CompiledEightLayerLayer,
forward_compressed,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna,
Expand Down Expand Up @@ -410,6 +418,44 @@ def generate_new_token(self, params, sharded=True):

return ret_dict

def generate_new_token(self, params):
is_first = params["is_first"]
if is_first:
prompt = params["prompt"]
input_ids = self.tokenizer(prompt).input_ids
# crop input_ids
# input_ids = input_ids[len(input_ids) - 20 :]
############
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(input_ids, is_first=is_first)
else:
token = params["token"]
past_key_values = params["past_key_values"]
input_ids = [token]
input_id_len = len(input_ids)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.reshape([1, input_id_len])
output = self.shark_model.forward(
input_ids, past_key_values=past_key_values, is_first=is_first
)

_logits = output["logits"]
_past_key_values = output["past_key_values"]
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
_detok = self.tokenizer.decode(_token)

ret_dict = {
"token": _token,
"detok": _detok,
"past_key_values": _past_key_values,
}

print(f" token : {_token} | detok : {_detok}")

return ret_dict


class ShardedVicuna(VicunaBase):
# Class representing Sharded Vicuna Model
Expand All @@ -422,6 +468,7 @@ def __init__(
precision="fp32",
config_json=None,
weight_group_size=128,
compressed=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
Expand All @@ -430,7 +477,9 @@ def __init__(
self.tokenizer = self.get_tokenizer()
self.config = config_json
self.weight_group_size = weight_group_size
self.compressed=compressed
self.shark_model = self.compile(device=device)


def get_tokenizer(self):
kwargs = {}
Expand Down Expand Up @@ -542,6 +591,59 @@ def compile_vicuna_layer(
)
return mlir_bytecode

def compile_vicuna_layer4(
self,
vicuna_layer,
hidden_states,
attention_mask,
position_ids,
past_key_values=None,
):
# Compile a hidden decoder layer of vicuna
if past_key_values is None:
model_inputs = (hidden_states, attention_mask, position_ids)
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_values

model_inputs = (
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
)
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
return mlir_bytecode

def get_device_index(self, layer_string):
# Get the device index from the config file
# In the event that different device indices are assigned to
Expand Down Expand Up @@ -858,28 +960,119 @@ def compile_to_vmfb_one_model(
modules.append(module)
return mlirs, modules

def get_sharded_model(self, device="cpu"):
def compile_to_vmfb_one_model4(
self, inputs0, layers0, inputs1, layers1, device="cpu"
):
mlirs, modules = [], []
assert len(layers0) == len(layers1)
for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))):
mlir_path = Path(f"{idx}_full.mlir")
vmfb_path = Path(f"{idx}_full.vmfb")
# if vmfb_path.exists():
# continue
if mlir_path.exists():
# print(f"Found layer {idx} mlir")
f_ = open(mlir_path, "rb")
bytecode = f_.read()
f_.close()
mlirs.append(bytecode)
else:
command = f"gsutil cp gs://shark_tank/elias/compressed_sv/{idx}_full.mlir {idx}_full.mlir"

subprocess.check_call(command.split())

f_ = open(f"{idx}_full.mlir", "rb")
bytecode = f_.read()
f_.close()
mlirs.append(bytecode)



if vmfb_path.exists():
# print(f"Found layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
None,
device=device,
device_idx=0,
mlir_dialect="tm_tensor",
mmap=True,
)
module.load_module(vmfb_path)
else:
print(f"Compiling layer {idx} vmfb")
device_idx = self.get_device_index(
f"first_vicuna.model.model.layers.{idx}[\s.$]"
)
module = SharkInference(
mlirs[idx],
device=device,
device_idx=0,
mlir_dialect="tm_tensor",
mmap=True,
)
module.save_module(
module_name=f"{idx}_full",
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
)
module.load_module(vmfb_path)
modules.append(module)
return mlirs, modules

def get_sharded_model(self, device="cpu", compressed=False):
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()
if compressed:
vicuna_model.model = LlamaModel.from_pretrained(
"TheBloke/vicuna-7B-1.1-HF"
)

if self.precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if self.precision == "int4" else 8
quantize_model(
get_model_impl(vicuna_model).layers,
dtype=torch.float32,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")

placeholder_pkv_segment = tuple(
(
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
for _ in range(8)
)
placeholder_pkv_full = tuple(
(
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
)
for _ in range(32)
)

placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
Expand Down Expand Up @@ -930,20 +1123,40 @@ def get_sharded_model(self, device="cpu"):
device_idx=device_idx,
)

layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]
if not compressed:

layers0 = [
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
]
layers1 = [
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
]

else:
layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8])
layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16])
layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24])
layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32])
layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8])
layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16])
layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24])
layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32])
layers0 = [layers00, layers01, layers02, layers03]
layers1 = [layers10, layers11, layers12, layers13]

_, modules = self.compile_to_vmfb_one_model(
placeholder_input0,
layers0,
placeholder_input1,
layers1,
device=device,
)
shark_layers = [CompiledVicunaLayer(m) for m in modules]

if not compressed:
shark_layers = [CompiledVicunaLayer(m) for m in modules]
else:
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
vicuna_model.model.compressedlayers = shark_layers

sharded_model = ShardedVicunaModel(
vicuna_model,
Expand All @@ -955,11 +1168,13 @@ def get_sharded_model(self, device="cpu"):
return sharded_model

def compile(self, device="cpu"):
return self.get_sharded_model(device=device)
return self.get_sharded_model(device=device, compressed=self.compressed)

def generate(self, prompt, cli=True):
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration

history = []

tokens_generated = []
_past_key_values = None
_token = None
Expand All @@ -977,6 +1192,8 @@ def generate(self, prompt, cli=True):
_token = generated_token_op["token"]
_past_key_values = generated_token_op["past_key_values"]
_detok = generated_token_op["detok"]
history.append(_token)
yield self.tokenizer.decode(history)

if _token == 2:
break
Expand All @@ -987,7 +1204,7 @@ def generate(self, prompt, cli=True):
if type(tokens_generated[i]) != int:
tokens_generated[i] = int(tokens_generated[i][0])
result_output = self.tokenizer.decode(tokens_generated)
return result_output
yield result_output

def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
Expand Down
Loading

0 comments on commit ed484b8

Please sign in to comment.