Skip to content

Commit

Permalink
Fixed compilation issue with quantized sharding models
Browse files Browse the repository at this point in the history
  • Loading branch information
Elias Joseph committed Nov 20, 2023
1 parent 56ed782 commit bf97bb4
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,17 @@ def write_in_dynamic_inputs0(self, module, dynamic_input_size):
return new_module

def write_in_dynamic_inputs1(self, module, dynamic_input_size):
if self.precision == "fp32":
fprecision = "32"
else:
fprecision = "16"
new_lines = []
for line in module.splitlines():
if "dim_42 =" in line:
continue
if f"%c{dynamic_input_size}_i64 =" in line:
new_lines.append(
"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>"
f"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf{fprecision}>"
)
new_lines.append(
f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64"
Expand Down Expand Up @@ -605,9 +609,11 @@ def compile_vicuna_layer(
past_key_value0,
past_key_value1,
)
is_f16 = self.precision in ["fp16", "int4"]
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
Expand Down Expand Up @@ -658,9 +664,11 @@ def compile_vicuna_layer4(
pkv70,
pkv71,
)
is_f16 = self.precision in ["fp16", "int4"]
mlir_bytecode = import_with_fx(
vicuna_layer,
model_inputs,
is_f16=is_f16,
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
Expand Down Expand Up @@ -836,6 +844,15 @@ def compile_to_vmfb_one_model(
layers1,
device="cpu",
):
if self.precision != "fp32":
inputs0 = tuple(
inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt
for inpt in inputs0
)
inputs1 = tuple(
inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt
for inpt in inputs1
)
mlirs, modules = [], []
assert len(layers0) == len(layers1)
for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))):
Expand Down Expand Up @@ -899,6 +916,7 @@ def compile_to_vmfb_one_model(
use_tracing=False,
verbose=False,
)

print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module0,
Expand Down

0 comments on commit bf97bb4

Please sign in to comment.