Skip to content

Commit

Permalink
Add LLaMa2-int4-fp16 support (#1782)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 authored Aug 22, 2023
1 parent b87efe7 commit 05889a8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 55 deletions.
130 changes: 77 additions & 53 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def combine_mlir_scripts(
print(f"[DEBIG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
constants_1 = set()
constants_2 = set()
f1 = []
f2 = []

Expand All @@ -206,7 +207,7 @@ def combine_mlir_scripts(
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
constants_1.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
Expand All @@ -232,7 +233,7 @@ def combine_mlir_scripts(
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
constants_2.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
Expand All @@ -255,15 +256,25 @@ def combine_mlir_scripts(
module_end = "}"

global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
global_var_loading1 = dict()
global_var_loading2 = dict()

print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
# in both 1 and 2
constants = [(e, "") for e in list(constants_1 & constants_2)]
# only in 1
constants.extend(
[(e, "_1") for e in list(constants_1.difference(constants_2))]
)
# only in 2
constants.extend(
[(e, "_2") for e in list(constants_2.difference(constants_1))]
)
del constants_1, constants_2
gc.collect()

while constants:
constant = constants.pop(0)
constant, vname_suf = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
Expand All @@ -273,41 +284,42 @@ def combine_mlir_scripts(
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
f"util.global private @{vname}{vname_suf} {noinline} = {vbody} : {fixed_vdtype}"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
f"util.global private @{vname}{vname_suf} = {vbody} : i1"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : i1"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = util.global_load @{vname}{vname_suf} : i1"
] = ""

del constants
gc.collect()

new_f1, new_f2 = [], []

print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
for global_var in global_var_loading1.keys():
new_f1.append(global_var)
else:
new_f1.append(line)
Expand All @@ -316,18 +328,15 @@ def combine_mlir_scripts(
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
for global_var in global_var_loading2.keys():
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
if "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in line:
new_f2.append("%" + line)
else:
new_f2.append(line)
new_f2.append(line)

f1 = new_f1
f2 = new_f2
Expand Down Expand Up @@ -441,7 +450,12 @@ def __init__(
compressed=False,
extra_args_cmd=[],
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
super().__init__(
model_name,
hf_model_path,
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
self.max_sequence_length = 256
self.device = device
self.precision = precision
Expand Down Expand Up @@ -945,7 +959,8 @@ def compile_to_vmfb_one_model(
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
] + self.extra_args,
]
+ self.extra_args,
)
module.load_module(vmfb_path)
modules.append(module)
Expand Down Expand Up @@ -1011,7 +1026,8 @@ def compile_to_vmfb_one_model4(
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
] + self.extra_args,
]
+ self.extra_args,
)
module.load_module(vmfb_path)
modules.append(module)
Expand Down Expand Up @@ -1225,7 +1241,12 @@ def __init__(
cache_vicunas=False,
extra_args_cmd=[],
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
super().__init__(
model_name,
hf_model_path,
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
Expand Down Expand Up @@ -1360,7 +1381,7 @@ def remove_constant_dim(line):
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append(
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
"%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
)
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
Expand Down Expand Up @@ -1447,7 +1468,9 @@ def compile(self, download_vmfb=False):
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
is_f16=True
if self.precision in ["fp16", "int4"]
else False,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
Expand All @@ -1468,9 +1491,7 @@ def compile(self, download_vmfb=False):
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"quant.matmul_rhs_group_quant"
],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
Expand Down Expand Up @@ -1528,13 +1549,15 @@ def compile(self, download_vmfb=False):
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
is_f16=True
if self.precision in ["fp16", "int4"]
else False,
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
del model
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
Expand All @@ -1555,9 +1578,7 @@ def compile(self, download_vmfb=False):
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"quant.matmul_rhs_group_quant"
],
backend_legal_ops=["quant.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
Expand Down Expand Up @@ -1606,7 +1627,8 @@ def compile(self, download_vmfb=False):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
] + self.extra_args,
]
+ self.extra_args,
)
print("Saved vic vmfb at ", str(path))
shark_module.load_module(path)
Expand Down Expand Up @@ -1681,6 +1703,7 @@ def autocomplete(self, prompt):
# use First vic alone to complete a story / prompt / sentence.
pass


# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
Expand Down Expand Up @@ -1730,6 +1753,7 @@ def autocomplete(self, prompt):
"codegen": "",
}


def create_prompt(model_name, history):
global start_message
system_message = start_message[model_name]
Expand Down Expand Up @@ -1820,5 +1844,5 @@ def create_prompt(model_name, history):
prompt = create_prompt(args.model_name, history)
for text, msg in vic.generate(prompt, cli=True):
if "formatted" in msg:
print("Response:",text)
print("Response:", text)
history[-1][1] = text
4 changes: 2 additions & 2 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=torch.float32,
dtype=torch.float16 if precision == "int4" else torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
Expand Down

0 comments on commit 05889a8

Please sign in to comment.