Skip to content

Commit

Permalink
Update on "[ET-VK] Changing all conv 2d pw ints from uint16 to int si…
Browse files Browse the repository at this point in the history
…nce it slightly improves perf."

This diff changes all integers in conv 2d pw op shader from uint16 to int in the Vulkan backend of Executorch. The change is made to improve performance since the shader does not appear to be register bound.

Differential Revision: [D67906023](https://our.internmc.facebook.com/intern/diff/D67906023/)

[ghstack-poisoned]
  • Loading branch information
trivedivivek committed Jan 8, 2025
2 parents 7469edb + 0019161 commit 0fb5b0d
Show file tree
Hide file tree
Showing 140 changed files with 2,096 additions and 1,265 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/buck2.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2024-05-15
2024-12-16
89 changes: 89 additions & 0 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
"""
torch.matmul can be decomposed in many ways, for instance:
dq -> matmul -> q can become
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
difficult. This helper function find all matmul partitions and annotate its
matmul-op (can be mm or bmm).
"""

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
],
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
)
matmul_targets = {
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.bmm.default,
}
for partition in matmul_partitions:
quantized_input = all(
input_node.target == dq_op for input_node in partition.input_nodes
)
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]
if quantized_input:
matmul_args = matmul_node.all_input_nodes
for i in range(len(matmul_args)):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
dq_node.args = (matmul_input_node, *input_node_qargs)
matmul_node.replace_input_with(matmul_input_node, dq_node)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
output_node_qargs = partition_output.args[1:]
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=q_op,
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *output_node_qargs)
# Remove partition output q-node
partition_output.replace_all_uses_with(
partition_output.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_output)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)
58 changes: 46 additions & 12 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
AnnotateDecomposedMatmulPass,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
Expand All @@ -32,7 +35,9 @@
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand Down Expand Up @@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
# TODO MLETORCH-558
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
Expand All @@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
]
)
)
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
22 changes: 2 additions & 20 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_param_tensor,
insert_q_dq_pair,
is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
supports 2d and 3d convolution. This is done by modifying the graph to do the
following:
1) unsqueeze the convolution's input from 3d to 4d
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
3) perform a conv2d (with a modified version of the original conv1d args)
4) squeeze the output back down to 3d.
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
2) perform a conv2d (with a modified version of the original conv1d args)
3) squeeze the output back down to 3d.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down Expand Up @@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

kernel_node = node.args[1]
if kernel_node.target == dq_op:
kernel_node = kernel_node.args[0]

if not is_param_node(self.exported_program, kernel_node):
raise AssertionError(
Expand Down Expand Up @@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)

with graph.inserting_after(node):
squeeze_after = create_node(
graph,
Expand All @@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
for user in original_users:
user.replace_input_with(node, squeeze_after)

# If quantized, insert conv2d --> q --> dq --> squeeze
if all(
original_user.target == q_op for original_user in original_users
):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, node, q_params)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
Expand Down
Loading

0 comments on commit 0fb5b0d

Please sign in to comment.