diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py index 26fd0f20de..f1452d98fd 100644 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ b/apps/language_models/src/pipelines/falcon_pipeline.py @@ -9,7 +9,7 @@ from shark.shark_downloader import download_public_file from shark.shark_importer import import_with_fx, save_mlir from shark.shark_inference import SharkInference -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig from transformers.generation import ( GenerationConfig, LogitsProcessorList, @@ -118,11 +118,17 @@ def get_src_model(self): "torch_dtype": torch.float, "trust_remote_code": True, "token": self.hf_auth_token, - "device_map": "cpu" if args.device == "cpu" else "cuda:0", } + if self.precision == "int4": + quantization_config = GPTQConfig(bits=4, disable_exllama=True) + kwargs["quantization_config"] = quantization_config + kwargs["load_gptq_on_cpu"] = True + kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0" falcon_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs ) + if self.precision == "int4": + falcon_model = falcon_model.to(torch.float32) return falcon_model def compile(self): @@ -194,7 +200,7 @@ def compile(self): ts_graph = import_with_fx( model, falconCompileInput, - is_f16=self.precision == "fp16", + is_f16=self.precision in ["fp16", "int4"], f16_input_mask=[False, False], mlir_type="torchscript", is_gptq=self.precision == "int4", @@ -229,7 +235,7 @@ def compile(self): mlir_dialect="linalg", ) path = shark_module.save_module( - self.falcon_vmfb_path, + self.falcon_vmfb_path.parent.absolute(), self.falcon_vmfb_path.stem, extra_args=[ "--iree-vm-target-truncate-unsupported-floats", @@ -417,7 +423,7 @@ def generate_new_token(self): (model_inputs["input_ids"], model_inputs["attention_mask"]), ) ) - if self.precision == "fp16": + if self.precision in ["fp16", "int4"]: outputs = outputs.to(dtype=torch.float32) next_token_logits = outputs diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 26abbd11ee..7082bf3813 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -491,23 +491,7 @@ def gptq_transforms(fx_g): node.args[4], ) - # Downcasting the result of native_layer_norm back to fp16. - if node.name.startswith("getitem"): - with fx_g.graph.inserting_before(node): - if node.args[0].target in [ - torch.ops.aten.native_layer_norm - ]: - new_node = fx_g.graph.call_function( - torch.ops.aten._to_copy, - args=(node,), - kwargs={"dtype": torch.float16}, - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - new_node.kwargs = {"dtype": torch.float16} - - # Inputs and outputs of aten.mm should be upcasted to fp32. + # Inputs of aten.mm should be upcasted to fp32. if node.target in [torch.ops.aten.mm]: with fx_g.graph.inserting_before(node): new_node_arg0 = fx_g.graph.call_function( @@ -522,6 +506,7 @@ def gptq_transforms(fx_g): ) node.args = (new_node_arg0, new_node_arg1) + # Outputs of aten.mm should be downcasted to fp16. if type(node.args[0]) == torch.fx.node.Node and node.args[ 0 ].target in [torch.ops.aten.mm]: @@ -537,7 +522,7 @@ def gptq_transforms(fx_g): new_node.args = (tmp,) new_node.kwargs = {"dtype": torch.float16} - # Inputs and outputs of aten._softmax should be upcasted to fp32. + # Inputs of aten._softmax should be upcasted to fp32. if node.target in [torch.ops.aten._softmax]: with fx_g.graph.inserting_before(node): new_node_arg0 = fx_g.graph.call_function( @@ -547,6 +532,7 @@ def gptq_transforms(fx_g): ) node.args = (new_node_arg0, node.args[1], node.args[2]) + # Outputs of aten._softmax should be downcasted to fp16. if ( type(node.args[0]) == torch.fx.node.Node and node.args[0].target in [torch.ops.aten._softmax]