Skip to content

Commit

Permalink
Add support for Falcon GPTQ
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Oct 11, 2023
1 parent a731eb6 commit 0a618e1
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
9 changes: 7 additions & 2 deletions apps/language_models/src/pipelines/falcon_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
Expand Down Expand Up @@ -235,7 +235,12 @@ def compile(self):
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
],
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
)
print("Saved falcon vmfb at ", str(path))
Expand Down
61 changes: 59 additions & 2 deletions shark/shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,69 @@ def gptq_transforms(fx_g):
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float32},
kwargs={"dtype": torch.float16},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float32}
new_node.kwargs = {"dtype": torch.float16}

# Inputs and outputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
new_node_arg1 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[1], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, new_node_arg1)

if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}

# Inputs and outputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, node.args[1], node.args[2])

if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
and node.target in [torch.ops.aten.expand]
):
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}

fx_g.graph.lint()

Expand Down

0 comments on commit 0a618e1

Please sign in to comment.