Skip to content

Commit

Permalink
Fix Falcon GPTQ Pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Oct 11, 2023
1 parent 0a618e1 commit b83d32f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
16 changes: 11 additions & 5 deletions apps/language_models/src/pipelines/falcon_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
Expand Down Expand Up @@ -118,11 +118,17 @@ def get_src_model(self):
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
"device_map": "cpu" if args.device == "cpu" else "cuda:0",
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model

def compile(self):
Expand Down Expand Up @@ -194,7 +200,7 @@ def compile(self):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision == "fp16",
is_f16=self.precision in ["fp16", "int4"],
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
Expand Down Expand Up @@ -229,7 +235,7 @@ def compile(self):
mlir_dialect="linalg",
)
path = shark_module.save_module(
self.falcon_vmfb_path,
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-vm-target-truncate-unsupported-floats",
Expand Down Expand Up @@ -417,7 +423,7 @@ def generate_new_token(self):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision == "fp16":
if self.precision in ["fp16", "int4"]:
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs

Expand Down
22 changes: 4 additions & 18 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,23 +491,7 @@ def gptq_transforms(fx_g):
node.args[4],
)

# Downcasting the result of native_layer_norm back to fp16.
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [
torch.ops.aten.native_layer_norm
]:
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float16}

# Inputs and outputs of aten.mm should be upcasted to fp32.
# Inputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
Expand All @@ -522,6 +506,7 @@ def gptq_transforms(fx_g):
)
node.args = (new_node_arg0, new_node_arg1)

# Outputs of aten.mm should be downcasted to fp16.
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
Expand All @@ -537,7 +522,7 @@ def gptq_transforms(fx_g):
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}

# Inputs and outputs of aten._softmax should be upcasted to fp32.
# Inputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
Expand All @@ -547,6 +532,7 @@ def gptq_transforms(fx_g):
)
node.args = (new_node_arg0, node.args[1], node.args[2])

# Outputs of aten._softmax should be downcasted to fp16.
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
Expand Down

0 comments on commit b83d32f

Please sign in to comment.