diff --git a/.gitmodules b/.gitmodules index 71ff854bb0..844cd91789 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,9 +28,6 @@ [submodule "backends/xnnpack/third-party/pthreadpool"] path = backends/xnnpack/third-party/pthreadpool url = https://github.com/Maratyszcza/pthreadpool.git -[submodule "examples/third-party/fbjni"] - path = examples/third-party/fbjni - url = https://github.com/facebookincubator/fbjni.git [submodule "extension/llm/third-party/abseil-cpp"] path = extension/llm/third-party/abseil-cpp url = https://github.com/abseil/abseil-cpp.git diff --git a/backends/qualcomm/_passes/decompose_einsum.py b/backends/qualcomm/_passes/decompose_einsum.py new file mode 100644 index 0000000000..c1924838b3 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_einsum.py @@ -0,0 +1,65 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.experimental.proxy_tensor import make_fx + + +class DecomposeEinsum(ExportPass): + """ + Decompose einsum for quantization annotation to work properly. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if node.target == torch.ops.aten.einsum.default: + decomposed_module = make_fx( + node.target, + tracing_mode="fake", + )(node.args[0], [arg.meta["val"] for arg in node.args[1]]) + + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correclty updated in the new graph + remap = {} + # Different from other nodes, einsum args[0] is the einsum equation, + # while input nodes are stored in args[1] + for i, arg in enumerate(node.args[1]): + remap[f"arg1_{i+1}"] = arg + + for decomposed_node in decomposed_module.graph.nodes: + # This is the arg[0] equation string, which is not required anymore after decomposition + if "arg0" in decomposed_node.name: + continue + + # no need to copy existent 'output' + if decomposed_node.op == "output": + for user in node.users.copy(): + # remap + user.replace_input_with( + node, + remap[decomposed_node.args[0][0]], + ) + # no need to copy existent placeholders + elif decomposed_node.op == "placeholder": + # replace node map from string to graph node + remap[decomposed_node] = remap.pop(decomposed_node.name) + else: + remap[decomposed_node] = graph.node_copy( + decomposed_node, + arg_transform=lambda x, remap=remap: remap[x], + ) + + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py index 417d3b85b0..5291edeb9f 100644 --- a/backends/qualcomm/_passes/insert_requantize.py +++ b/backends/qualcomm/_passes/insert_requantize.py @@ -28,6 +28,7 @@ class InsertRequantize(ExportPass): # we don't use the 2nd output, 2nd output is an integer, etc. multi_output_op_ignore_set = { exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.topk.default, } def __init__( diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index bdee2c8196..829c11fda4 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -65,6 +65,7 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.topk.default, exir_ops.edge.aten._to_copy.default, exir_ops.edge.aten.split_with_sizes.default, *q_ops, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 79c02e2207..74fd58a3ec 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -53,6 +53,7 @@ op_sum_int_list, op_tanh, op_to, + op_topk, op_transpose, op_unsqueeze, op_upsample_bilinear2d, @@ -107,6 +108,7 @@ op_sub, op_sum_int_list, op_tanh, + op_topk, op_to, op_transpose, op_unsqueeze, diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index 2f7e773b4f..5ad3fc36c9 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -85,7 +86,10 @@ def define_node( if len(node.args) > 6: divisor_override = cast(int, node.args[6]) if divisor_override != pooling_region: - print("Not support divisor_override which is not equal to pooling region.") + warnings.warn( + "[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.", + stacklevel=1, + ) return avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index bb68b24289..cf18690498 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -43,8 +44,9 @@ def define_node( ) if len(list_of_tensors) != len(list_of_tensor_wrappers): - print( - "The number or input tensors is not equal to the number of input tensor wrappers." + warnings.warn( + "[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.", + stacklevel=1, ) return diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index b6e70c374e..30207a0392 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -189,12 +190,18 @@ def _define_conv1d( # args[6] = transposed if cast(bool, node.args[6]): - print("Currently, No support for transposed convolution") + warnings.warn( + "[QNN Delegate Op Builder]: Currently, No support for transposed convolution.", + stacklevel=1, + ) return # args[7] = output padding if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])): - print("QNN does not support output padding") + warnings.warn( + "[QNN Delegate Op Builder]: QNN does not support output padding.", + stacklevel=1, + ) return stride_shape = [len(stride)] diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py index dec352fef7..3f5c266cdd 100644 --- a/backends/qualcomm/builders/op_expand.py +++ b/backends/qualcomm/builders/op_expand.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -52,8 +53,9 @@ def define_node( output_dims = len(output_tensor.size()) if input_dims < output_dims: - print( - f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}." + warnings.warn( + f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.", + stacklevel=1, ) return diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 18f5b76310..635e12d2ee 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -44,7 +45,10 @@ def define_node( len(normalized_shapes) != 1 and normalized_shapes[0] != input_tensor.shape[-1] ): - print("Only supports normalization with last input dimension") + warnings.warn( + "[QNN Delegate Op Builder]: Only supports normalization with last input dimension.", + stacklevel=1, + ) return axis = [len(input_tensor.shape) - 1] axis_shape = [len(axis)] diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 17afb21c6d..e4f16d4473 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -70,8 +71,9 @@ def define_node( # TODO remove this when qnn sdk support if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}): - print( - f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." + warnings.warn( + f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.", + stacklevel=1, ) bias_tensor = get_parameter(bias_node, self.edge_program) bias_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py index 586556621b..27f14889bf 100644 --- a/backends/qualcomm/builders/op_max_pool2d.py +++ b/backends/qualcomm/builders/op_max_pool2d.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -42,8 +43,9 @@ def define_node( if user.target.__name__ == "getitem": getitem_index = user.args[1] if getitem_index != 0: - print( - f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}" + warnings.warn( + f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}", + stacklevel=1, ) return @@ -78,8 +80,9 @@ def define_node( if len(node.args) > 4: dilation = cast(List[int], node.args[4]) if not (dilation == 1 or dilation == [1, 1]): - print( - f"Not support dilation argument for max pool2d, but got {dilation}" + warnings.warn( + f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}", + stacklevel=1, ) return diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index e99b1f47ba..3a5101b12d 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -47,7 +48,10 @@ def define_node( len(normalized_shapes) != 1 and normalized_shapes[0] != input_tensor.shape[-1] ): - print("Only supports normalization with last input dimension") + warnings.warn( + "[QNN Delegate Op Builder]: Only supports normalization with last input dimension.", + stacklevel=1, + ) return axes = [node.args[0].meta["val"].dim() - 1] axes_shape = [len(axes)] diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py new file mode 100644 index 0000000000..84c29925f2 --- /dev/null +++ b/backends/qualcomm/builders/op_topk.py @@ -0,0 +1,107 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class TopK(NodeVisitor): + target = ["aten.topk.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=True, + ) + + k = cast(int, node.args[1]) + + if len(node.args) > 2: + dim = cast(int, node.args[2]) + if dim < 0: + dim = dim % len(input_tensor.shape) + if QCOM_AXIS_ORDER in node.meta: + dim = node.meta[QCOM_AXIS_ORDER].index(dim) + if dim != len(input_tensor.shape) - 1: + warnings.warn( + "[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.", + stacklevel=1, + ) + return + + topk_input_tensors = [input_tensor_wrapper] + + output_val_tensor = self.get_tensor(node, node, 0) + output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32) + + # QNN constraint, topk output_0 requires having the same quant config as input + node.meta["quant_attrs"] = input_node.meta.get("quant_attrs") + output_val_tensor_wrapper = self.define_tensor( + node, + output_val_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + # topk output_1 is index, do not quantize it. + node.meta.pop("quant_attrs", None) + output_index_tensor_wrapper = self.define_tensor( + node, + output_idx_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + wrapper_idx=1, + ) + topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper] + + topk_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTopK.op_name, + ) + topk_op.AddInputTensors(topk_input_tensors) + topk_op.AddOutputTensors(topk_output_tensors) + + topk_op.AddScalarParam( + OpTopK.param_k, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {"data": np.uint32(k)}, + ) + + # As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation + if len(node.args) > 3: + largest = cast(bool, node.args[3]) + topk_op.AddScalarParam( + OpTopK.param_largest, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {QCOM_DATA: largest}, + ) + + return topk_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 9c589c7678..307ad4a0c4 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -352,6 +352,13 @@ class OpTile: param_multiples: str = "multiples" +@dataclass(init=False, frozen=True) +class OpTopK: + op_name: str = "TopK" + param_k: str = "k" + param_largest: str = "largest" + + @dataclass(init=False, frozen=True) class OpTranspose: op_name: str = "Transpose" diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index ce8fab5e6f..3c8faee9c0 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Sequence, Set import torch +from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, @@ -190,6 +191,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module model = DecomposeScaledDotProductAttention()(model).graph_module model = DecomposeSilu()(model).graph_module + model = DecomposeEinsum()(model).graph_module model = ReplaceInfBuffer()(model).graph_module return model diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 62407decce..d1ea35fa19 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -415,8 +415,8 @@ def _is_annotated(nodes: List[Node]): return annotated -def _is_input_float_tensor(node: Node): - """Check if the input is not a float tensor, so that we can skip quantization for the node +def _is_float_tensor(node: Node): + """Check if the node's tensor is a float tensor, so that we can skip quantization for the node since observers only works with float Tensors """ if ( @@ -474,7 +474,7 @@ def annotate_single_in_single_out( assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation - if _is_input_float_tensor(node): + if _is_float_tensor(node): node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=quantization_config.output_activation, @@ -482,22 +482,30 @@ def annotate_single_in_single_out( ) +@register_annotator([torch.ops.aten.topk.default]) +def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + # We can use single_in_single_out since we don't want to quantize indices output + annotate_single_in_single_out(node, quantization_config) + + def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return input_act_qspec = quantization_config.input_activation output_act_qspec = ( - quantization_config.output_activation if _is_input_float_tensor(node) else None + quantization_config.output_activation if _is_float_tensor(node) else None ) input_qspec_map = {} input_act0 = node.args[0] - if _is_input_float_tensor(input_act0): + if _is_float_tensor(input_act0): input_qspec_map[input_act0] = input_act_qspec input_act1 = node.args[1] - if _is_input_float_tensor(input_act1): + if _is_float_tensor(input_act1): input_qspec_map[input_act1] = input_act_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( @@ -579,7 +587,7 @@ def _derive_div_qparams_fn( ) input_qspec_map = {} input_act0 = node.args[0] - if _is_input_float_tensor(input_act0): + if _is_float_tensor(input_act0): input_qspec_map[input_act0] = input_act_qspec node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( @@ -864,7 +872,7 @@ def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> Non qscheme=torch.torch.per_tensor_affine, ) - if _is_input_float_tensor(node): + if _is_float_tensor(node): node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=out_act_quantization_spec, @@ -1143,8 +1151,12 @@ def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> @register_annotator([operator.getitem]) def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None: - _annotate_output_qspec(node, quantization_config.output_activation) - _mark_nodes_as_annotated([node]) + if _is_annotated([node]): + return + + if _is_float_tensor(node): + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated([node]) @register_annotator([torch.ops.aten.layer_norm.default]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index ee3d6cf93a..0b3553fd59 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -416,6 +416,17 @@ def forward(self, x): return torch.sum(self.first(x), dim=(2, 3), keepdim=False) +class Conv2dTopK(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3) + + def forward(self, x): + x = self.conv(x) + topk_values, topk_indices = torch.topk(x, 5, dim=1) + return topk_values + + class Div(torch.nn.Module): def __init__(self): super().__init__() @@ -440,6 +451,30 @@ def forward(self, x): return x / 10 +class EinsumBilinear(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, bn, anm, bm): + return torch.einsum("bn,anm,bm->ba", bn, anm, bm) + + +class EinsumOuterProduct(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, i, j): + return torch.einsum("i,j->ij", i, j) + + +class EinsumOuterProductRelu(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, i, j): + return torch.relu(torch.einsum("i,j->ij", i, j)) + + class Embedding(torch.nn.Module): def __init__(self): super().__init__() @@ -978,6 +1013,16 @@ def forward(self, x): return torch.tanh(x) +class TopKandIndex(torch.nn.Module): + def __init__(self): + super().__init__() + self.idx_source = torch.rand(10, 3) + + def forward(self, x): + a, b = torch.topk(x, 3) + return a + self.idx_source[b] + + class Unbind(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index bb90e0eb58..2e1bd0eff3 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -37,6 +37,11 @@ skip_annotation, ) +from executorch.examples.models.llama2.llama_transformer import ( + ModelArgs, + MOEFeedForward, +) + from executorch.examples.qualcomm.utils import setup_common_args_and_variables from executorch.backends.qualcomm.tests.models import * # noqa: F403 @@ -140,6 +145,28 @@ def test_qnn_backend_conv_transpose2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_einsum_outer_product(self): + module = EinsumOuterProduct() # noqa: F405 + x = torch.randn(5) + y = torch.randn(4) + sample_input = ( + x, + y, + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_einsum_bilinear(self): + module = EinsumBilinear() # noqa: F405 + bn = torch.randn(2, 5) + anm = torch.randn(3, 5, 4) + bm = torch.randn(2, 4) + sample_input = ( + bn, + anm, + bm, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_add(self): test_comb = [ { @@ -546,6 +573,34 @@ def test_qnn_backend_conv2d_sum_reduce_dim(self): sample_input = (torch.randn([1, 1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_topk(self): + module = Conv2dTopK() # noqa: F405 + sample_input = (torch.randn(1, 3, 32, 32),) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_einsum_outer_product_relu(self): + module = EinsumOuterProductRelu() # noqa: F405 + x = torch.randn(5) + y = torch.randn(4) + sample_input = ( + x, + y, + ) + self.lower_module_and_test_output(module, sample_input) + + @unittest.skip("Fail because of bad accuracy") + def test_qnn_backend_moe_feed_forward(self): + args = ModelArgs() + args.dim = 32 + args.n_heads = 8 + args.n_layers = 2 + self.head_dim = args.dim // args.n_heads + module = MOEFeedForward(args) # noqa: F405 + sample_input = ( + torch.randint(low=0, high=100, size=(1, 32), dtype=torch.float32), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_unshuffle_math_equivalent(self): module = PixelUnshuffleMathEquivalent(2) # noqa: F405 sample_input = (torch.rand(2, 2, 6, 6),) @@ -561,6 +616,11 @@ def test_qnn_backend_simple_model(self): sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_topk_and_index(self): + module = TopKandIndex() # noqa: F405 + sample_input = (torch.randn(3, 10),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_view_permute_matmul(self): module = ViewPermuteMatMul() # noqa: F405 torch.manual_seed(8) @@ -749,6 +809,30 @@ def test_qnn_backend_conv_transpose2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_einsum_outer_product(self): + module = EinsumOuterProduct() # noqa: F405 + x = torch.randn(5) + y = torch.randn(4) + sample_input = ( + x, + y, + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_einsum_bilinear(self): + module = EinsumBilinear() # noqa: F405 + bn = torch.randn(2, 5) + anm = torch.randn(3, 5, 4) + bm = torch.randn(2, 4) + sample_input = ( + bn, + anm, + bm, + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_element_wise_add(self): test_comb = [ { @@ -1211,6 +1295,37 @@ def test_qnn_backend_conv2d_sum_reduce_dim(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_topk(self): + module = Conv2dTopK() # noqa: F405 + sample_input = (torch.randn(1, 3, 32, 32),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_einsum_outer_product_relu(self): + module = EinsumOuterProductRelu() # noqa: F405 + x = torch.randn(5) + y = torch.randn(4) + sample_input = ( + x, + y, + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + @unittest.skip("UT pass before QNN 2.26, segfault during partitioner") + def test_qnn_backend_moe_feed_forward(self): + args = ModelArgs() + args.dim = 32 + args.n_heads = 8 + args.n_layers = 2 + self.head_dim = args.dim // args.n_heads + module = MOEFeedForward(args) # noqa: F405 + sample_input = ( + torch.randint(low=0, high=100, size=(1, 32), dtype=torch.float32), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_pixel_unshuffle_math_equivalent(self): module = PixelUnshuffleMathEquivalent(2) # noqa: F405 sample_input = (torch.rand(2, 2, 6, 6),) @@ -1229,6 +1344,12 @@ def test_qnn_backend_simple_model(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_topk_and_index(self): + module = TopKandIndex() # noqa: F405 + sample_input = (torch.randn(3, 10),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_view_permute_matmul(self): module = ViewPermuteMatMul() # noqa: F405 torch.manual_seed(8) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index b8230abdc2..c58da42e84 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -166,7 +166,7 @@ def _save_model_and_expected_output( ref_outputs = [] if isinstance(ref_output, collections.OrderedDict): ref_outputs.append(ref_output["out"].detach()) - elif isinstance(ref_output, tuple): + elif isinstance(ref_output, (list, tuple)): for output in ref_output: ref_outputs.append(output.detach()) else: diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 7da1ccb4ed..d93f7fcb4b 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -52,6 +52,7 @@ from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( _soc_info_table, + HtpArch, QcomChipset, QnnExecuTorchBackendOptions, QnnExecuTorchBackendType, @@ -854,6 +855,16 @@ def generate_qnn_executorch_compiler_spec( ] +def get_soc_to_arch_map(): + return { + "SSG2115P": HtpArch.V73, + "SM8650": HtpArch.V75, + "SM8550": HtpArch.V73, + "SM8475": HtpArch.V69, + "SM8450": HtpArch.V69, + } + + def get_soc_to_chipset_map(): return { "SSG2115P": QcomChipset.SSG2115P, diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 25a2d4846f..dd49512c08 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -25,6 +25,7 @@ runtime.python_library( "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_conv_with_clamp", + "//executorch/backends/transforms:fuse_dequant_linear", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index 587485a5c9..4d0858953b 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -45,6 +45,13 @@ def __contains__(self, op): PRIM_OPS = [ operator.getitem, + # Quantization related ops will be fused via graph passes + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, ] SUPPORTS_DYNAMIC_SHAPE = [ diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index d2378274f6..ed566a30cc 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,6 +13,7 @@ FuseBatchNormWithConvPass, ) from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass +from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform @@ -59,6 +60,7 @@ def preprocess( # noqa: C901 passes = [ RemoveCloneOpsTransform(), AddmmToLinearTransform(), + FuseDequantLinearPass(), FuseViewCopyTransform(), FuseBatchNormWithConvPass(program), FuseClampPass(), diff --git a/docs/source/Doxyfile b/docs/source/Doxyfile index e662105b83..0d60bf51c7 100644 --- a/docs/source/Doxyfile +++ b/docs/source/Doxyfile @@ -943,7 +943,8 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../runtime/executor/memory_manager.h \ +INPUT = ../devtools/bundled_program/bundled_program.h \ + ../runtime/executor/memory_manager.h \ ../runtime/executor/method.h \ ../runtime/executor/method_meta.h \ ../runtime/executor/program.h \ diff --git a/docs/source/android-prebuilt-library.md b/docs/source/android-prebuilt-library.md index ee6c2cd65e..5a2d319893 100644 --- a/docs/source/android-prebuilt-library.md +++ b/docs/source/android-prebuilt-library.md @@ -4,18 +4,17 @@ We provide two prebuilt Android libraries (AAR), `executorch.aar` for generic us ## Contents of libraries - `executorch.aar` - - [Java library](https://github.com/pytorch/executorch/tree/release/0.3/extension/android/src/main/java/org/pytorch/executorch) - - JNI contains the JNI binding for [NativePeer.java](https://github.com/pytorch/executorch/blob/release/0.3/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, and Quantized kernels. + - [Java library](https://github.com/pytorch/executorch/tree/main/extension/android/src/main/java/org/pytorch/executorch) + - JNI contains the JNI binding for [NativePeer.java](https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/NativePeer.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, and Quantized kernels. - Comes with two ABI variants, arm64-v8a and x86_64. - `executorch_llama.aar` - - [Java library](https://github.com/pytorch/executorch/tree/release/0.3/extension/android/src/main/java/org/pytorch/executorch) (Note: it contains the same Java classes as the previous Java, but it does not contain the JNI binding for generic Module/NativePeer Java code). - - JNI contains the JNI binding for [LlamaModule.java](https://github.com/pytorch/executorch/blob/release/0.3/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, Quantized kernels, and LLAMA-specific Custom ops library. + - [Java library](https://github.com/pytorch/executorch/tree/main/extension/android/src/main/java/org/pytorch/executorch) (Note: it contains the same Java classes as the previous Java, but it does not contain the JNI binding for generic Module/NativePeer Java code). + - JNI contains the JNI binding for [LlamaModule.java](https://github.com/pytorch/executorch/blob/main/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java) and ExecuTorch native library, including core ExecuTorch runtime libraries, XNNPACK backend, Portable kernels, Optimized kernels, Quantized kernels, and LLAMA-specific Custom ops library. - Comes with two ABI variants, arm64-v8a and x86_64. ## Downloading AAR -[executorch.aar](https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/0.3/executorch.aar) -[executorch_llama.aar](https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/0.3/executorch-llama.aar) -[SHA1SUMS](https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/0.3/SHA1SUMS) +[executorch.aar](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-241002/executorch.aar) +[executorch.aar.sha256sums](https://ossci-android.s3.amazonaws.com/executorch/release/executorch-241002/executorch.aar.sha256sums) ## Using prebuilt libraries @@ -24,7 +23,7 @@ To add the Java library to your app, simply download the AAR, and add it to your In your app working directory, such as example executorch/examples/demo-apps/android/LlamaDemo, ``` mkdir -p app/libs -curl https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/0.3/executorch-llama.aar -o app/libs/executorch.aar +curl https://ossci-android.s3.amazonaws.com/executorch/release/executorch-241002/executorch.aar -o app/libs/executorch-llama.aar ``` And include it in gradle: diff --git a/docs/source/build-run-coreml.md b/docs/source/build-run-coreml.md index 9751dc066f..45a7ecafce 100644 --- a/docs/source/build-run-coreml.md +++ b/docs/source/build-run-coreml.md @@ -147,11 +147,10 @@ libsqlite3.tbd 7. Update the code to load the program from the Application's bundle. ``` objective-c -using namespace torch::executor; - NSURL *model_url = [NBundle.mainBundle URLForResource:@"mv3_coreml_all" extension:@"pte"]; -Result loader = util::FileDataLoader::from(model_url.path.UTF8String); +Result loader = + executorch::extension::FileDataLoader::from(model_url.path.UTF8String); ``` 8. Use [Xcode](https://developer.apple.com/documentation/xcode/building-and-running-an-app#Build-run-and-debug-your-app) to deploy the application on the device. diff --git a/docs/source/bundled-io.md b/docs/source/bundled-io.md index 74c86cb8cc..08961db776 100644 --- a/docs/source/bundled-io.md +++ b/docs/source/bundled-io.md @@ -201,14 +201,14 @@ This stage mainly focuses on executing the model with the bundled inputs and and ### Get ExecuTorch Program Pointer from `BundledProgram` Buffer We need the pointer to ExecuTorch program to do the execution. To unify the process of loading and executing `BundledProgram` and Program flatbuffer, we create an API: -:::{dropdown} `GetProgramData` +:::{dropdown} `get_program_data` ```{eval-rst} -.. doxygenfunction:: torch::executor::bundled_program::GetProgramData +.. doxygenfunction:: ::executorch::bundled_program::get_program_data ``` ::: -Here's an example of how to use the `GetProgramData` API: +Here's an example of how to use the `get_program_data` API: ```c++ // Assume that the user has read the contents of the file into file_data using // whatever method works best for their application. The file could contain @@ -216,36 +216,36 @@ Here's an example of how to use the `GetProgramData` API: void* file_data = ...; size_t file_data_len = ...; -// If file_data contains a BundledProgram, GetProgramData() will return a +// If file_data contains a BundledProgram, get_program_data() will return a // pointer to the Program data embedded inside it. Otherwise it will return // file_data, which already pointed to Program data. const void* program_ptr; size_t program_len; -status = torch::executor::bundled_program::GetProgramData( +status = executorch::bundled_program::get_program_data( file_data, file_data_len, &program_ptr, &program_len); ET_CHECK_MSG( status == Error::Ok, - "GetProgramData() failed with status 0x%" PRIx32, + "get_program_data() failed with status 0x%" PRIx32, status); ``` ### Load Bundled Input to Method -To execute the program on the bundled input, we need to load the bundled input into the method. Here we provided an API called `torch::executor::bundled_program::LoadBundledInput`: +To execute the program on the bundled input, we need to load the bundled input into the method. Here we provided an API called `executorch::bundled_program::load_bundled_input`: -:::{dropdown} `LoadBundledInput` +:::{dropdown} `load_bundled_input` ```{eval-rst} -.. doxygenfunction:: torch::executor::bundled_program::LoadBundledInput +.. doxygenfunction:: ::executorch::bundled_program::load_bundled_input ``` ::: ### Verify the Method's Output. -We call `torch::executor::bundled_program::VerifyResultWithBundledExpectedOutput` to verify the method's output with bundled expected outputs. Here's the details of this API: +We call `executorch::bundled_program::verify_method_outputs` to verify the method's output with bundled expected outputs. Here's the details of this API: -:::{dropdown} `VerifyResultWithBundledExpectedOutput` +:::{dropdown} `verify_method_outputs` ```{eval-rst} -.. doxygenfunction:: torch::executor::bundled_program::VerifyResultWithBundledExpectedOutput +.. doxygenfunction:: ::executorch::bundled_program::verify_method_outputs ``` ::: @@ -266,13 +266,13 @@ ET_CHECK_MSG( method.error()); // Load testset_idx-th input in the buffer to plan -status = torch::executor::bundled_program::LoadBundledInput( +status = executorch::bundled_program::load_bundled_input( *method, program_data.bundled_program_data(), FLAGS_testset_idx); ET_CHECK_MSG( status == Error::Ok, - "LoadBundledInput failed with status 0x%" PRIx32, + "load_bundled_input failed with status 0x%" PRIx32, status); // Execute the plan @@ -283,7 +283,7 @@ ET_CHECK_MSG( status); // Verify the result. -status = torch::executor::bundled_program::VerifyResultWithBundledExpectedOutput( +status = executorch::bundled_program::verify_method_outputs( *method, program_data.bundled_program_data(), FLAGS_testset_idx, diff --git a/docs/source/concepts.md b/docs/source/concepts.md index c085505b61..289ecda6d8 100644 --- a/docs/source/concepts.md +++ b/docs/source/concepts.md @@ -26,7 +26,7 @@ The goal of ATen dialect is to capture users’ programs as faithfully as possib ## ATen mode -ATen mode uses the ATen implementation of Tensor (`at::Tensor`) and related types, such as `ScalarType`, from the PyTorch core. This is in contrast to portable mode, which uses ExecuTorch’s smaller implementation of tensor (`torch::executor::Tensor`) and related types, such as `torch::executor::ScalarType`. +ATen mode uses the ATen implementation of Tensor (`at::Tensor`) and related types, such as `ScalarType`, from the PyTorch core. This is in contrast to ETensor mode, which uses ExecuTorch’s smaller implementation of tensor (`executorch::runtime::etensor::Tensor`) and related types, such as `executorch::runtime::etensor::ScalarType`. - ATen kernels that rely on the full `at::Tensor` API are usable in this configuration. - ATen kernels tend to do dynamic memory allocation and often have extra flexibility (and thus overhead) to handle cases not needed by mobile/embedded clients. e.g., CUDA support, sparse tensor support, and dtype promotion. - Note: ATen mode is currently a WIP. @@ -244,10 +244,10 @@ Kernels that support a subset of tensor dtypes and/or dim orders. Parts of a model may be delegated to run on an optimized backend. The partitioner splits the graph into the appropriate sub-networks and tags them for delegation. -## Portable mode (lean mode) +## ETensor mode -Portable mode uses ExecuTorch’s smaller implementation of tensor (`torch::executor::Tensor`) along with related types (`torch::executor::ScalarType`, etc.). This is in contrast to ATen mode, which uses the ATen implementation of Tensor (`at::Tensor`) and related types (`ScalarType`, etc.) -- `torch::executor::Tensor`, also known as ETensor, is a source-compatible subset of `at::Tensor`. Code written against ETensor can build against `at::Tensor`. +ETensor mode uses ExecuTorch’s smaller implementation of tensor (`executorch::runtime::etensor::Tensor`) along with related types (`executorch::runtime::etensor::ScalarType`, etc.). This is in contrast to ATen mode, which uses the ATen implementation of Tensor (`at::Tensor`) and related types (`ScalarType`, etc.) +- `executorch::runtime::etensor::Tensor`, also known as ETensor, is a source-compatible subset of `at::Tensor`. Code written against ETensor can build against `at::Tensor`. - ETensor does not own or allocate memory on its own. To support dynamic shapes, kernels can allocate Tensor data using the MemoryAllocator provided by the client. ## Portable kernels diff --git a/docs/source/etdump.md b/docs/source/etdump.md index 42391cf40e..13957f1c01 100644 --- a/docs/source/etdump.md +++ b/docs/source/etdump.md @@ -15,7 +15,7 @@ Generating an ETDump is a relatively straightforward process. Users can follow t 2. ***Create*** an Instance of the ETDumpGen class and pass it into the `load_method` call that is invoked in the runtime. ```C++ -torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen(); +executorch::etdump::ETDumpGen etdump_gen; Result method = program->load_method(method_name, &memory_manager, &etdump_gen); ``` diff --git a/docs/source/executorch-runtime-api-reference.rst b/docs/source/executorch-runtime-api-reference.rst index ef9a99e593..5bec597987 100644 --- a/docs/source/executorch-runtime-api-reference.rst +++ b/docs/source/executorch-runtime-api-reference.rst @@ -11,25 +11,25 @@ For detailed information on how APIs evolve and the deprecation process, please Model Loading and Execution --------------------------- -.. doxygenclass:: executorch::runtime::DataLoader +.. doxygenclass:: executorch::runtime::Program :members: -.. doxygenclass:: executorch::runtime::MemoryAllocator +.. doxygenclass:: executorch::runtime::Method :members: -.. doxygenclass:: executorch::runtime::HierarchicalAllocator +.. doxygenclass:: executorch::runtime::MethodMeta :members: -.. doxygenclass:: executorch::runtime::MemoryManager +.. doxygenclass:: executorch::runtime::DataLoader :members: -.. doxygenclass:: executorch::runtime::Program +.. doxygenclass:: executorch::runtime::MemoryAllocator :members: -.. doxygenclass:: executorch::runtime::Method +.. doxygenclass:: executorch::runtime::HierarchicalAllocator :members: -.. doxygenclass:: executorch::runtime::MethodMeta +.. doxygenclass:: executorch::runtime::MemoryManager :members: Values @@ -38,5 +38,5 @@ Values .. doxygenstruct:: executorch::runtime::EValue :members: -.. doxygenclass:: executorch::aten::Tensor +.. doxygenclass:: executorch::runtime::etensor::Tensor :members: diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index 71b13598ed..2cdf13ca65 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -43,14 +43,13 @@ cd et-nanogpt # Clone the ExecuTorch repository and submodules. mkdir third-party -git clone -b release/0.2 https://github.com/pytorch/executorch.git third-party/executorch +git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch cd third-party/executorch git submodule update --init # Create a conda environment and install requirements. conda create -yn executorch python=3.10.0 conda activate executorch -pip install cmake zstd ./install_requirements.sh cd ../.. @@ -77,12 +76,11 @@ pyenv activate executorch # Clone the ExecuTorch repository and submodules. mkdir third-party -git clone -b release/0.2 https://github.com/pytorch/executorch.git third-party/executorch +git clone -b release/0.4 https://github.com/pytorch/executorch.git third-party/executorch cd third-party/executorch git submodule update --init # Install requirements. -pip install cmake zstd PYTHON_EXECUTABLE=python ./install_requirements.sh cd ../.. @@ -208,8 +206,8 @@ Create a file called main.cpp with the following contents: #include #include -using exec_aten::ScalarType; -using exec_aten::Tensor; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; using executorch::extension::from_blob; using executorch::extension::Module; using executorch::runtime::EValue; @@ -235,56 +233,56 @@ std::string generate( BasicSampler& sampler, size_t max_input_length, size_t max_output_length) { + // Convert the input text into a list of integers (tokens) that represents it, + // using the string-to-token mapping that the model was trained on. Each token + // is an integer that represents a word or part of a word. + std::vector input_tokens = tokenizer.encode(prompt); + std::vector output_tokens; + + for (auto i = 0u; i < max_output_length; i++) { + // Convert the input_tokens from a vector of int64_t to EValue. EValue is a + // unified data type in the ExecuTorch runtime. + auto inputs = from_blob( + input_tokens.data(), + {1, static_cast(input_tokens.size())}, + ScalarType::Long); + + // Run the model. It will return a tensor of logits (log-probabilities). + auto logits_evalue = llm_model.forward(inputs); + + // Convert the output logits from EValue to std::vector, which is what the + // sampler expects. + Tensor logits_tensor = logits_evalue.get()[0].toTensor(); + std::vector logits( + logits_tensor.data_ptr(), + logits_tensor.data_ptr() + logits_tensor.numel()); + + // Sample the next token from the logits. + int64_t next_token = sampler.sample(logits); + + // Break if we reached the end of the text. + if (next_token == ENDOFTEXT_TOKEN) { + break; + } + + // Add the next token to the output. + output_tokens.push_back(next_token); + + std::cout << tokenizer.decode({next_token}); + std::cout.flush(); - // Convert the input text into a list of integers (tokens) that represents - // it, using the string-to-token mapping that the model was trained on. - // Each token is an integer that represents a word or part of a word. - std::vector input_tokens = tokenizer.encode(prompt); - std::vector output_tokens; - - for (auto i = 0u; i < max_output_length; i++) { - // Convert the input_tokens from a vector of int64_t to EValue. - // EValue is a unified data type in the ExecuTorch runtime. - auto inputs = from_blob( - input_tokens.data(), - {1, static_cast(input_tokens.size())}, - ScalarType::Long); - - // Run the model. It will return a tensor of logits (log-probabilities). - auto logits_evalue = llm_model.forward(inputs); - - // Convert the output logits from EValue to std::vector, which is what - // the sampler expects. - Tensor logits_tensor = logits_evalue.get()[0].toTensor(); - std::vector logits(logits_tensor.data_ptr(), - logits_tensor.data_ptr() + logits_tensor.numel()); - - // Sample the next token from the logits. - int64_t next_token = sampler.sample(logits); - - // Break if we reached the end of the text. - if (next_token == ENDOFTEXT_TOKEN) { - break; - } - - // Add the next token to the output. - output_tokens.push_back(next_token); - - std::cout << tokenizer.decode({ next_token }); - std::cout.flush(); - - // Update next input. - input_tokens.push_back(next_token); - if (input_tokens.size() > max_input_length) { - input_tokens.erase(input_tokens.begin()); - } + // Update next input. + input_tokens.push_back(next_token); + if (input_tokens.size() > max_input_length) { + input_tokens.erase(input_tokens.begin()); } + } - std::cout << std::endl; + std::cout << std::endl; - // Convert the output tokens into a human-readable string. - std::string output_string = tokenizer.decode(output_tokens); - return output_string; + // Convert the output tokens into a human-readable string. + std::string output_string = tokenizer.decode(output_tokens); + return output_string; } ``` @@ -309,32 +307,32 @@ penalties for repeated tokens, and biases to prioritize or de-prioritize specifi ```cpp // main.cpp -using namespace torch::executor; - int main() { - // Set up the prompt. This provides the seed text for the model to elaborate. - std::cout << "Enter model prompt: "; - std::string prompt; - std::getline(std::cin, prompt); - - // The tokenizer is used to convert between tokens (used by the model) and - // human-readable strings. - BasicTokenizer tokenizer("vocab.json"); - - // The sampler is used to sample the next token from the logits. - BasicSampler sampler = BasicSampler(); - - // Load the exported nanoGPT program, which was generated via the previous steps. - Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors); - - const auto max_input_tokens = 1024; - const auto max_output_tokens = 30; - std::cout << prompt; - generate(model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens); + // Set up the prompt. This provides the seed text for the model to elaborate. + std::cout << "Enter model prompt: "; + std::string prompt; + std::getline(std::cin, prompt); + + // The tokenizer is used to convert between tokens (used by the model) and + // human-readable strings. + BasicTokenizer tokenizer("vocab.json"); + + // The sampler is used to sample the next token from the logits. + BasicSampler sampler = BasicSampler(); + + // Load the exported nanoGPT program, which was generated via the previous + // steps. + Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors); + + const auto max_input_tokens = 1024; + const auto max_output_tokens = 30; + std::cout << prompt; + generate( + model, prompt, tokenizer, sampler, max_input_tokens, max_output_tokens); } ``` -Finally, download the following files into the same directory as main.h: +Finally, download the following files into the same directory as main.cpp: ``` curl -O https://raw.githubusercontent.com/pytorch/executorch/main/examples/llm_manual/basic_sampler.h @@ -368,17 +366,19 @@ option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON) # Include the executorch subdirectory. add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch - ${CMAKE_BINARY_DIR}/third-party/executorch) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch + ${CMAKE_BINARY_DIR}/executorch +) add_executable(nanogpt_runner main.cpp) target_link_libraries( - nanogpt_runner - PRIVATE - executorch - extension_module_static # Provides the Module class - extension_tensor # Provides the TensorPtr class - optimized_native_cpu_ops_lib) # Provides baseline cross-platform kernels + nanogpt_runner + PRIVATE executorch + extension_module_static # Provides the Module class + extension_tensor # Provides the TensorPtr class + optimized_native_cpu_ops_lib # Provides baseline cross-platform + # kernels +) ``` At this point, the working directory should contain the following files: @@ -524,20 +524,20 @@ option(EXECUTORCH_BUILD_XNNPACK "" ON) # Build with Xnnpack backend # Include the executorch subdirectory. add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch - ${CMAKE_BINARY_DIR}/executorch) - -# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) + ${CMAKE_CURRENT_SOURCE_DIR}/third-party/executorch + ${CMAKE_BINARY_DIR}/executorch +) add_executable(nanogpt_runner main.cpp) target_link_libraries( - nanogpt_runner - PRIVATE - executorch - extension_module_static # Provides the Module class - extension_tensor # Provides the TensorPtr class - optimized_native_cpu_ops_lib # Provides baseline cross-platform kernels - xnnpack_backend) # Provides the XNNPACK CPU acceleration backend + nanogpt_runner + PRIVATE executorch + extension_module_static # Provides the Module class + extension_tensor # Provides the TensorPtr class + optimized_native_cpu_ops_lib # Provides baseline cross-platform + # kernels + xnnpack_backend # Provides the XNNPACK CPU acceleration backend +) ``` Keep the rest of the code the same. For more details refer to [Exporting diff --git a/docs/source/running-a-model-cpp-tutorial.md b/docs/source/running-a-model-cpp-tutorial.md index edd5d95086..70e17c8222 100644 --- a/docs/source/running-a-model-cpp-tutorial.md +++ b/docs/source/running-a-model-cpp-tutorial.md @@ -24,14 +24,25 @@ Users can define their own `DataLoader`s to fit the needs of their particular sy For the `FileDataLoader` all we need to do is provide a file path to the constructor. ``` cpp -using namespace torch::executor; - -Result loader = - util::FileDataLoader::from("/tmp/model.pte"); +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::extension::FileDataLoader; +using executorch::extension::MallocMemoryAllocator; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::HierarchicalAllocator; +using executorch::runtime::MemoryManager; +using executorch::runtime::Method; +using executorch::runtime::MethodMeta; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::Span; + +Result loader = + FileDataLoader::from("/tmp/model.pte"); assert(loader.ok()); -Result program = - torch::executor::Program::load(&loader.get()); +Result program = Program::load(&loader.get()); assert(program.ok()); ``` @@ -48,14 +59,13 @@ One of the principles of ExecuTorch is giving users control over where the memor For this example we will retrieve the size of the planned memory arenas dynamically from the `Program`, but for heapless environments users could retrieve this information from the `Program` ahead of time and allocate the arena statically. We will also be using a malloc based allocator for the method allocator. ``` cpp - -// Method names map back to Python nn.Module method names. Most users will only have the singular method "forward". +// Method names map back to Python nn.Module method names. Most users will only +// have the singular method "forward". const char* method_name = "forward"; // MethodMeta is a lightweight structure that lets us gather metadata -// information about a specific method. In this case we are looking to -// get the required size of the memory planned buffers for the method -// "forward". +// information about a specific method. In this case we are looking to get the +// required size of the memory planned buffers for the method "forward". Result method_meta = program->method_meta(method_name); assert(method_meta.ok()); @@ -64,7 +74,8 @@ std::vector> planned_arenas; // Passed to the allocator size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers(); -// It is possible to have multiple layers in our memory hierarchy; for example, SRAM and DRAM. +// It is possible to have multiple layers in our memory hierarchy; for example, +// SRAM and DRAM. for (size_t id = 0; id < num_memory_planned_buffers; ++id) { // .get() will always succeed because id < num_memory_planned_buffers. size_t buffer_size = @@ -75,12 +86,12 @@ for (size_t id = 0; id < num_memory_planned_buffers; ++id) { HierarchicalAllocator planned_memory( {planned_arenas.data(), planned_arenas.size()}); -// Version of MemoryAllocator that uses malloc to handle allocations -// rather then a fixed buffer. -util::MallocMemoryAllocator method_allocator; +// Version of MemoryAllocator that uses malloc to handle allocations rather then +// a fixed buffer. +MallocMemoryAllocator method_allocator; -// Assemble all of the allocators into the MemoryManager that the Executor -// will use. +// Assemble all of the allocators into the MemoryManager that the Executor will +// use. MemoryManager memory_manager(&method_allocator, &planned_memory); ``` diff --git a/examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh b/examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh index 61c86a92f7..34cf910746 100644 --- a/examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh +++ b/examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh @@ -7,14 +7,14 @@ set -eu -AAR_URL="https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/0.3/executorch-llama.aar" -AAR_SHASUM="f06cc1606e5e05f00fd0ae721f5d37d56124fd28" - +AAR_URL="https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/executorch-241002/executorch.aar" +AAR_SHASUM_URL="https://ossci-android.s3.us-west-1.amazonaws.com/executorch/release/executorch-241002/executorch.aar.sha256sums" LIBS_PATH="$(dirname "$0")/app/libs" -AAR_PATH="${LIBS_PATH}/executorch-llama.aar" mkdir -p "$LIBS_PATH" -if [[ ! -f "${AAR_PATH}" || "${AAR_SHASUM}" != $(shasum "${AAR_PATH}" | awk '{print $1}') ]]; then - curl "${AAR_URL}" -o "${AAR_PATH}" -fi +pushd "$LIBS_PATH" +curl -O "${AAR_SHASUM_URL}" +sed -i -e 's/executorch.aar/executorch-llama.aar/g' executorch.aar.sha256sums +shasum --check --status executorch.aar.sha256sums || curl "${AAR_URL}" -o executorch-llama.aar +popd diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift index 00fc4f6f54..69ece27f67 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift @@ -176,6 +176,7 @@ struct ContentView: View { .padding([.leading, .trailing, .bottom], 10) .sheet(isPresented: $isImagePickerPresented, onDismiss: addSelectedImageMessage) { ImagePicker(selectedImage: $selectedImage, sourceType: imagePickerSourceType) + .id(imagePickerSourceType.rawValue) } } .navigationBarTitle(title, displayMode: .inline) diff --git a/examples/llm_manual/CMakeLists.txt b/examples/llm_manual/CMakeLists.txt index e5054a683a..1283eb548e 100644 --- a/examples/llm_manual/CMakeLists.txt +++ b/examples/llm_manual/CMakeLists.txt @@ -23,8 +23,6 @@ add_subdirectory( ${CMAKE_BINARY_DIR}/executorch ) -# include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) - add_executable(nanogpt_runner main.cpp) target_link_libraries( nanogpt_runner @@ -33,5 +31,5 @@ target_link_libraries( extension_tensor # Provides the TensorPtr class optimized_native_cpu_ops_lib # Provides baseline cross-platform # kernels - xnnpack_backend -) # Provides the XNNPACK CPU acceleration backend + xnnpack_backend # Provides the XNNPACK CPU acceleration backend +) diff --git a/examples/llm_manual/main.cpp b/examples/llm_manual/main.cpp index 5bd318983b..76492513f9 100644 --- a/examples/llm_manual/main.cpp +++ b/examples/llm_manual/main.cpp @@ -17,15 +17,15 @@ #include #include -using exec_aten::ScalarType; -using exec_aten::Tensor; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; using executorch::extension::from_blob; using executorch::extension::Module; using executorch::runtime::EValue; using executorch::runtime::Result; // The value of the gpt2 `<|endoftext|>` token. -#define ENDOFTEXT 50256 +#define ENDOFTEXT_TOKEN 50256 std::string generate( Module& llm_model, @@ -34,15 +34,15 @@ std::string generate( BasicSampler& sampler, size_t max_input_length, size_t max_output_length) { - // Convert the input text into a list of integers (tokens) that represents - // it, using the string-to-token mapping that the model was trained on. - // Each token is an integer that represents a word or part of a word. + // Convert the input text into a list of integers (tokens) that represents it, + // using the string-to-token mapping that the model was trained on. Each token + // is an integer that represents a word or part of a word. std::vector input_tokens = tokenizer.encode(prompt); std::vector output_tokens; for (auto i = 0u; i < max_output_length; i++) { - // Convert the input_tokens from a vector of int64_t to EValue. - // EValue is a unified data type in the ExecuTorch runtime. + // Convert the input_tokens from a vector of int64_t to EValue. EValue is a + // unified data type in the ExecuTorch runtime. auto inputs = from_blob( input_tokens.data(), {1, static_cast(input_tokens.size())}, @@ -51,8 +51,8 @@ std::string generate( // Run the model. It will return a tensor of logits (log-probabilities). auto logits_evalue = llm_model.forward(inputs); - // Convert the output logits from EValue to std::vector, which is what - // the sampler expects. + // Convert the output logits from EValue to std::vector, which is what the + // sampler expects. Tensor logits_tensor = logits_evalue.get()[0].toTensor(); std::vector logits( logits_tensor.data_ptr(), @@ -62,7 +62,7 @@ std::string generate( int64_t next_token = sampler.sample(logits); // Break if we reached the end of the text. - if (next_token == ENDOFTEXT) { + if (next_token == ENDOFTEXT_TOKEN) { break; } @@ -86,11 +86,9 @@ std::string generate( return output_string; } -// main.cpp - int main() { // Set up the prompt. This provides the seed text for the model to elaborate. - std::cout << "Prompt: "; + std::cout << "Enter model prompt: "; std::string prompt; std::getline(std::cin, prompt); diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 023955ab32..1a6fe99fc4 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -49,7 +49,7 @@ We employed 4-bit groupwise per token dynamic quantization of all the linear lay We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for WikiText from other sources if they implement it differntly. More details could be found [here](https://github.com/EleutherAI/lm-evaluation-harness/issues/2301). -Below are the results for two different groupsizes, with max_seq_len 2048, and 1000 samples. +Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000. |Model | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256) |--------|-----------------| ---------------------- | --------------- @@ -280,12 +280,32 @@ tokenizer.path=/tokenizer.model > Forewarning: Model evaluation without a GPU may take a long time, especially on larger models. -Using the same arguments from above +We use [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate model accuracy. + +For base models, use the following example command to calculate its perplexity based on WikiText. ``` -python -m examples.models.llama2.eval_llama -c -p -t -d fp32 --max_seq_len --limit +python -m examples.models.llama2.eval_llama \ + -c \ + -p \ + -t \ + -kv \ + -d \ + --max_seq_len \ + --limit ``` -The Wikitext results generated above used: `{max_seq_len: 2048, limit: 1000}` +For instruct models, use the following example command to calculate its MMLU score. +``` +python -m examples.models.llama2.eval_llama \ + -c \ + -p \ + -t \ + -kv \ + -d \ + --tasks mmlu \ + --num_fewshot 5 \ + --max_seq_len +``` ## Step 4: Run on your computer to validate diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 3061d290bd..57258fdbc1 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -21,8 +21,9 @@ ) from executorch.extension.llm.tokenizer.utils import get_tokenizer from lm_eval.api.model import LM +from lm_eval.evaluator import simple_evaluate -from .evaluate.eager_eval import EagerEvalWrapper, evaluate_model +from .evaluate.eager_eval import EagerEvalWrapper from .export_llama_lib import ( _prepare_for_llama_export, @@ -246,9 +247,19 @@ def build_args_parser() -> argparse.ArgumentParser: help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( - "--limit", type=int, default=5, help="number of samples to evalulate" + "--limit", + type=int, + default=None, + help="number of samples to evalulate. If not set, evaluate all samples", + ) + parser.add_argument( + "-f", + "--num_fewshot", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", ) - # Add additional args specific to eval via an ET Runner # Note: For initial integration, the tokenizer.model is also required parser.add_argument( @@ -281,11 +292,13 @@ def eval_llama( eval_wrapper = gen_eval_wrapper(model_name, args) # Evaluate the model - eval_results = evaluate_model( - eval_wrapper, - args.tasks, # pyre-ignore - args.limit, # pyre-ignore - ) + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=args.tasks, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` + num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` + limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` + ) for task, res in eval_results["results"].items(): print(f"{task}: {res}") diff --git a/examples/models/llama2/evaluate/__init__.py b/examples/models/llama2/evaluate/__init__.py index c7f7d2be6e..f1ba33ccbb 100644 --- a/examples/models/llama2/evaluate/__init__.py +++ b/examples/models/llama2/evaluate/__init__.py @@ -4,9 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .eager_eval import EagerEvalWrapper, evaluate_model +from .eager_eval import EagerEvalWrapper __all__ = [ - "evaluate_model", "EagerEvalWrapper", ] diff --git a/examples/models/llama2/evaluate/eager_eval.py b/examples/models/llama2/evaluate/eager_eval.py index cd845f577e..784112e052 100644 --- a/examples/models/llama2/evaluate/eager_eval.py +++ b/examples/models/llama2/evaluate/eager_eval.py @@ -7,17 +7,13 @@ from typing import Optional, Union -import lm_eval import torch from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.extension.llm.tokenizer.tokenizer import ( Tokenizer as SentencePieceTokenizer, ) -from lm_eval.api.model import LM -from lm_eval.evaluator import evaluate from lm_eval.models.huggingface import HFLM as eval_wrapper -from lm_eval.tasks import get_task_dict from torch import nn @@ -79,39 +75,3 @@ def _model_call(self, inps): def _model_generate(self, context, max_length, eos_token_id): raise Exception("unimplemented") - - -@torch.no_grad() -def evaluate_model( - eval_wrapper: LM, - tasks: Optional[list] = None, - limit: Optional[int] = None, -) -> dict: - """ - Evaluates a language model on a specified task using the lm-evaluation-harness library. - - Args: - eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation - tasks: Optional[list]: The names of the evaluation tasks to perform. - limit (Optional[int]): The maximum number of samples to evaluate (None for all available). - - Returns: - eval_results (dict): A dictionary of evaluation results for the specified task(s). - """ - - if tasks is None: - tasks = ["wikitext"] - - if "hendrycks_test" in tasks: - tasks.remove("hendrycks_test") - tasks += list( - lm_eval.tasks.hendrycks_test.create_all_tasks().keys() # pyre-ignore - ) - task_dict = get_task_dict(tasks) - - eval_results = evaluate( - eval_wrapper, - task_dict, - limit=limit, - ) - return eval_results diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index dce864adfe..cdac837fd2 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -41,6 +41,7 @@ get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, + get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace @@ -147,6 +148,7 @@ def build_args_parser() -> argparse.ArgumentParser: "coreml_8a_c4w", "coreml_baseline_8a_c8w", "coreml_baseline_8a_c4w", + "vulkan_8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -548,6 +550,12 @@ def get_quantizer_and_quant_params(args): assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) quantizers.append(coreml_quantizer) + if args.vulkan and args.pt2e_quantize: + assert ( + len(quantizers) == 0 + ), "Should not enable both vulkan and other quantizers" + vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) + quantizers.append(vulkan_quantizer) logging.info(f"Applying quantizers: {quantizers}") return pt2e_quant_params, quantizers, quant_dtype diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 2e6cb348b0..f10babc5bb 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -30,7 +30,7 @@ capture_program, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, - get_soc_to_chipset_map, + get_soc_to_arch_map, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from executorch.exir.backend.backend_api import to_backend @@ -83,7 +83,7 @@ def __init__( self.dump_intermediate_outputs = dump_intermediate_outputs self.debug_output_path = f"{self.workspace}/debug_output.bin" self.output_folder = f"{self.workspace}/outputs" - self.soc_model = get_soc_to_chipset_map()[soc_model] + self.soc_model = get_soc_to_arch_map()[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer self.runner = runner diff --git a/examples/third-party/fbjni b/examples/third-party/fbjni deleted file mode 160000 index 52a14f0daa..0000000000 --- a/examples/third-party/fbjni +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 52a14f0daa889a20d8984798b8d96eb03cebd334 diff --git a/exir/program/test/test_fake_program.py b/exir/program/test/test_fake_program.py index edccde43c2..15959efde4 100644 --- a/exir/program/test/test_fake_program.py +++ b/exir/program/test/test_fake_program.py @@ -62,8 +62,8 @@ def test_fake_program(self) -> None: self.assertEqual(exported_program.verifier, fake_program.verifier) self.assertEqual(id(exported_program.verifier), id(fake_program.verifier)) - # Fake program uses fake tensors for the state dict. Size should be smaller. - self.assertLess( + # Fake program uses fake tensors for the state dict. Size should be not be larger. + self.assertLessEqual( sys.getsizeof(fake_program.state_dict), sys.getsizeof(exported_program.state_dict), ) diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 075e4f21a4..cc08071611 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -21,9 +21,40 @@ include(${EXECUTORCH_ROOT}/build/Utils.cmake) set(_common_compile_options -Wno-deprecated-declarations -fPIC) set(_common_include_directories ${EXECUTORCH_ROOT}/..) -add_subdirectory( - ${EXECUTORCH_ROOT}/examples/third-party/fbjni - ${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni +# We need to download fbjni library from maven, and use its "prefab" library +# and headers, and link executorch library against that fbjni library. +# We don't know which NDK is used to compile fbjni, and we need to link our +# executorch library to the version which Android APK links against for runtime +# to ensure the libc++ dependencies are consistent. +# WARNING # +# Users need to use the SAME fbjni version here and in app gradle dependency +# for runtime compatibility! +if(NOT FBJNI_VERSION) + set(FBJNI_VERSION 0.5.1) +endif() + +set(FBJNI_AAR_URL https://repo1.maven.org/maven2/com/facebook/fbjni/fbjni/${FBJNI_VERSION}/fbjni-${FBJNI_VERSION}.aar) +set(FBJNI_DOWNLOAD_PATH ${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/fbjni.aar) + +if(NOT EXISTS "${FBJNI_DOWNLOAD_PATH}") + file(DOWNLOAD "${FBJNI_AAR_URL}" "${FBJNI_DOWNLOAD_PATH}") +endif() + +add_custom_command( + OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/include/" "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/libs/android.${ANDROID_ABI}/libfbjni.so" + COMMAND unzip -o ${FBJNI_DOWNLOAD_PATH} -d ${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni + DEPENDS "${FBJNI_DOWNLOAD_PATH}" +) + +add_custom_target( + fbjni_prefab + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/include/" "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/libs/android.${ANDROID_ABI}/libfbjni.so" +) + +add_library(fbjni SHARED IMPORTED) +add_dependencies(fbjni fbjni_prefab) +set_target_properties(fbjni PROPERTIES + IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/libs/android.${ANDROID_ABI}/libfbjni.so" ) set(executorch_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../lib/cmake/ExecuTorch) @@ -128,6 +159,7 @@ endif() target_include_directories( executorch_jni PRIVATE ${_common_include_directories} + "${CMAKE_CURRENT_BINARY_DIR}/third-party/fbjni/prefab/modules/fbjni/include/" ) target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index e4ade20228..866cfe56ea 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -31,6 +31,7 @@ runtime.python_library( "//executorch/backends/qualcomm/quantizer:quantizer", "//executorch/backends/transforms:duplicate_dynamic_quant_chain", "//executorch/backends/vulkan/partitioner:vulkan_partitioner", + "//executorch/backends/vulkan/quantizer:vulkan_quantizer", "//executorch/backends/xnnpack/partition:xnnpack_partitioner", "//executorch/exir:lib", "//executorch/exir/backend:backend_details", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 11d92f32f5..37d90f8595 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -29,10 +29,10 @@ from executorch.extension.export_util.utils import export_to_edge, save_pte_program from executorch.extension.llm.tokenizer.utils import get_tokenizer -from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.export import export_for_training from torch.nn.attention import SDPBackend FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -193,12 +193,12 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": strict=True, ).module() else: - self.pre_autograd_graph_module = capture_pre_autograd_graph( + self.pre_autograd_graph_module = export_for_training( self.model, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, - ) + ).module() return self diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 30701e4fa5..fd368d73f1 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -260,3 +260,22 @@ def get_coreml_quantizer(pt2e_quantize: str): raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}") return quantizer + + +def get_vulkan_quantizer(pt2e_quantize: str): + from executorch.backends.vulkan.quantizer.vulkan_quantizer import ( + get_weight_quantization_config, + VulkanQuantizer, + ) + + if pt2e_quantize == "vulkan_8w": + config = get_weight_quantization_config( + is_per_channel=True, + weight_qmin=-128, + weight_qmax=127, + ) + else: + raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}") + + quantizer = VulkanQuantizer().set_global(config) + return quantizer