diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 7280da73c2..f4aa8224a7 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -12,6 +12,8 @@ import time from dataclasses import dataclass from os import environ +from dataclasses import dataclass +from os import environ import torch import torch_mlir @@ -510,6 +512,8 @@ def __init__( n_devices=None, ) -> None: self.hf_auth_token = hf_auth_token + self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120} + self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40} super().__init__( model_name, hf_model_path, @@ -711,6 +715,27 @@ def get_device_index(self, layer_string): device_idx = max(idx_votes, key=idx_votes.get) return device_idx + + def write_dynamic_inputs_lmhead(self, ir, sample_input_length): + if self.precision in ["fp16", "int4"]: + precision_str = "f16" + else: + precision_str = "f32" + lines = ir.splitlines() + new_lines = [] + for line in lines: + if f"%cst_0 =" in line: + new_lines.append(line) + new_lines.append("%c1 = arith.constant 1 : index") + new_lines.append(f"%dim = tensor.dim %arg0, %c1 : tensor<1x?x{self.hidden_state_size_dict[self.model_name]}x{precision_str}>") + else: + line = re.sub(f"{sample_input_length}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + new_lines.append(line) + + return "\n".join(new_lines) + def compile_lmhead( self, lmh, @@ -775,14 +800,21 @@ def compile_lmhead( use_tracing=False, verbose=False, ) - + """ bytecode_stream = BytesIO() module.operation.write_bytecode(bytecode_stream) bytecode = bytecode_stream.getvalue() f_ = open(mlir_path, "wb") f_.write(bytecode) f_.close() + """ + module = str(module) + if self.precision in ["int4", "fp16"]: + module = self.write_dynamic_inputs_lmhead(module, 137) filepath = Path(f"{self.dir_name}/lmhead.mlir") + f_ = open(mlir_path, "w+") + f_.write(module) + f_.close() # download_public_file( # "gs://shark_tank/elias/compressed_sv/lmhead.mlir", # filepath.absolute(), @@ -795,7 +827,7 @@ def compile_lmhead( device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -883,7 +915,7 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -964,7 +996,7 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): device=device, mlir_dialect="tm_tensor", device_idx=device_idx, - mmap=False, + mmap=True, ) if vmfb_path.exists(): shark_module.load_module(vmfb_path) @@ -1163,12 +1195,13 @@ def compile_to_vmfb_one_model( device_idx = idx % self.n_devices else: device_idx = None + print(device_idx, self.n_devices) module = SharkInference( None, device=device, device_idx=device_idx, mlir_dialect="tm_tensor", - mmap=False, + mmap=True, ) module.load_module(vmfb_path) else: @@ -1180,13 +1213,13 @@ def compile_to_vmfb_one_model( if self.n_devices is not None: device_idx = idx % self.n_devices else: - device_idx = 0 + device_idx = None module = SharkInference( mlirs[idx], device=device, device_idx=device_idx, mlir_dialect="tm_tensor", - mmap=False, + mmap=True, ) module.save_module( module_name=f"{self.dir_name}/{idx}_full", @@ -1238,7 +1271,7 @@ def compile_to_vmfb_one_model4( if self.n_devices is not None: device_idx = idx % self.n_devices else: - device_idx = 0 + device_idx = None module = SharkInference( None, device=device, @@ -1256,7 +1289,7 @@ def compile_to_vmfb_one_model4( if self.n_devices is not None: device_idx = idx % self.n_devices else: - device_idx = 0 + device_idx = None module = SharkInference( mlirs[idx], device=device, @@ -1320,41 +1353,42 @@ def get_sharded_model(self, device="cpu", compressed=False): placeholder_pkv_segment = tuple( ( - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), ) for _ in range(8) ) placeholder_pkv_full = tuple( ( - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), ) - for _ in range(32) + for _ in range(self.n_layers_dict[self.model_name]) ) placeholder_input0 = ( - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64), ) placeholder_input1 = ( - torch.zeros([1, 1, 4096]), + torch.zeros([1, 1, self.hidden_state_size_dict[self.model_name]]), torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), torch.zeros([1, 1], dtype=torch.int64), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), ) norm = VicunaNorm(vicuna_model.model.norm) device_idx = self.get_device_index( r"vicuna\.model\.model\.norm(?:\.|\s|$)" ) - print(device_idx) + # HC device_idx for non-layer vmfbs + device_idx = 0 norm = self.compile_norm( norm, - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), device=self.device, device_idx=device_idx, ) @@ -1363,7 +1397,8 @@ def get_sharded_model(self, device="cpu", compressed=False): device_idx = self.get_device_index( r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" ) - print(device_idx) + # HC device_idx for non-layer vmfbs + device_idx = 0 embeddings = self.compile_embedding( embeddings, (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), @@ -1375,10 +1410,11 @@ def get_sharded_model(self, device="cpu", compressed=False): device_idx = self.get_device_index( r"vicuna\.model\.lm_head(?:\.|\s|$)" ) - print(device_idx) + # HC device_idx for non-layer vmfbs + device_idx = 0 lmhead = self.compile_lmhead( lmhead, - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), device=self.device, device_idx=device_idx, ) @@ -1452,13 +1488,13 @@ def generate(self, prompt, cli=False): generated_token_op = self.generate_new_token(params=params) - prefill_time = time.time() - decode_st_time + decode_time = (time.time() - decode_st_time) * 1000 _token = generated_token_op["token"] _past_key_values = generated_token_op["past_key_values"] _detok = generated_token_op["detok"] history.append(_token) - yield self.tokenizer.decode(history), None, prefill_time + yield self.tokenizer.decode(history), None, decode_time if _token == 2: break @@ -1667,12 +1703,7 @@ def remove_constant_dim(line): new_lines = [] # Using a while loop and the pop method to avoid creating a copy of module - if "llama2_13b" in self.model_name: - pkv_tensor_shape = "tensor<1x40x?x128x" - elif "llama2_70b" in self.model_name: - pkv_tensor_shape = "tensor<1x8x?x128x" - else: - pkv_tensor_shape = "tensor<1x32x?x128x" + pkv_tensor_shape = f"tensor<1x{self.n_layers_dict[self.model_name]}x?x128x" if self.precision in ["fp16", "int4", "int8"]: pkv_tensor_shape += "f16>" else: @@ -2066,14 +2097,14 @@ def generate(self, prompt, cli): generated_token_op = self.generate_new_token( params=params, sharded=False, cli=cli ) - prefill_time = time.time() - prefill_st_time + prefill_time_ms = (time.time() - prefill_st_time) * 1000 token = generated_token_op["token"] if "cpu" not in self.device: logits = generated_token_op["logits"] pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] - yield detok, None, prefill_time + yield detok, None, prefill_time_ms res_tokens.append(token) if cli: @@ -2408,8 +2439,7 @@ def avg_and_stdev(data): vic.shark_model.shark_runner.iree_config.device.flush_profiling() if msg is None: if is_first: - # Note that the prefill time is in seconds, and all the decoded tokens in ms. - prefill_time_ms = exec_time * 1000 + prefill_time_ms = exec_time is_first = False else: token_times_ms.append(exec_time) diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py index 54796fccfa..85aefb7d57 100644 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py @@ -1,4 +1,5 @@ import torch +import time class FirstVicunaLayer(torch.nn.Module): @@ -110,9 +111,11 @@ def __init__(self, shark_module): self.model = shark_module def forward(self, hidden_states): - hidden_states = hidden_states.detach() + hidden_states_sample = hidden_states.detach() + output = self.model("forward", (hidden_states,)) output = torch.tensor(output) + return output @@ -136,8 +139,9 @@ def forward(self, hidden_states): hidden_states.detach() except: pass - output = self.model("forward", (hidden_states,)) + output = self.model("forward", (hidden_states,), send_to_host=True) output = torch.tensor(output) + return output @@ -158,8 +162,9 @@ def __init__(self, shark_module): def forward(self, input_ids): input_ids.detach() - output = self.model("forward", (input_ids,)) + output = self.model("forward", (input_ids,), send_to_host=True) output = torch.tensor(output) + return output @@ -178,9 +183,10 @@ def forward( use_cache=True, ): if past_key_value is None: - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() + # hidden_states = hidden_states.detach() + # attention_mask = attention_mask.detach() + # position_ids = position_ids.detach() + output = self.model( "first_vicuna_forward", ( @@ -188,11 +194,17 @@ def forward( attention_mask, position_ids, ), + send_to_host=True, ) + ### send_to_host=True output0 = torch.tensor(output[0]) output1 = torch.tensor(output[1]) output2 = torch.tensor(output[2]) + ### send_to_host=False + # output0 = output[0] + # output1 = output[1] + # output2 = output[2] return ( output0, @@ -202,11 +214,12 @@ def forward( ), ) else: - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - pkv0 = past_key_value[0].detach() - pkv1 = past_key_value[1].detach() + # hidden_states = hidden_states.detach() + # attention_mask = attention_mask.detach() + # position_ids = position_ids.detach() + # pkv0 = past_key_value[0].detach() + pkv0 = past_key_value[0] + pkv1 = past_key_value[1] output = self.model( "second_vicuna_forward", ( @@ -216,11 +229,17 @@ def forward( pkv0, pkv1, ), + send_to_host=True, ) + ### send_to_host=True output0 = torch.tensor(output[0]) output1 = torch.tensor(output[1]) output2 = torch.tensor(output[2]) + ### send_to_host=False + # output0 = output[0] + # output1 = output[1] + # output2 = output[2] return ( output0, diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index b5f87378c4..6cfe369426 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -355,11 +355,15 @@ def get_iree_module( device = iree_device_map(device) print("registering device id: ", device_idx) haldriver = ireert.get_driver(device) + hal_device_id = haldriver.query_available_devices()[device_idx][ + "device_id" + ] haldevice = haldriver.create_device( - haldriver.query_available_devices()[device_idx]["device_id"], + hal_device_id, allocators=shark_args.device_allocator, ) config = ireert.Config(device=haldevice) + config.id = hal_device_id else: config = get_iree_runtime_config(device) vm_module = ireert.VmModule.from_buffer( @@ -398,15 +402,16 @@ def load_vmfb_using_mmap( haldriver = ireert.get_driver(device) dl.log(f"ireert.get_driver()") + hal_device_id = haldriver.query_available_devices()[device_idx][ + "device_id" + ] haldevice = haldriver.create_device( - haldriver.query_available_devices()[device_idx]["device_id"], + hal_device_id, allocators=shark_args.device_allocator, ) dl.log(f"ireert.create_device()") config = ireert.Config(device=haldevice) - config.id = haldriver.query_available_devices()[device_idx][ - "device_id" - ] + config.id = hal_device_id dl.log(f"ireert.Config()") else: config = get_iree_runtime_config(device)