Skip to content

Commit

Permalink
2024-10-31 nightly release (f813b6a)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 31, 2024
1 parent c464a71 commit 7072212
Show file tree
Hide file tree
Showing 75 changed files with 3,987 additions and 434 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/_android.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ jobs:
# avoid permission issue
sudo chown -R "${USER}" /opt/android
- name: Download Artifacts
shell: bash
run: |
set -eux
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug.apk
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug-androidTest.apk
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/fp32-xnnpack-custom/model.zip
unzip model.zip
mv *.pte model.pte
- name: Gradle cache
uses: gradle/actions/setup-gradle@v3

Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ jobs:
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 900
upload-artifact: android-models
upload-artifact-to-s3: true
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
Expand All @@ -107,13 +109,15 @@ jobs:
DTYPE=${{ matrix.dtype }}
BUILD_TOOL="cmake"
MODE=${{ matrix.mode }}
ARTIFACTS_DIR_NAME="artifacts-to-be-uploaded/${DTYPE}-${MODE}"
ARTIFACTS_DIR_NAME="${ARTIFACTS_DIR_NAME/+/-}"
# Setup executorch
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "${BUILD_TOOL}"
# Install requirements for export_llama
PYTHON_EXECUTABLE=python bash examples/models/llama/install_requirements.sh
# Test llama2
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}"
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}" "${ARTIFACTS_DIR_NAME}"
test-llama-runner-linux-android:
name: test-llama-runner-linux-android
Expand Down Expand Up @@ -320,6 +324,7 @@ jobs:
android:
uses: ./.github/workflows/_android.yml
needs: test-llama-runner-linux

unittest:
uses: ./.github/workflows/_unittest.yml
Expand Down
5 changes: 4 additions & 1 deletion backends/apple/coreml/compiler/coreml_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,15 @@ def preprocess(
CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
)

# Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
# get_compiled_model_path() requires a loaded model.
skip_model_load = model_type != CoreMLBackend.MODEL_TYPE.COMPILED_MODEL
mlmodel = ct.convert(
model=edge_program,
source="pytorch",
convert_to="mlprogram",
pass_pipeline=ct.PassPipeline.DEFAULT,
skip_model_load=True,
skip_model_load=skip_model_load,
compute_precision=model_compute_precision,
minimum_deployment_target=minimum_deployment_target,
compute_units=compute_units,
Expand Down
2 changes: 1 addition & 1 deletion backends/apple/coreml/runtime/delegate/backend_delegate.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class BackendDelegate {
// Max models cache size in bytes.
size_t max_models_cache_size = 10 * size_t(1024) * size_t(1024) * size_t(1024);
// If set to `true`, delegate pre-warms the most recently used asset.
bool should_prewarm_asset = true;
bool should_prewarm_asset = false;
// If set to `true`, delegate pre-warms the model in `init`.
bool should_prewarm_model = true;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<plist version="1.0">
<dict>
<key>shouldPrewarmAsset</key>
<true/>
<false/>
<key>shouldPrewarmModel</key>
<true/>
<key>maxAssetsSizeInBytes</key>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ - (void)testStateProgramExecute {
}
#endif

- (void)testAddMulCompiledProgramExecute {
NSURL *modelURL = [[self class] bundledResourceWithName:@"add_mul_compiled_coreml_all" extension:@"pte"];
XCTAssertNotNil(modelURL);
[self executeModelAtURL:modelURL nLoads:1 nExecutions:2];
}

- (void)executeMultipleModelsConcurrently:(NSArray<NSURL *> *)modelURLs
nLoads:(NSUInteger)nLoads
nExecutions:(NSUInteger)nExecutions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

/* Begin PBXBuildFile section */
8307EB8A2C9262060011AE6D /* state_coreml_all.pte in Resources */ = {isa = PBXBuildFile; fileRef = 8307EB892C9262060011AE6D /* state_coreml_all.pte */; };
838CA6872CD1965700462190 /* add_mul_compiled_coreml_all.pte in Resources */ = {isa = PBXBuildFile; fileRef = 838CA6862CD1965700462190 /* add_mul_compiled_coreml_all.pte */; };
83BB78A02C65DA7300274ED7 /* ETCoreMLModelDebugInfo.mm in Sources */ = {isa = PBXBuildFile; fileRef = 83BB789F2C65DA7300274ED7 /* ETCoreMLModelDebugInfo.mm */; };
83BB78BF2C66AAAE00274ED7 /* add_mul_coreml_all.bin in Resources */ = {isa = PBXBuildFile; fileRef = 83BB78BD2C66AAAE00274ED7 /* add_mul_coreml_all.bin */; };
83BB78C02C66AAAE00274ED7 /* add_mul_coreml_all.pte in Resources */ = {isa = PBXBuildFile; fileRef = 83BB78BE2C66AAAE00274ED7 /* add_mul_coreml_all.pte */; };
Expand Down Expand Up @@ -122,6 +123,7 @@

/* Begin PBXFileReference section */
8307EB892C9262060011AE6D /* state_coreml_all.pte */ = {isa = PBXFileReference; lastKnownFileType = file; name = state_coreml_all.pte; path = ../test/models/state_coreml_all.pte; sourceTree = "<group>"; };
838CA6862CD1965700462190 /* add_mul_compiled_coreml_all.pte */ = {isa = PBXFileReference; lastKnownFileType = file; name = add_mul_compiled_coreml_all.pte; path = ../test/models/add_mul_compiled_coreml_all.pte; sourceTree = "<group>"; };
83BB789E2C65DA7300274ED7 /* ETCoreMLModelDebugInfo.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; name = ETCoreMLModelDebugInfo.h; path = ../sdk/ETCoreMLModelDebugInfo.h; sourceTree = "<group>"; };
83BB789F2C65DA7300274ED7 /* ETCoreMLModelDebugInfo.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; name = ETCoreMLModelDebugInfo.mm; path = ../sdk/ETCoreMLModelDebugInfo.mm; sourceTree = "<group>"; };
83BB78BD2C66AAAE00274ED7 /* add_mul_coreml_all.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; name = add_mul_coreml_all.bin; path = ../test/models/add_mul_coreml_all.bin; sourceTree = "<group>"; };
Expand Down Expand Up @@ -606,6 +608,7 @@
C98551992AD2542D009143F9 /* mul_coreml_all.bin */,
C985519C2AD2542D009143F9 /* mul_coreml_all.pte */,
C985519B2AD2542D009143F9 /* mv3_coreml_all.bin */,
838CA6862CD1965700462190 /* add_mul_compiled_coreml_all.pte */,
C98551982AD2542D009143F9 /* mv3_coreml_all.pte */,
83BB78BD2C66AAAE00274ED7 /* add_mul_coreml_all.bin */,
83BB78BE2C66AAAE00274ED7 /* add_mul_coreml_all.pte */,
Expand Down Expand Up @@ -680,6 +683,7 @@
C985519E2AD2542D009143F9 /* mv3_coreml_all.pte in Resources */,
C98551A02AD2542D009143F9 /* add_coreml_all.bin in Resources */,
C98551A22AD2542D009143F9 /* mul_coreml_all.pte in Resources */,
838CA6872CD1965700462190 /* add_mul_compiled_coreml_all.pte in Resources */,
8307EB8A2C9262060011AE6D /* state_coreml_all.pte in Resources */,
C98551A32AD2542D009143F9 /* add_coreml_all.pte in Resources */,
);
Expand Down
9 changes: 9 additions & 0 deletions backends/apple/coreml/scripts/generate_test_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ done

echo "Executorch: Generating stateful model"
python3 "$SCRIPT_DIR_PATH/../runtime/test/export_stateful_model.py"

COMPILE_MODELS=("add_mul")
echo "Executorch: Generating compiled model"
for MODEL in "${COMPILE_MODELS[@]}"
do
echo "Executorch: Generating compiled $MODEL model"
python3 -m examples.apple.coreml.scripts.export --model_name "$MODEL" --compile
mv -f "$MODEL""_compiled_coreml_all.pte" "$COREML_DIR_PATH/runtime/test/models"
done
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
Expand Down Expand Up @@ -69,6 +70,7 @@ def transform_to_backend_pipeline(
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
Expand Down
47 changes: 47 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
Expand All @@ -9,11 +10,57 @@
import torch
import torch.fx

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor


def is_get_attr_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_get_attr_node(node)
or is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)


def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
Expand Down
164 changes: 164 additions & 0 deletions backends/arm/_passes/conv1d_unsqueeze_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_param_tensor,
insert_q_dq_pair,
is_param_node,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class Conv1dUnsqueezePass(ExportPass):
"""
This pass is used to change conv1d ops into conv2d since TOSA only
supports 2d and 3d convolution. This is done by modifying the graph to do the
following:
1) unsqueeze the convolution's input from 3d to 4d
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
3) perform a conv2d (with a modified version of the original conv1d args)
4) squeeze the output back down to 3d.
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
"""

def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.exported_program = exported_program

def unsqueeze_kernel_weights(self, kernel_node):
"""
Unsqueezes the weights of a conv1d to make it 4 dimensional.
Args:
kernel_node: the weights of conv1d node to be unsqueezed
"""
kernel_param_3d = get_param_tensor(self.exported_program, kernel_node)
if kernel_param_3d is None:
raise AssertionError("Expected param tensor for the kernel node")

kernel_param_4d = torch.nn.Parameter(
data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1),
requires_grad=False,
)

if torch._export.utils.is_param(self.exported_program, kernel_node):
parameter_name = self.exported_program.graph_signature.inputs_to_parameters[
kernel_node.name
]
self.exported_program.state_dict[parameter_name] = kernel_param_4d
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
elif torch._export.utils.is_buffer(self.exported_program, kernel_node):
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
kernel_node.name
]
self.exported_program.state_dict[buffer_name] = kernel_param_4d
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
elif torch._export.utils.is_lifted_tensor_constant(
self.exported_program, kernel_node
):
buffer_name = (
self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
kernel_node.name
]
)
self.exported_program.constants[buffer_name] = kernel_param_4d
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
else:
setattr(
kernel_node.graph.owning_module,
kernel_node.target,
kernel_param_4d,
)

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
node_list = list(graph.nodes)
for node in node_list:
if node.op == "call_function":
if node.target == exir_ops.edge.aten.convolution.default:
stride = list(node.args[3])
if len(stride) != 1:
# skip conv if it is not 1d
continue

kernel_node = node.args[1]
if kernel_node.target == dq_op:
kernel_node = kernel_node.args[0]

if not is_param_node(self.exported_program, kernel_node):
raise AssertionError(
"Expected op for convolution weight node to be a get_attr node or a parameter"
)

# Modify graph such that the conv changes from 1d to 2d
self.unsqueeze_kernel_weights(kernel_node)

# (b) Extend stride, padding, and dilation for extra dim
node.args = (
node.args[0],
node.args[1],
node.args[2],
node.args[3] + [1], # stride
node.args[4] + [0], # padding
node.args[5] + [1], # dilation
node.args[6],
node.args[7] + [0],
node.args[8],
)

# c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d)
# unsqueeze -> conv2d -> squeeze
with graph.inserting_before(node):
input_node = node.args[0]
unsqueeze_before = create_node(
graph, exir_ops.edge.aten.unsqueeze_copy.default
)
unsqueeze_before.args = (
input_node, # Input is node's original input
-1, # Last Dimension
)
node.replace_input_with(input_node, unsqueeze_before)

# If Quantized we must insert unsqueeze --> q --> dq --> node
if input_node.target == dq_op:
q_params = input_node.args[1:]
insert_q_dq_pair(graph, unsqueeze_before, q_params)

with graph.inserting_after(node):
squeeze_after = create_node(
graph,
exir_ops.edge.aten.squeeze_copy.dims,
)
squeeze_after.args = (
node, # Input is the conv node
[-1], # Last dimension
)
original_users = [
user for user in node.users if user != squeeze_after
]
for user in original_users:
user.replace_input_with(node, squeeze_after)

# If quantized, insert conv2d --> q --> dq --> squeeze
if all(
original_user.target == q_op for original_user in original_users
):
q_params = original_users[0].args[1:]
insert_q_dq_pair(graph, node, q_params)

graph_module.recompile()
# Since we are overriding "call", we need to call the parent's "call"
# to retrace the graph and regenerate metadata
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ def call(self, graph_module: torch.fx.GraphModule):
output_node.replace_all_uses_with(slice_node)
graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
Loading

0 comments on commit 7072212

Please sign in to comment.