diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml index fa22fb1a89..07f5c91985 100644 --- a/.github/workflows/test-models.yml +++ b/.github/workflows/test-models.yml @@ -115,6 +115,7 @@ jobs: pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv + python build_tools/vicuna_testing.py - name: Validate Models on NVIDIA GPU if: matrix.suite == 'cuda' diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 4d800c2c9a..86b64958c4 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -53,7 +53,7 @@ ) parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") parser.add_argument( - "--first_vicuna_vmfb_path", default=None, help="path to first vicuna vmfb" + "--vicuna_vmfb_path", default=None, help="path to vicuna vmfb" ) parser.add_argument( "-s", @@ -64,19 +64,9 @@ ) # TODO: sharded config parser.add_argument( - "--second_vicuna_vmfb_path", + "--vicuna_mlir_path", default=None, - help="path to second vicuna vmfb", -) -parser.add_argument( - "--first_vicuna_mlir_path", - default=None, - help="path to first vicuna mlir file", -) -parser.add_argument( - "--second_vicuna_mlir_path", - default=None, - help="path to second vicuna mlir", + help="path to vicuna mlir file", ) parser.add_argument( "--load_mlir_from_shark_tank", @@ -147,108 +137,35 @@ def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, brevitas〇matmul_rhs_group_quant〡has_value_semantics] -class ShardedVicuna(SharkLLMBase): - # Class representing Sharded Vicuna Model +class VicunaBase(SharkLLMBase): def __init__( self, model_name, hf_model_path="TheBloke/vicuna-7B-1.1-HF", max_num_tokens=512, - device="cuda", - precision="fp32", - config_json=None, - weight_group_size=128, + device="cpu", + precision="int8" ) -> None: super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device self.precision = precision - self.tokenizer = self.get_tokenizer() - self.config = config_json - self.weight_group_size = weight_group_size - self.shark_model = self.compile(device=device) def get_tokenizer(self): - kwargs = {} - if self.model_name == "llama2": - kwargs = { - "use_auth_token": "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" - } - if self.model_name == "codegen": - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - trust_remote_code=True, - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - use_fast=False, - **kwargs, - ) + # Retrieve the tokenizer from Huggingface + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, use_fast=False + ) return tokenizer def get_src_model(self): # Retrieve the torch model from Huggingface kwargs = {"torch_dtype": torch.float} - if self.model_name == "llama2": - kwargs["use_auth_token"] = "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" vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, - **kwargs, + self.hf_model_path, **kwargs ) return vicuna_model - def write_in_dynamic_inputs0(self, module, dynamic_input_size): - # Current solution for ensuring mlir files support dynamic inputs - # TODO find a more elegant way to implement this - new_lines = [] - for line in module.splitlines(): - line = re.sub(f"{dynamic_input_size}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - line = re.sub(f" {dynamic_input_size},", " %dim,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - - def write_in_dynamic_inputs1(self, module, dynamic_input_size): - new_lines = [] - for line in module.splitlines(): - if "dim_42 =" in line: - continue - if f"%c{dynamic_input_size}_i64 =" in line: - new_lines.append( - "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" - ) - new_lines.append( - f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" - ) - continue - line = re.sub(f"{dynamic_input_size}x", "?x", line) - line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) - if "?x" in line: - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim_42)", line - ) - line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim_42\)", - "tensor.empty(%dim_42, %dim_42)", - line, - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim_42", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - def combine_mlir_scripts( self, first_vicuna_mlir, second_vicuna_mlir, output_name ): @@ -257,8 +174,9 @@ def combine_mlir_scripts( constants = set() f1 = [] f2 = [] - - for line in first_vicuna_mlir.splitlines(): + first_vicuna_mlir = first_vicuna_mlir.splitlines() + while first_vicuna_mlir: + line = first_vicuna_mlir.pop(0) if re.search("#map\d*\s*=", line): maps1.append(line) elif re.search("arith.constant", line): @@ -267,6 +185,7 @@ def combine_mlir_scripts( line = re.sub("forward", "first_vicuna_forward", line) f1.append(line) f1 = f1[:-1] + del first_vicuna_mlir for i, map_line in enumerate(maps1): map_var = map_line.split(" ")[0] @@ -277,7 +196,9 @@ def combine_mlir_scripts( for func_line in f1 ] - for line in second_vicuna_mlir.splitlines(): + second_vicuna_mlir = second_vicuna_mlir.splitlines() + while second_vicuna_mlir: + line = second_vicuna_mlir.pop(0) if re.search("#map\d*\s*=", line): maps2.append(line) elif "global_seed" in line: @@ -305,31 +226,50 @@ def combine_mlir_scripts( global_vars = [] vnames = [] - vdtypes = [] global_var_loading1 = [] global_var_loading2 = [] - for constant in list(constants): + counter = 0 + constants = list(constants) + while constants: + constant = constants.pop(0) vname, vbody = constant.split("=") vname = re.sub("%", "", vname) vname = vname.strip() vbody = re.sub("arith.constant", "", vbody) vbody = vbody.strip() - vdtype = vbody.split(":")[1].strip() + if len(vbody.split(":"))<2: + print(constant) + vdtype = vbody.split(":")[-1].strip() fixed_vdtype = vdtype - vdtypes.append(vdtype) - vdtype = re.sub("\d{1,}x", "?x", vdtype) + if "c1_i64" in vname: + print(constant) + counter+=1 + if counter==2: + counter=0 + print("detected duplicate") + continue vnames.append(vname) - global_vars.append( - f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - + if "true" not in vname: + global_vars.append( + f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}" + ) + global_var_loading1.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + global_var_loading2.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + else: + global_vars.append( + f"ml_program.global public @{vname}({vbody}) : i1" + ) + global_var_loading1.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" + ) + global_var_loading2.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" + ) new_f1, new_f2 = [], [] for line in f1: @@ -343,14 +283,19 @@ def combine_mlir_scripts( for line in f2: if "func.func" in line: new_f2.append(line) - for global_var in global_var_loading1: + for global_var in global_var_loading2: + if "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in global_var: + print(global_var) new_f2.append(global_var) else: - new_f2.append(line) + if "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in line: + new_f2.append("%"+line) + else: + new_f2.append(line) f1 = new_f1 f2 = new_f2 - + print(["c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x for x in [maps1, maps2, global_vars, f1, f2]]) whole_string = "\n".join( maps1 + maps2 @@ -367,6 +312,163 @@ def combine_mlir_scripts( return whole_string + + def generate_new_token(self, params, sharded=True): + is_first = params["is_first"] + if is_first: + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + if sharded: + output = self.shark_model.forward(input_ids, is_first=is_first) + else: + output = self.shark_model("first_vicuna_forward", (input_ids,)) + out_tensor = torch.tensor(output[1:]) + + else: + token = params["token"] + past_key_values = params["past_key_values"] + input_ids = [token] + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + if sharded: + output = self.shark_model.forward( + input_ids, past_key_values=past_key_values, is_first=is_first + ) + else: + token = token.to(torch.int64).reshape([1,1]) + second_input = (token,) + tuple(past_key_values) + output = self.shark_model("second_vicuna_forward", second_input) + + + if sharded: + _logits = output["logits"] + _past_key_values = output["past_key_values"] + _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) + else: + print(len(output)) + _logits = torch.tensor(output[0]) + _past_key_values = torch.tensor(output[1:]) + _token = torch.argmax(_logits[:, -1, :], dim=1) + + skip_sp_tok = True if self.model_name == "codegen" else False + _detok = self.tokenizer.decode(_token, skip_special_tokens=skip_sp_tok) + ret_dict = { + "token": _token, + "detok": _detok, + "logits": _logits, + "past_key_values": _past_key_values, + } + + print(f" token : {_token} | detok : {_detok}") + + return ret_dict + +class ShardedVicuna(VicunaBase): + # Class representing Sharded Vicuna Model + def __init__( + self, + model_name, + hf_model_path="TheBloke/vicuna-7B-1.1-HF", + max_num_tokens=512, + device="cuda", + precision="fp32", + config_json=None, + weight_group_size=128, + ) -> None: + super().__init__(model_name, hf_model_path, max_num_tokens) + self.max_sequence_length = 256 + self.device = device + self.precision = precision + self.tokenizer = self.get_tokenizer() + self.config = config_json + self.weight_group_size = weight_group_size + self.shark_model = self.compile(device=device) + + def get_tokenizer(self): + kwargs = {} + if self.model_name == "llama2": + kwargs = { + "use_auth_token": "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" + } + if self.model_name == "codegen": + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, + trust_remote_code=True, + ) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, + use_fast=False, + **kwargs, + ) + return tokenizer + + def get_src_model(self): + # Retrieve the torch model from Huggingface + kwargs = {"torch_dtype": torch.float} + if self.model_name == "llama2": + kwargs["use_auth_token"] = "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" + vicuna_model = AutoModelForCausalLM.from_pretrained( + self.hf_model_path, + **kwargs, + ) + return vicuna_model + + def write_in_dynamic_inputs0(self, module, dynamic_input_size): + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this + new_lines = [] + for line in module.splitlines(): + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + line = re.sub(f" {dynamic_input_size},", " %dim,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + def write_in_dynamic_inputs1(self, module, dynamic_input_size): + new_lines = [] + for line in module.splitlines(): + if "dim_42 =" in line: + continue + if f"%c{dynamic_input_size}_i64 =" in line: + new_lines.append( + "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" + ) + new_lines.append( + f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" + ) + continue + line = re.sub(f"{dynamic_input_size}x", "?x", line) + line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) + if "?x" in line: + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim_42)", line + ) + line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim_42\)", + "tensor.empty(%dim_42, %dim_42)", + line, + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim_42", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + def compile_vicuna_layer( self, vicuna_layer, @@ -843,59 +945,23 @@ def generate(self, prompt, cli=True): result_output = self.tokenizer.decode(tokens_generated) return result_output - def generate_new_token(self, params): - is_first = params["is_first"] - if is_first: - prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward(input_ids, is_first=is_first) - else: - token = params["token"] - past_key_values = params["past_key_values"] - input_ids = [token] - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward( - input_ids, past_key_values=past_key_values, is_first=is_first - ) - - _logits = output["logits"] - _past_key_values = output["past_key_values"] - _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) - _detok = self.tokenizer.decode(_token) - - ret_dict = { - "token": _token, - "detok": _detok, - "past_key_values": _past_key_values, - } - - print(f" token : {_token} | detok : {_detok}") - - return ret_dict def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. pass -class UnshardedVicuna(SharkLLMBase): +class UnshardedVicuna(VicunaBase): def __init__( self, model_name, hf_model_path="TheBloke/vicuna-7B-1.1-HF", hf_auth_token: str = None, max_num_tokens=512, - device="cuda", - precision="fp32", - first_vicuna_mlir_path=None, - second_vicuna_mlir_path=None, - first_vicuna_vmfb_path=None, - second_vicuna_vmfb_path=None, + device="cpu", + precision="int8", + vicuna_mlir_path=None, + vicuna_vmfb_path=None, load_mlir_from_shark_tank=True, low_device_memory=False, weight_group_size=128, @@ -916,36 +982,24 @@ def __init__( self.device = device self.precision = precision self.download_vmfb = download_vmfb - self.first_vicuna_vmfb_path = first_vicuna_vmfb_path - self.second_vicuna_vmfb_path = second_vicuna_vmfb_path - self.first_vicuna_mlir_path = first_vicuna_mlir_path - self.second_vicuna_mlir_path = second_vicuna_mlir_path + self.vicuna_vmfb_path = vicuna_vmfb_path + self.vicuna_mlir_path = vicuna_mlir_path self.load_mlir_from_shark_tank = load_mlir_from_shark_tank self.low_device_memory = low_device_memory self.weight_group_size = weight_group_size - self.first_vic = None - self.second_vic = None - if self.first_vicuna_mlir_path == None: - self.first_vicuna_mlir_path = self.get_model_path() - if self.second_vicuna_mlir_path == None: - self.second_vicuna_mlir_path = self.get_model_path("second") - if self.first_vicuna_vmfb_path == None: - self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") - if self.second_vicuna_vmfb_path == None: - self.second_vicuna_vmfb_path = self.get_model_path( - "second", "vmfb" - ) + if self.vicuna_mlir_path == None: + self.vicuna_mlir_path = self.get_model_path() + if self.vicuna_vmfb_path == None: + self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb") self.tokenizer = self.get_tokenizer() - self.shark_model = self.compile() + self.compile() - def get_model_path(self, model_number="first", suffix="mlir"): + def get_model_path(self, suffix="mlir"): safe_device = self.device.split("-")[0] if suffix == "mlir": - return Path( - f"{model_number}_{self.model_name}_{self.precision}.{suffix}" - ) + return Path(f"vicuna_{self.precision}.{suffix}") return Path( - f"{model_number}_{self.model_name}_{self.precision}_{safe_device}.{suffix}" + f"vicuna_{self.precision}_{safe_device}.{suffix}" ) def get_tokenizer(self): @@ -974,39 +1028,138 @@ def get_src_model(self): ) return vicuna_model - def compile_first_vicuna(self): - vmfb = get_vmfb_from_path( - self.first_vicuna_vmfb_path, self.device, "tm_tensor" + def write_in_dynamic_inputs0(self, module, dynamic_input_size): + print("[DEBUG] writing dynamic inputs to first vicuna.") + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this + new_lines = [] + module = module.splitlines() + while module: + line = module.pop(0) + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + line = re.sub(f" {dynamic_input_size},", " %dim,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim", line) + if "%0 = tensor.empty(%dim) : tensor" in line: + new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>") + if ( + "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" + in line + ): + continue + + new_lines.append(line) + return '\n'.join(new_lines) + + def write_in_dynamic_inputs1(self, module): + print("[DEBUG] writing dynamic inputs to second vicuna.") + def remove_constant_dim(line): + if "c19_i64" in line: + line = re.sub("c19_i64", "dim_i64", line) + if "19x" in line: + line = re.sub("19x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim)", line + ) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", + "tensor.empty(%dim, %dim)", + line, + ) + if "arith.cmpi" in line: + line = re.sub("c19", "dim", line) + if " 19," in line: + line = re.sub(" 19,", " %dim,", line) + if "20x" in line: + line = re.sub("20x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dimp1)", line + ) + if " 20," in line: + line = re.sub(" 20,", " %dimp1,", line) + return line + + module = module.splitlines() + new_lines = [] + #Using a while loop and the pop method to avoid creating a copy of module + while module: + line = module.pop(0) + if "%c19_i64 = arith.constant 19 : i64" in line: + new_lines.append("%c2 = arith.constant 2 : index") + new_lines.append(f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>") + new_lines.append("%dim_i64 = arith.index_cast %dim_4_int : index to i64") + continue + if "%c2 = arith.constant 2 : index" in line: + continue + if "%c20_i64 = arith.constant 20 : i64" in line: + new_lines.append("%c1_i64 = arith.constant 1 : i64") + new_lines.append("c20_i64 = arith.addi %dim_i64, %c1_i64 : i64") + new_lines.append("%dimp1 = arith.index_cast %c20_i64 : i64 to index") + continue + line = remove_constant_dim(line) + new_lines.append(line) + + return '\n'.join(new_lines) + + def compile(self, download_vmfb=False): + # Testing : DO NOT Download Vmfbs if not found. Modify later + # download vmfbs for A100 + supported_devices = ["cuda", "cpu-sync", "cpu-task", "cpu"] + if ( + not self.vicuna_vmfb_path.exists() + and self.device in supported_devices + and self.precision in ["fp32", "fp16", "int8"] + ): + if (self.device == "cuda" and self.precision == "fp16") or ( + self.device in ["cpu-sync", "cpu-task"] + and self.precision == "int8" and download_vmfb + ): + download_public_file( + f"gs://shark_tank/vicuna/unsharded/vmfb/{self.vicuna_vmfb_path.name}", + self.vicuna_vmfb_path.absolute(), + single_file=True, + ) + else: + pass + + self.shark_model = get_vmfb_from_path( + self.vicuna_vmfb_path, self.device, "tm_tensor" ) - if vmfb is not None: - return vmfb + if self.shark_model is not None: + return None - # Compilation path needs some more work before it is functional print( - f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n" - f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}" + f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}. Trying to work with\n" + f"[DEBUG] mlir path { self.vicuna_mlir_path} {'exists' if self.vicuna_mlir_path.exists() else 'does not exist'}" ) - if self.first_vicuna_mlir_path.exists(): - with open(self.first_vicuna_mlir_path, "rb") as f: - bytecode = f.read() + if self.vicuna_mlir_path.exists(): + with open(self.vicuna_mlir_path, "rb") as f: + combined_module = f.read() else: mlir_generated = False if self.load_mlir_from_shark_tank: if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.first_vicuna_mlir_path.name}", - self.first_vicuna_mlir_path.absolute(), + f"gs://shark_tank/vicuna/unsharded/mlir/{self.vicuna_mlir_path.name}", + self.vicuna_mlir_path.absolute(), single_file=True, ) - if self.first_vicuna_mlir_path.exists(): - with open(self.first_vicuna_mlir_path, "rb") as f: + if self.vicuna_mlir_path.exists(): + with open(self.vicuna_mlir_path, "rb") as f: bytecode = f.read() mlir_generated = True else: - print( - f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}" - " after downloading! Generating mlir on device." + raise ValueError( + f"MLIR not found at {self.vicuna_mlir_path.absolute()}" + " after downloading! Please check path and try again" ) else: print( @@ -1020,6 +1173,13 @@ def compile_first_vicuna(self): compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')" else: compilation_prompt = "".join(["0" for _ in range(17)]) + combined_module = None + if Path("first.mlir").exists(): + print("loading first.mlir") + with open(Path("first.mlir"), "r") as f: + first_module = f.read() + else: + compilation_prompt = "".join(["0" for _ in range(17)]) compilation_input_ids = self.tokenizer( compilation_prompt, return_tensors="pt", @@ -1053,10 +1213,10 @@ def compile_first_vicuna(self): firstVicunaCompileInput[0], dynamic_axes=[1] ) firstVicunaCompileInput = tuple(firstVicunaCompileInput) - + first_module = None print(f"[DEBUG] generating torch mlir") if self.precision in ["int4", "int8"]: - module = torch_mlir.compile( + first_module = torch_mlir.compile( ts_graph, [*firstVicunaCompileInput], output_type=torch_mlir.OutputType.TORCH, @@ -1067,12 +1227,12 @@ def compile_first_vicuna(self): ) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( - module, + first_module, "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) else: - module = torch_mlir.compile( + first_module = torch_mlir.compile( ts_graph, [*firstVicunaCompileInput], torch_mlir.OutputType.LINALG_ON_TENSORS, @@ -1081,112 +1241,17 @@ def compile_first_vicuna(self): ) del ts_graph - def remove_constant_dim(line): - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim)", line - ) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - return line - - module = str(module) - new_lines = [] - - print(f"[DEBUG] rewriting torch_mlir file") - for line in module.splitlines(): - line = remove_constant_dim(line) - if "%0 = tensor.empty(%dim) : tensor" in line: - new_lines.append( - "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" - ) - if ( - "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" - in line - ): - continue - - new_lines.append(line) - - module = "\n".join(new_lines) - print(f"[DEBUG] converting to bytecode") - del new_lines - module = module.encode("UTF-8") - module = BytesIO(module) - bytecode = module.read() - del module - - print(f"[DEBUG] writing mlir to file") - f_ = open(self.first_vicuna_mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" - ) - path = shark_module.save_module( - self.first_vicuna_vmfb_path.parent.absolute(), - self.first_vicuna_vmfb_path.stem, - extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - print("Saved first vic vmfb at ", str(path)) - shark_module.load_module(path) - - return shark_module + + first_module = self.write_in_dynamic_inputs0(str(first_module), dynamic_input_size=19) - def compile_second_vicuna(self): - vmfb = get_vmfb_from_path( - self.second_vicuna_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - return vmfb - - # Compilation path needs some more work before it is functional - print( - f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}" - ) - if self.second_vicuna_mlir_path.exists(): - with open(self.second_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - else: - mlir_generated = False - if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16", "int8", "int4"]: - # download MLIR from shark_tank - download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.second_vicuna_mlir_path.name}", - self.second_vicuna_mlir_path.absolute(), - single_file=True, - ) - if self.second_vicuna_mlir_path.exists(): - with open(self.second_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - mlir_generated = True - else: - print( - f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}" - " after downloading! Generating mlir on device." - ) - else: - print( - "Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device." - ) + with open("first.mlir", "w+") as f: + f.write(first_module) - if not mlir_generated: + if Path("second.mlir").exists(): + print("loading second.mlir") + with open(Path("second.mlir"), "r") as f: + second_module = f.read() + else: compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) pkv = tuple( (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) @@ -1228,7 +1293,7 @@ def compile_second_vicuna(self): print(f"[DEBUG] generating torch mlir") if self.precision in ["int4", "int8"]: - module = torch_mlir.compile( + second_module = torch_mlir.compile( ts_graph, [*secondVicunaCompileInput], output_type=torch_mlir.OutputType.TORCH, @@ -1239,162 +1304,46 @@ def compile_second_vicuna(self): ) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( - module, + second_module, "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) else: - module = torch_mlir.compile( + second_module = torch_mlir.compile( ts_graph, [*secondVicunaCompileInput], torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=False, verbose=False, ) + print("[DEBUG] successfully converted second vicuna to linalg.") + second_module = str(second_module) + second_module = self.write_in_dynamic_inputs1(second_module) + with open("second.mlir", "w+") as f: + f.write(second_module) + + combined_module = self.combine_mlir_scripts(first_module, second_module, self.vicuna_mlir_path) + del first_module, second_module - def remove_constant_dim(line): - if "c19_i64" in line: - line = re.sub("c19_i64", "dim_i64", line) - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim)", line - ) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - if "20x" in line: - line = re.sub("20x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dimp1)", line - ) - if " 20," in line: - line = re.sub(" 20,", " %dimp1,", line) - return line - - module_str = str(module) - new_lines = [] - - print(f"[DEBUG] rewriting torch_mlir file") - for line in module_str.splitlines(): - if "%c19_i64 = arith.constant 19 : i64" in line: - new_lines.append("%c2 = arith.constant 2 : index") - new_lines.append( - f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>" - ) - new_lines.append( - "%dim_i64 = arith.index_cast %dim_4_int : index to i64" - ) - continue - if "%c2 = arith.constant 2 : index" in line: - continue - if "%c20_i64 = arith.constant 20 : i64" in line: - new_lines.append("%c1_i64 = arith.constant 1 : i64") - new_lines.append( - "%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" - ) - new_lines.append( - "%dimp1 = arith.index_cast %c20_i64 : i64 to index" - ) - continue - line = remove_constant_dim(line) - new_lines.append(line) - - module_str = "\n".join(new_lines) - print(f"[DEBUG] converting to bytecode") - bytecode = module_str.encode("UTF-8") - bytecode_stream = BytesIO(bytecode) - bytecode = bytecode_stream.read() - - print(f"[DEBUG] writing mlir to file") - f_ = open(self.second_vicuna_mlir_path, "wb") - f_.write(bytecode) - f_.close() + shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + mlir_module=combined_module, device=self.device, mlir_dialect="tm_tensor" ) - path = shark_module.save_module( - self.second_vicuna_vmfb_path.parent.absolute(), - self.second_vicuna_vmfb_path.stem, + self.vicuna_vmfb_path.parent.absolute(), + self.vicuna_vmfb_path.stem, extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", "--iree-vm-target-truncate-unsupported-floats", "--iree-codegen-check-ir-before-llvm-conversion=false", "--iree-vm-bytecode-module-output-format=flatbuffer-binary", ], ) - print("Saved second vic vmfb at ", str(path)) - shark_module.load_module(self.second_vicuna_vmfb_path) - - # self.shark_module = shark_module - return shark_module - - def compile(self): - # Cannot load both the models in the memory at once - # due to memory constraints, hence on demand compilation - # is being used until the space is enough for both models - - # Testing : DO NOT Download Vmfbs if not found. Modify later - # download vmfbs for A100 - supported_devices = ["cuda", "cpu-sync", "cpu-task", "cpu"] - if ( - not self.first_vicuna_vmfb_path.exists() - and self.device in supported_devices - and self.precision in ["fp32", "fp16", "int8"] - ): - if (self.device == "cuda" and self.precision == "fp16") or ( - self.device in ["cpu-sync", "cpu-task"] - and self.precision == "int8" - and self.download_vmfb - ): - download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", - self.first_vicuna_vmfb_path.absolute(), - single_file=True, - ) - else: - pass + print("Saved vic vmfb at ", str(path)) + shark_module.load_module(path) - else: - # get first vic - # TODO: Remove after testing to avoid memory overload - # fvic_shark_model = self.compile_first_vicuna() - pass - if ( - not self.second_vicuna_vmfb_path.exists() - and self.device in supported_devices - and self.precision in ["fp32", "fp16", "int8"] - ): - if (self.device == "cuda" and self.precision == "fp16") or ( - self.device in ["cpu-sync", "cpu-task"] - and self.precision == "int8" - and self.download_vmfb - ): - download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", - self.second_vicuna_vmfb_path.absolute(), - single_file=True, - ) - else: - pass - else: - # get second vic - # TODO: Remove after testing to avoid memory overload - # svic_shark_model = self.compile_second_vicuna() - pass + self.shark_module = shark_module - return None - # return tuple of shark_modules once mem is supported - # return fvic_shark_model, svic_shark_model def decode_tokens(self, res_tokens): for i in range(len(res_tokens)): @@ -1410,26 +1359,19 @@ def decode_tokens(self, res_tokens): def generate(self, prompt, cli=True): # TODO: refactor for cleaner integration import gc - - if not self.low_device_memory: - if self.first_vic == None: - self.first_vic = self.compile_first_vicuna() - if self.second_vic == None: - self.second_vic = self.compile_second_vicuna() + res_tokens = [] params = { "prompt": prompt, "is_first": True, - "fv": self.compile_first_vicuna() - if self.first_vic == None - else self.first_vic, + "fv": self.shark_model } - generated_token_op = self.generate_new_token(params=params) + generated_token_op = self.generate_new_token(params=params, sharded=False) token = generated_token_op["token"] logits = generated_token_op["logits"] - pkv = generated_token_op["pkv"] + pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] yield detok @@ -1437,28 +1379,21 @@ def generate(self, prompt, cli=True): if cli: print(f"Assistant: {detok}", end=" ", flush=True) - # Clear First Vic from Memory (main and cuda) - if self.low_device_memory: - del params - torch.cuda.empty_cache() - gc.collect() for _ in range(self.max_num_tokens - 2): params = { - "prompt": None, + "token": token, "is_first": False, "logits": logits, - "pkv": pkv, - "sv": self.compile_second_vicuna() - if self.second_vic == None - else self.second_vic, + "past_key_values": pkv, + "sv": self.shark_model } - generated_token_op = self.generate_new_token(params=params) + generated_token_op = self.generate_new_token(params=params, sharded=False) token = generated_token_op["token"] logits = generated_token_op["logits"] - pkv = generated_token_op["pkv"] + pkv = generated_token_op["past_key_values"] detok = generated_token_op["detok"] if token == 2 and self.model_name != "codegen": @@ -1475,96 +1410,11 @@ def generate(self, prompt, cli=True): part_str = self.decode_tokens(res_tokens) yield part_str - if self.low_device_memory: - del params - torch.cuda.empty_cache() - gc.collect() res_str = self.decode_tokens(res_tokens) # print(f"[DEBUG] final output : \n{res_str}") yield res_str - def generate_new_token(self, params, debug=False): - def forward_first(first_vic, prompt, cache_outputs=False): - input_ids = self.tokenizer(prompt).input_ids - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - firstVicunaInput = (input_ids,) - assert first_vic is not None - output_first_vicuna = first_vic("forward", firstVicunaInput) - output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:]) - logits_first_vicuna = torch.tensor(output_first_vicuna[0]) - if cache_outputs: - torch.save( - logits_first_vicuna, "logits_first_vicuna_tensor.pt" - ) - torch.save( - output_first_vicuna_tensor, "output_first_vicuna_tensor.pt" - ) - token = torch.argmax( - torch.tensor(logits_first_vicuna)[:, -1, :], dim=1 - ) - return token, logits_first_vicuna, output_first_vicuna_tensor - - def forward_second(sec_vic, inputs=None, load_inputs=False): - if inputs is not None: - logits = inputs[0] - pkv = inputs[1:] - elif load_inputs: - pkv = torch.load("output_first_vicuna_tensor.pt") - pkv = tuple(torch.tensor(x) for x in pkv) - logits = torch.load("logits_first_vicuna_tensor.pt") - else: - print( - "Either inputs must be given, or load_inputs must be true" - ) - return None - token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1) - token = token.to(torch.int64).reshape([1, 1]) - secondVicunaInput = (token,) + tuple(pkv) - - secondVicunaOutput = sec_vic("forward", secondVicunaInput) - new_pkv = secondVicunaOutput[1:] - new_logits = secondVicunaOutput[0] - new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1) - return new_token, new_logits, new_pkv - - is_first = params["is_first"] - - if is_first: - prompt = params["prompt"] - fv = params["fv"] - token, logits, pkv = forward_first( - fv, # self.shark_model[0], - prompt=prompt, - cache_outputs=False, - ) - else: - _logits = params["logits"] - _pkv = params["pkv"] - inputs = (_logits,) + tuple(_pkv) - sv = params["sv"] - token, logits, pkv = forward_second( - sv, # self.shark_model[1], - inputs=inputs, - load_inputs=False, - ) - - skip_sp_tok = True if self.model_name == "codegen" else False - detok = self.tokenizer.decode(token, skip_special_tokens=skip_sp_tok) - if debug: - print( - f"[DEBUG] is_first: {is_first} |" - f" token : {token} | detok : {detok}" - ) - ret_dict = { - "token": token, - "logits": logits, - "pkv": pkv, - "detok": detok, - } - return ret_dict def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. @@ -1576,36 +1426,23 @@ def autocomplete(self, prompt): vic = None if not args.sharded: - first_vic_mlir_path = ( - None - if args.first_vicuna_mlir_path is None - else Path(args.first_vicuna_mlir_path) - ) - second_vic_mlir_path = ( + vic_mlir_path = ( None - if args.second_vicuna_mlir_path is None - else Path(args.second_vicuna_mlir_path) + if args.vicuna_mlir_path is None + else Path(args.vicuna_mlir_path) ) - first_vic_vmfb_path = ( + vic_vmfb_path = ( None - if args.first_vicuna_vmfb_path is None - else Path(args.first_vicuna_vmfb_path) + if args.vicuna_vmfb_path is None + else Path(args.vicuna_vmfb_path) ) - second_vic_vmfb_path = ( - None - if args.second_vicuna_vmfb_path is None - else Path(args.second_vicuna_vmfb_path) - ) - vic = UnshardedVicuna( model_name=args.model_name, hf_auth_token=args.hf_auth_token, device=args.device, precision=args.precision, - first_vicuna_mlir_path=first_vic_mlir_path, - second_vicuna_mlir_path=second_vic_mlir_path, - first_vicuna_vmfb_path=first_vic_vmfb_path, - second_vicuna_vmfb_path=second_vic_vmfb_path, + vicuna_mlir_path=vic_mlir_path, + vicuna_vmfb_path=vic_vmfb_path, load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, weight_group_size=args.weight_group_size, download_vmfb=args.download_vmfb, diff --git a/build_tools/vicuna_testing.py b/build_tools/vicuna_testing.py new file mode 100644 index 0000000000..8b71ae4857 --- /dev/null +++ b/build_tools/vicuna_testing.py @@ -0,0 +1,14 @@ +import os +from sys import executable +import subprocess +from apps.language_models.scripts import vicuna + + +def test_loop(): + precisions = ["fp16", "int8", "int4"] + devices = ["cpu"] + for precision in precisions: + for device in devices: + model = vicuna.UnshardedVicuna(device=device, precision=precision) + model.compile() + del model