From ed484b8253550b88b59ec42393c789b88adaf3b6 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Fri, 4 Aug 2023 12:05:05 -0700 Subject: [PATCH] added functionality for int8 vicuna and 4 shards (#1712) combined vicuna_4_shards.py and vicuna.py to reduce code duplication Co-authored-by: Elias Joseph --- apps/language_models/scripts/vicuna.py | 241 ++++- .../src/model_wrappers/vicuna4.py | 879 ++++++++++++++++++ .../model_wrappers/vicuna_sharded_model.py | 7 +- apps/stable_diffusion/web/ui/stablelm_ui.py | 59 +- 4 files changed, 1154 insertions(+), 32 deletions(-) create mode 100644 apps/language_models/src/model_wrappers/vicuna4.py diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 0ba3062320..55914d7e51 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -26,6 +26,14 @@ VicunaNorm, VicunaNormCompiled, ) +from apps.language_models.src.model_wrappers.vicuna4 import( + LlamaModel, + EightLayerLayerSV, + EightLayerLayerFV, + CompiledEightLayerLayerSV, + CompiledEightLayerLayer, + forward_compressed, +) from apps.language_models.src.model_wrappers.vicuna_model import ( FirstVicuna, SecondVicuna, @@ -410,6 +418,44 @@ def generate_new_token(self, params, sharded=True): return ret_dict + def generate_new_token(self, params): + is_first = params["is_first"] + if is_first: + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + # crop input_ids + # input_ids = input_ids[len(input_ids) - 20 :] + ############ + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward(input_ids, is_first=is_first) + else: + token = params["token"] + past_key_values = params["past_key_values"] + input_ids = [token] + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward( + input_ids, past_key_values=past_key_values, is_first=is_first + ) + + _logits = output["logits"] + _past_key_values = output["past_key_values"] + _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) + _detok = self.tokenizer.decode(_token) + + ret_dict = { + "token": _token, + "detok": _detok, + "past_key_values": _past_key_values, + } + + print(f" token : {_token} | detok : {_detok}") + + return ret_dict + class ShardedVicuna(VicunaBase): # Class representing Sharded Vicuna Model @@ -422,6 +468,7 @@ def __init__( precision="fp32", config_json=None, weight_group_size=128, + compressed=False, ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 @@ -430,7 +477,9 @@ def __init__( self.tokenizer = self.get_tokenizer() self.config = config_json self.weight_group_size = weight_group_size + self.compressed=compressed self.shark_model = self.compile(device=device) + def get_tokenizer(self): kwargs = {} @@ -542,6 +591,59 @@ def compile_vicuna_layer( ) return mlir_bytecode + def compile_vicuna_layer4( + self, + vicuna_layer, + hidden_states, + attention_mask, + position_ids, + past_key_values=None, + ): + # Compile a hidden decoder layer of vicuna + if past_key_values is None: + model_inputs = (hidden_states, attention_mask, position_ids) + else: + ( + (pkv00, pkv01), + (pkv10, pkv11), + (pkv20, pkv21), + (pkv30, pkv31), + (pkv40, pkv41), + (pkv50, pkv51), + (pkv60, pkv61), + (pkv70, pkv71), + ) = past_key_values + + model_inputs = ( + hidden_states, + attention_mask, + position_ids, + pkv00, + pkv01, + pkv10, + pkv11, + pkv20, + pkv21, + pkv30, + pkv31, + pkv40, + pkv41, + pkv50, + pkv51, + pkv60, + pkv61, + pkv70, + pkv71, + ) + mlir_bytecode = import_with_fx( + vicuna_layer, + model_inputs, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + return mlir_bytecode + def get_device_index(self, layer_string): # Get the device index from the config file # In the event that different device indices are assigned to @@ -858,11 +960,80 @@ def compile_to_vmfb_one_model( modules.append(module) return mlirs, modules - def get_sharded_model(self, device="cpu"): + def compile_to_vmfb_one_model4( + self, inputs0, layers0, inputs1, layers1, device="cpu" + ): + mlirs, modules = [], [] + assert len(layers0) == len(layers1) + for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): + mlir_path = Path(f"{idx}_full.mlir") + vmfb_path = Path(f"{idx}_full.vmfb") + # if vmfb_path.exists(): + # continue + if mlir_path.exists(): + # print(f"Found layer {idx} mlir") + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + mlirs.append(bytecode) + else: + command = f"gsutil cp gs://shark_tank/elias/compressed_sv/{idx}_full.mlir {idx}_full.mlir" + + subprocess.check_call(command.split()) + + f_ = open(f"{idx}_full.mlir", "rb") + bytecode = f_.read() + f_.close() + mlirs.append(bytecode) + + + + if vmfb_path.exists(): + # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + None, + device=device, + device_idx=0, + mlir_dialect="tm_tensor", + mmap=True, + ) + module.load_module(vmfb_path) + else: + print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + mlirs[idx], + device=device, + device_idx=0, + mlir_dialect="tm_tensor", + mmap=True, + ) + module.save_module( + module_name=f"{idx}_full", + extra_args=[ + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + module.load_module(vmfb_path) + modules.append(module) + return mlirs, modules + + def get_sharded_model(self, device="cpu", compressed=False): # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess # please don't change it SAMPLE_INPUT_LEN = 137 vicuna_model = self.get_src_model() + if compressed: + vicuna_model.model = LlamaModel.from_pretrained( + "TheBloke/vicuna-7B-1.1-HF" + ) if self.precision in ["int4", "int8"]: print("Applying weight quantization..") @@ -870,16 +1041,38 @@ def get_sharded_model(self, device="cpu"): quantize_model( get_model_impl(vicuna_model).layers, dtype=torch.float32, + weight_quant_type="asym", weight_bit_width=weight_bit_width, weight_param_method="stats", weight_scale_precision="float", - weight_quant_type="asym", weight_quant_granularity="per_group", weight_group_size=self.weight_group_size, quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, ) print("Weight quantization applied.") + placeholder_pkv_segment = tuple( + ( + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + ) + for _ in range(8) + ) + placeholder_pkv_full = tuple( + ( + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + ) + for _ in range(32) + ) + placeholder_input0 = ( torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), @@ -930,12 +1123,27 @@ def get_sharded_model(self, device="cpu"): device_idx=device_idx, ) - layers0 = [ - FirstVicunaLayer(layer) for layer in vicuna_model.model.layers - ] - layers1 = [ - SecondVicunaLayer(layer) for layer in vicuna_model.model.layers - ] + if not compressed: + + layers0 = [ + FirstVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + layers1 = [ + SecondVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + + else: + layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8]) + layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16]) + layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24]) + layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32]) + layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8]) + layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16]) + layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24]) + layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32]) + layers0 = [layers00, layers01, layers02, layers03] + layers1 = [layers10, layers11, layers12, layers13] + _, modules = self.compile_to_vmfb_one_model( placeholder_input0, layers0, @@ -943,7 +1151,12 @@ def get_sharded_model(self, device="cpu"): layers1, device=device, ) - shark_layers = [CompiledVicunaLayer(m) for m in modules] + + if not compressed: + shark_layers = [CompiledVicunaLayer(m) for m in modules] + else: + shark_layers = [CompiledEightLayerLayer(m) for m in modules] + vicuna_model.model.compressedlayers = shark_layers sharded_model = ShardedVicunaModel( vicuna_model, @@ -955,11 +1168,13 @@ def get_sharded_model(self, device="cpu"): return sharded_model def compile(self, device="cpu"): - return self.get_sharded_model(device=device) + return self.get_sharded_model(device=device, compressed=self.compressed) - def generate(self, prompt, cli=True): + def generate(self, prompt, cli=False): # TODO: refactor for cleaner integration + history = [] + tokens_generated = [] _past_key_values = None _token = None @@ -977,6 +1192,8 @@ def generate(self, prompt, cli=True): _token = generated_token_op["token"] _past_key_values = generated_token_op["past_key_values"] _detok = generated_token_op["detok"] + history.append(_token) + yield self.tokenizer.decode(history) if _token == 2: break @@ -987,7 +1204,7 @@ def generate(self, prompt, cli=True): if type(tokens_generated[i]) != int: tokens_generated[i] = int(tokens_generated[i][0]) result_output = self.tokenizer.decode(tokens_generated) - return result_output + yield result_output def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. diff --git a/apps/language_models/src/model_wrappers/vicuna4.py b/apps/language_models/src/model_wrappers/vicuna4.py new file mode 100644 index 0000000000..508ab75995 --- /dev/null +++ b/apps/language_models/src/model_wrappers/vicuna4.py @@ -0,0 +1,879 @@ +import argparse +import json +import re +from io import BytesIO +from pathlib import Path +from tqdm import tqdm +from typing import List, Optional, Tuple, Union +import numpy as np +import iree.runtime +import itertools +import subprocess + +import torch +import torch_mlir +from torch_mlir import TensorPlaceholder +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + LlamaPreTrainedModel, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase +from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( + FirstVicunaLayer, + SecondVicunaLayer, + CompiledVicunaLayer, + ShardedVicunaModel, + LMHead, + LMHeadCompiled, + VicunaEmbedding, + VicunaEmbeddingCompiled, + VicunaNorm, + VicunaNormCompiled, +) +from apps.language_models.src.model_wrappers.vicuna_model import ( + FirstVicuna, + SecondVicuna, +) +from apps.language_models.utils import ( + get_vmfb_from_path, +) +from shark.shark_downloader import download_public_file +from shark.shark_importer import get_f16_inputs +from shark.shark_importer import import_with_fx +from shark.shark_inference import SharkInference + +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + _make_causal_mask, + _expand_mask, +) +from torch import nn +from time import time + + +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + t1 = time() + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = ( + seq_length_with_past + past_key_values_length + ) + + if position_ids is None: + device = ( + input_ids.device + if input_ids is not None + else inputs_embeds.device + ) + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.compressedlayers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[8 * idx : 8 * (idx + 1)] + if past_key_values is not None + else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer.forward( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[1:],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + try: + hidden_states = np.asarray(hidden_states, hidden_states.dtype) + except: + _ = 10 + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + next_cache = tuple(itertools.chain.from_iterable(next_cache)) + print(f"Token generated in {time() - t1} seconds") + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class EightLayerLayerSV(torch.nn.Module): + def __init__(self, layers): + super().__init__() + assert len(layers) == 8 + self.layers = layers + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + pkv00, + pkv01, + pkv10, + pkv11, + pkv20, + pkv21, + pkv30, + pkv31, + pkv40, + pkv41, + pkv50, + pkv51, + pkv60, + pkv61, + pkv70, + pkv71, + ): + pkvs = [ + (pkv00, pkv01), + (pkv10, pkv11), + (pkv20, pkv21), + (pkv30, pkv31), + (pkv40, pkv41), + (pkv50, pkv51), + (pkv60, pkv61), + (pkv70, pkv71), + ] + new_pkvs = [] + for layer, pkv in zip(self.layers, pkvs): + outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=( + pkv[0], + pkv[1], + ), + use_cache=True, + ) + + hidden_states = outputs[0] + new_pkvs.append( + ( + outputs[-1][0], + outputs[-1][1], + ) + ) + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + ) = new_pkvs + return ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + ) + + +class EightLayerLayerFV(torch.nn.Module): + def __init__(self, layers): + super().__init__() + assert len(layers) == 8 + self.layers = layers + + def forward(self, hidden_states, attention_mask, position_ids): + new_pkvs = [] + for layer in self.layers: + outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + use_cache=True, + ) + + hidden_states = outputs[0] + new_pkvs.append( + ( + outputs[-1][0], + outputs[-1][1], + ) + ) + ( + (new_pkv00, new_pkv01), + (new_pkv10, new_pkv11), + (new_pkv20, new_pkv21), + (new_pkv30, new_pkv31), + (new_pkv40, new_pkv41), + (new_pkv50, new_pkv51), + (new_pkv60, new_pkv61), + (new_pkv70, new_pkv71), + ) = new_pkvs + return ( + hidden_states, + new_pkv00, + new_pkv01, + new_pkv10, + new_pkv11, + new_pkv20, + new_pkv21, + new_pkv30, + new_pkv31, + new_pkv40, + new_pkv41, + new_pkv50, + new_pkv51, + new_pkv60, + new_pkv61, + new_pkv70, + new_pkv71, + ) + + +class CompiledEightLayerLayerSV(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions=False, + use_cache=True, + ): + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + ( + (pkv00, pkv01), + (pkv10, pkv11), + (pkv20, pkv21), + (pkv30, pkv31), + (pkv40, pkv41), + (pkv50, pkv51), + (pkv60, pkv61), + (pkv70, pkv71), + ) = past_key_value + pkv00 = pkv00.detatch() + pkv01 = pkv01.detatch() + pkv10 = pkv10.detatch() + pkv11 = pkv11.detatch() + pkv20 = pkv20.detatch() + pkv21 = pkv21.detatch() + pkv30 = pkv30.detatch() + pkv31 = pkv31.detatch() + pkv40 = pkv40.detatch() + pkv41 = pkv41.detatch() + pkv50 = pkv50.detatch() + pkv51 = pkv51.detatch() + pkv60 = pkv60.detatch() + pkv61 = pkv61.detatch() + pkv70 = pkv70.detatch() + pkv71 = pkv71.detatch() + + output = self.model( + "forward", + ( + hidden_states, + attention_mask, + position_ids, + pkv00, + pkv01, + pkv10, + pkv11, + pkv20, + pkv21, + pkv30, + pkv31, + pkv40, + pkv41, + pkv50, + pkv51, + pkv60, + pkv61, + pkv70, + pkv71, + ), + send_to_host=False, + ) + return ( + output[0], + (output[1][0], output[1][1]), + (output[2][0], output[2][1]), + (output[3][0], output[3][1]), + (output[4][0], output[4][1]), + (output[5][0], output[5][1]), + (output[6][0], output[6][1]), + (output[7][0], output[7][1]), + (output[8][0], output[8][1]), + ) + + +def forward_compressed( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = ( + input_ids.device if input_ids is not None else inputs_embeds.device + ) + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.compressedlayers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = ( + past_key_values[8 * idx : 8 * (idx + 1)] + if past_key_values is not None + else None + ) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], + ) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class CompiledEightLayerLayer(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value=None, + output_attentions=False, + use_cache=True, + ): + t2 = time() + if past_key_value is None: + try: + hidden_states = np.asarray(hidden_states, hidden_states.dtype) + except: + pass + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + t1 = time() + + output = self.model( + "first_vicuna_forward", + (hidden_states, attention_mask, position_ids), + send_to_host=False, + ) + output2 = ( + output[0], + ( + output[1], + output[2], + ), + ( + output[3], + output[4], + ), + ( + output[5], + output[6], + ), + ( + output[7], + output[8], + ), + ( + output[9], + output[10], + ), + ( + output[11], + output[12], + ), + ( + output[13], + output[14], + ), + ( + output[15], + output[16], + ), + ) + return output2 + else: + ( + (pkv00, pkv01), + (pkv10, pkv11), + (pkv20, pkv21), + (pkv30, pkv31), + (pkv40, pkv41), + (pkv50, pkv51), + (pkv60, pkv61), + (pkv70, pkv71), + ) = past_key_value + + try: + hidden_states = hidden_states.detach() + attention_mask = attention_mask.detach() + position_ids = position_ids.detach() + pkv00 = pkv00.detach() + pkv01 = pkv01.detach() + pkv10 = pkv10.detach() + pkv11 = pkv11.detach() + pkv20 = pkv20.detach() + pkv21 = pkv21.detach() + pkv30 = pkv30.detach() + pkv31 = pkv31.detach() + pkv40 = pkv40.detach() + pkv41 = pkv41.detach() + pkv50 = pkv50.detach() + pkv51 = pkv51.detach() + pkv60 = pkv60.detach() + pkv61 = pkv61.detach() + pkv70 = pkv70.detach() + pkv71 = pkv71.detach() + except: + x = 10 + + t1 = time() + if type(hidden_states) == iree.runtime.array_interop.DeviceArray: + hidden_states = np.array(hidden_states, hidden_states.dtype) + hidden_states = torch.tensor(hidden_states) + hidden_states = hidden_states.detach() + + output = self.model( + "second_vicuna_forward", + ( + hidden_states, + attention_mask, + position_ids, + pkv00, + pkv01, + pkv10, + pkv11, + pkv20, + pkv21, + pkv30, + pkv31, + pkv40, + pkv41, + pkv50, + pkv51, + pkv60, + pkv61, + pkv70, + pkv71, + ), + send_to_host=False, + ) + print(f"{time() - t1}") + del pkv00 + del pkv01 + del pkv10 + del pkv11 + del pkv20 + del pkv21 + del pkv30 + del pkv31 + del pkv40 + del pkv41 + del pkv50 + del pkv51 + del pkv60 + del pkv61 + del pkv70 + del pkv71 + output2 = ( + output[0], + ( + output[1], + output[2], + ), + ( + output[3], + output[4], + ), + ( + output[5], + output[6], + ), + ( + output[7], + output[8], + ), + ( + output[9], + output[10], + ), + ( + output[11], + output[12], + ), + ( + output[13], + output[14], + ), + ( + output[15], + output[16], + ), + ) + return output2 diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index 1a46c3b50f..54796fccfa 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module): def __init__(self, model, layers, lmhead, embedding, norm): super().__init__() self.model = model - assert len(layers) == len(model.model.layers) + # assert len(layers) == len(model.model.layers) self.model.model.config.use_cache = True self.model.model.config.output_attentions = False self.layers = layers @@ -132,7 +132,10 @@ def __init__(self, shark_module): self.model = shark_module def forward(self, hidden_states): - hidden_states.detach() + try: + hidden_states.detach() + except: + pass output = self.model("forward", (hidden_states,)) output = torch.tensor(output) return output diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py index f31344d715..26a024ecfa 100644 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ b/apps/stable_diffusion/web/ui/stablelm_ui.py @@ -27,6 +27,7 @@ def user(message, history): "codegen": "Salesforce/codegen25-7b-multi", "vicuna1p3": "lmsys/vicuna-7b-v1.3", "vicuna": "TheBloke/vicuna-7B-1.1-HF", + "vicuna4": "TheBloke/vicuna-7B-1.1-HF", "StableLM": "stabilityai/stablelm-tuned-alpha-3b", } @@ -66,6 +67,11 @@ def user(message, history): "The assistant gives helpful, detailed, and polite answers to the user's " "questions.\n" ), + "vicuna4": ( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's " + "questions.\n" + ), "vicuna1p3": ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's " @@ -81,6 +87,7 @@ def create_prompt(model_name, history): if model_name in [ "StableLM", "vicuna", + "vicuna4", "vicuna1p3", "llama2_7b", "llama2_70b", @@ -123,15 +130,20 @@ def chat( if model_name in [ "vicuna", + "vicuna4", "vicuna1p3", "codegen", "llama2_7b", "llama2_70b", ]: - from apps.language_models.scripts.vicuna import ( - UnshardedVicuna, - ShardedVicuna, - ) + if model_name == "vicuna4": + from apps.language_models.scripts.vicuna import ( + ShardedVicuna as Vicuna, + ) + else: + from apps.language_models.scripts.vicuna import ( + UnshardedVicuna as Vicuna, + ) from apps.stable_diffusion.src import args if vicuna_model == 0: @@ -148,28 +160,39 @@ def chat( print("unrecognized device") max_toks = 128 if model_name == "codegen" else 512 - if len(devices) == 1 and config_file is None: - vicuna_model = UnshardedVicuna( + if model_name == "vicuna4": + vicuna_model = Vicuna( model_name, hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, device=device, precision=precision, max_num_tokens=max_toks, + compressed=True, ) else: - if config_file is not None: - config_file = open(config_file) - config_json = json.load(config_file) - config_file.close() + if len(devices) == 1 and config_file is None: + vicuna_model = Vicuna( + model_name, + hf_model_path=model_path, + hf_auth_token=args.hf_auth_token, + device=device, + precision=precision, + max_num_tokens=max_toks, + ) else: - config_json = None - vicuna_model = ShardedVicuna( - model_name, - device=device, - precision=precision, - config_json=config_json, - ) + if config_file is not None: + config_file = open(config_file) + config_json = json.load(config_file) + config_file.close() + else: + config_json = None + vicuna_model = Vicuna( + model_name, + device=device, + precision=precision, + config_json=config_json, + ) + prompt = create_prompt(model_name, history) for partial_text in vicuna_model.generate(prompt, cli=cli):