Skip to content

Commit

Permalink
Unsharded Vicuna: Fix Memory Error compiling mlir for lmsys/vicuna-7b…
Browse files Browse the repository at this point in the history
…-v1.3 fp16 with 64 GiB (#1702)
  • Loading branch information
one-lithe-rune authored Aug 1, 2023
1 parent 98fb6c5 commit 6bb329c
Showing 1 changed file with 51 additions and 22 deletions.
73 changes: 51 additions & 22 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import json
import re
import gc
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
Expand Down Expand Up @@ -40,9 +41,6 @@
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

if __name__ == "__main__":
import gc


parser = argparse.ArgumentParser(
prog="vicuna runner",
Expand Down Expand Up @@ -114,10 +112,11 @@
"--cache_vicunas",
default=False,
action=argparse.BooleanOptionalAction,
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk"
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk",
)


# fmt: off
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
Expand All @@ -141,6 +140,7 @@ def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale,
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
# fmt: on


class VicunaBase(SharkLLMBase):
Expand Down Expand Up @@ -176,11 +176,14 @@ def combine_mlir_scripts(
self, first_vicuna_mlir, second_vicuna_mlir, output_name
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBIG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []

print(f"[DEBUG] processing first vircuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
Expand All @@ -193,6 +196,7 @@ def combine_mlir_scripts(
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()

for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
Expand All @@ -203,6 +207,7 @@ def combine_mlir_scripts(
for func_line in f1
]

print(f"[DEBUG] processing second vircuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
Expand All @@ -216,6 +221,8 @@ def combine_mlir_scripts(
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()

for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
Expand All @@ -236,6 +243,7 @@ def combine_mlir_scripts(
global_var_loading1 = []
global_var_loading2 = []

print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
Expand Down Expand Up @@ -279,6 +287,7 @@ def combine_mlir_scripts(
)
new_f1, new_f2 = [], []

print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
Expand All @@ -287,6 +296,7 @@ def combine_mlir_scripts(
else:
new_f1.append(line)

print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
Expand All @@ -305,27 +315,43 @@ def combine_mlir_scripts(

f1 = new_f1
f2 = new_f2

del new_f1
del new_f2
gc.collect()

print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
whole_string = "\n".join(
maps1
+ maps2
+ [module_start]
+ global_vars
+ f1
+ f2
+ [module_end]
)

f_ = open(output_name, "w+")
f_.write(whole_string)
f_.close()

return whole_string
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.

print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])

del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()

print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()

def generate_new_token(self, params, sharded=True):
is_first = params["is_first"]
Expand Down Expand Up @@ -1182,11 +1208,10 @@ def compile(self, download_vmfb=False):
else:
compilation_prompt = "".join(["0" for _ in range(17)])


if Path(f'first_{self.precision}.mlir').exists():
if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
first_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
Expand Down Expand Up @@ -1251,6 +1276,9 @@ def compile(self, download_vmfb=False):
verbose=False,
)
del ts_graph
del firstVicunaCompileInput
gc.collect()

print(
"[DEBUG] successfully generated first vicuna linalg mlir"
)
Expand Down Expand Up @@ -1335,6 +1363,8 @@ def compile(self, download_vmfb=False):
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
gc.collect()
print(
"[DEBUG] successfully generated second vicuna linalg mlir"
)
Expand Down Expand Up @@ -1381,7 +1411,6 @@ def decode_tokens(self, res_tokens):

def generate(self, prompt, cli=True):
# TODO: refactor for cleaner integration
import gc
if self.shark_model is None:
self.compile()
res_tokens = []
Expand Down

0 comments on commit 6bb329c

Please sign in to comment.