-
Notifications
You must be signed in to change notification settings - Fork 423
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2024-10-12 nightly release (1f2b9aa)
- Loading branch information
pytorchbot
committed
Oct 12, 2024
1 parent
5db774d
commit 6a1d9c3
Showing
49 changed files
with
736 additions
and
269 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from torch.fx.experimental.proxy_tensor import make_fx | ||
|
||
|
||
class DecomposeEinsum(ExportPass): | ||
""" | ||
Decompose einsum for quantization annotation to work properly. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
graph = graph_module.graph | ||
for node in graph.nodes: | ||
if node.target == torch.ops.aten.einsum.default: | ||
decomposed_module = make_fx( | ||
node.target, | ||
tracing_mode="fake", | ||
)(node.args[0], [arg.meta["val"] for arg in node.args[1]]) | ||
|
||
with graph.inserting_before(node): | ||
# remap is used to map original node values to new node values, | ||
# which ensures that reference to nodes are correclty updated in the new graph | ||
remap = {} | ||
# Different from other nodes, einsum args[0] is the einsum equation, | ||
# while input nodes are stored in args[1] | ||
for i, arg in enumerate(node.args[1]): | ||
remap[f"arg1_{i+1}"] = arg | ||
|
||
for decomposed_node in decomposed_module.graph.nodes: | ||
# This is the arg[0] equation string, which is not required anymore after decomposition | ||
if "arg0" in decomposed_node.name: | ||
continue | ||
|
||
# no need to copy existent 'output' | ||
if decomposed_node.op == "output": | ||
for user in node.users.copy(): | ||
# remap | ||
user.replace_input_with( | ||
node, | ||
remap[decomposed_node.args[0][0]], | ||
) | ||
# no need to copy existent placeholders | ||
elif decomposed_node.op == "placeholder": | ||
# replace node map from string to graph node | ||
remap[decomposed_node] = remap.pop(decomposed_node.name) | ||
else: | ||
remap[decomposed_node] = graph.node_copy( | ||
decomposed_node, | ||
arg_transform=lambda x, remap=remap: remap[x], | ||
) | ||
|
||
graph.erase_node(node) | ||
|
||
graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) Qualcomm Innovation Center, Inc. | ||
# All rights reserved | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import warnings | ||
from typing import cast, Dict | ||
|
||
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper | ||
|
||
import numpy as np | ||
import torch | ||
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA | ||
|
||
from .node_visitor import NodeVisitor, register_node_visitor | ||
from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW | ||
|
||
|
||
@register_node_visitor | ||
class TopK(NodeVisitor): | ||
target = ["aten.topk.default"] | ||
|
||
def __init__(self, *args) -> None: | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: torch.fx.Node, | ||
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], | ||
) -> PyQnnWrapper.PyQnnOpWrapper: | ||
|
||
input_node = node.args[0] | ||
input_tensor = self.get_tensor(input_node, node) | ||
input_tensor_wrapper = self.define_tensor( | ||
input_node, | ||
input_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, | ||
nodes_to_wrappers, | ||
is_input_tensor=True, | ||
) | ||
|
||
k = cast(int, node.args[1]) | ||
|
||
if len(node.args) > 2: | ||
dim = cast(int, node.args[2]) | ||
if dim < 0: | ||
dim = dim % len(input_tensor.shape) | ||
if QCOM_AXIS_ORDER in node.meta: | ||
dim = node.meta[QCOM_AXIS_ORDER].index(dim) | ||
if dim != len(input_tensor.shape) - 1: | ||
warnings.warn( | ||
"[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.", | ||
stacklevel=1, | ||
) | ||
return | ||
|
||
topk_input_tensors = [input_tensor_wrapper] | ||
|
||
output_val_tensor = self.get_tensor(node, node, 0) | ||
output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32) | ||
|
||
# QNN constraint, topk output_0 requires having the same quant config as input | ||
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs") | ||
output_val_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_val_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=False, | ||
) | ||
|
||
# topk output_1 is index, do not quantize it. | ||
node.meta.pop("quant_attrs", None) | ||
output_index_tensor_wrapper = self.define_tensor( | ||
node, | ||
output_idx_tensor, | ||
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, | ||
nodes_to_wrappers, | ||
is_input_tensor=False, | ||
wrapper_idx=1, | ||
) | ||
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper] | ||
|
||
topk_op = PyQnnWrapper.PyQnnOpWrapper( | ||
node.name, | ||
QNN_OP_PACKAGE_NAME_QTI_AISW, | ||
OpTopK.op_name, | ||
) | ||
topk_op.AddInputTensors(topk_input_tensors) | ||
topk_op.AddOutputTensors(topk_output_tensors) | ||
|
||
topk_op.AddScalarParam( | ||
OpTopK.param_k, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, | ||
{"data": np.uint32(k)}, | ||
) | ||
|
||
# As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation | ||
if len(node.args) > 3: | ||
largest = cast(bool, node.args[3]) | ||
topk_op.AddScalarParam( | ||
OpTopK.param_largest, | ||
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, | ||
{QCOM_DATA: largest}, | ||
) | ||
|
||
return topk_op |
Oops, something went wrong.