Skip to content

Commit

Permalink
Update on "Use c10 version of half/bfloat16 in executorch"
Browse files Browse the repository at this point in the history
Accomplished by importing relevant files from c10 into
executorch/runtime/core/portable_type/c10, and then using `using` in
the top-level ExecuTorch headers. This approach should keep the
ExecuTorch build hermetic for embedded use cases. In the future, we
should add a CI job to ensure the c10 files stay identical to the
PyTorch ones.

Differential Revision: [D66106969](https://our.internmc.facebook.com/intern/diff/D66106969/)

[ghstack-poisoned]
  • Loading branch information
swolchok committed Jan 8, 2025
2 parents e4d21ff + 91bad0a commit afd982e
Show file tree
Hide file tree
Showing 158 changed files with 2,240 additions and 1,303 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/buck2.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2024-05-15
2024-12-16
5 changes: 3 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ include_patterns = [
'profiler/**/*.py',
'runtime/**/*.py',
'scripts/**/*.py',
# 'test/**/*.py',
# 'util/**/*.py',
'test/**/*.py',
'util/**/*.py',
'*.py',
]
exclude_patterns = [
Expand All @@ -325,6 +325,7 @@ command = [
'--config=.mypy.ini',
'--show-disable',
'--',
'--explicit-package-bases',
'@{{PATHSFILE}}'
]
init_command = [
Expand Down
12 changes: 11 additions & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ files =
profiler,
runtime,
scripts,
test,
util

mypy_path = executorch

[mypy-executorch.backends.*]
follow_untyped_imports = True

[mypy-executorch.codegen.*]
follow_untyped_imports = True

Expand All @@ -46,6 +50,12 @@ follow_untyped_imports = True
[mypy-executorch.runtime.*]
follow_untyped_imports = True

[mypy-executorch.test.*]
follow_untyped_imports = True

[mypy-functorch.*]
follow_untyped_imports = True

[mypy-requests.*]
follow_untyped_imports = True

Expand Down Expand Up @@ -80,4 +90,4 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-zstd]
ignore_missing_imports = True
ignore_missing_imports = True
89 changes: 89 additions & 0 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class AnnotateDecomposedMatmulPass(ExportPass):
"""
torch.matmul can be decomposed in many ways, for instance:
dq -> matmul -> q can become
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
difficult. This helper function find all matmul partitions and annotate its
matmul-op (can be mm or bmm).
"""

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
[
torch.matmul,
],
None,
)
matmul_partitions = list(
itertools.chain.from_iterable(matmul_partitions.values())
)
matmul_targets = {
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.bmm.default,
}
for partition in matmul_partitions:
quantized_input = all(
input_node.target == dq_op for input_node in partition.input_nodes
)
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]
if quantized_input:
matmul_args = matmul_node.all_input_nodes
for i in range(len(matmul_args)):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
dq_node.args = (matmul_input_node, *input_node_qargs)
matmul_node.replace_input_with(matmul_input_node, dq_node)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
output_node_qargs = partition_output.args[1:]
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
graph=graph_module.graph,
op_target=q_op,
)
matmul_node.replace_all_uses_with(q_node)
q_node.args = (matmul_node, *output_node_qargs)
# Remove partition output q-node
partition_output.replace_all_uses_with(
partition_output.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_output)

# retrace the graph to update the fake tensor types
graph_module = super().call(graph_module).graph_module

graph_module.recompile()
return PassResult(graph_module, True)
58 changes: 46 additions & 12 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.annotate_decomposed_matmul import (
AnnotateDecomposedMatmulPass,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
Expand All @@ -32,7 +35,9 @@
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
Expand Down Expand Up @@ -67,24 +72,15 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(DecomposeLinearPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
# TODO MLETORCH-558
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
Expand All @@ -93,11 +89,49 @@ def transform_to_backend_pipeline(
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
]
)
)
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
22 changes: 2 additions & 20 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_param_tensor,
insert_q_dq_pair,
is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand All @@ -27,10 +25,8 @@ class Conv1dUnsqueezePass(ExportPass):
supports 2d and 3d convolution. This is done by modifying the graph to do the
following:
1) unsqueeze the convolution's input from 3d to 4d
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
3) perform a conv2d (with a modified version of the original conv1d args)
4) squeeze the output back down to 3d.
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
2) perform a conv2d (with a modified version of the original conv1d args)
3) squeeze the output back down to 3d.
"""

def __init__(self, exported_program: ExportedProgram) -> None:
Expand Down Expand Up @@ -94,8 +90,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

kernel_node = node.args[1]
if kernel_node.target == dq_op:
kernel_node = kernel_node.args[0]

if not is_param_node(self.exported_program, kernel_node):
raise AssertionError(
Expand Down Expand Up @@ -131,11 +125,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)

with graph.inserting_after(node):
squeeze_after = create_node(
graph,
Expand All @@ -151,13 +140,6 @@ def call(self, graph_module: torch.fx.GraphModule):
for user in original_users:
user.replace_input_with(node, squeeze_after)

# If quantized, insert conv2d --> q --> dq --> squeeze
if all(
original_user.target == q_op for original_user in original_users
):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, node, q_params)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
Expand Down
Loading

0 comments on commit afd982e

Please sign in to comment.