Skip to content

Commit

Permalink
2024-11-11 nightly release (0a20d72)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Nov 11, 2024
1 parent 6b9cc29 commit a453011
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 343 deletions.
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.backends.arm._passes.decompose_layernorm_pass import (
DecomposeLayerNormPass,
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
Expand Down Expand Up @@ -74,6 +75,7 @@ def transform_to_backend_pipeline(
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
Expand Down
112 changes: 112 additions & 0 deletions backends/arm/_passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeLinearPass(ExportPass):
"""
This pass decomposes linear into a Conv2D with the required view operations.
linear(x, weights, bias) becomes:
x_reshaped = view(x)
weights_reshaped = view(weights)
conv2d = conv2d(x_reshaped, weights_reshaped, bias)
output = view(conv2d)
It also inserts q/dq pairs if the linear node was quantized.
"""

def call(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target != exir_ops.edge.aten.linear.default:
continue
args = node.args
input = args[0]
weights = args[1]
bias = args[2] if len(args) > 2 else None
output_shape = get_first_fake_tensor(node).shape
input_shape = get_first_fake_tensor(input).shape
weights_shape = get_first_fake_tensor(weights).shape
batches = int(np.prod(input_shape[:-1])) if len(input_shape) > 1 else 1
# input has shape (..., Ci)
input_reshaped_shape = [batches, input_shape[-1], 1, 1]
# weights have shape (Co, Ci)
weights_reshaped_shape = [weights_shape[0], weights_shape[1], 1, 1]

with graph_module.graph.inserting_before(node):
quantize = input.op == "call_function" and input.target == dq_op
q_params = input.args[1:] if quantize else None
# Reshape input to 4D with shape (N, Ci, 1, 1)
input_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(input, input_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

quantize = weights.op == "call_function" and weights.target == dq_op
q_params = weights.args[1:] if quantize else None
# Reshape weights to 4D with shape (Co, Ci, 1, 1)
weights_reshaped = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(weights, weights_reshaped_shape),
kwargs={},
quantize=quantize,
q_params=q_params,
)

consumer_node = list(node.users)[0]
quantize = (
consumer_node.op == "call_function" and consumer_node.target == q_op
)
q_params = consumer_node.args[1:] if quantize else None
conv = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.convolution.default,
args=(
input_reshaped,
weights_reshaped,
bias,
[1, 1], # strides
[0, 0], # padding
[1, 1], # dilation
False, # transposed
[0, 0], # output padding
1, # groups
),
kwargs={},
quantize=quantize,
q_params=q_params,
)

with graph_module.graph.inserting_after(conv):
# Reshape output to same rank as original input with shape (..., Co)
# No need to insert q/dq pair as Conv2D node above has inserted them if
# required.
output = create_node(
graph=graph_module.graph,
op_target=exir_ops.edge.aten.view_copy.default,
args=(conv, list(output_shape)),
kwargs={},
)

node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
13 changes: 11 additions & 2 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import operator
import os
from typing import cast, final, List
from typing import Callable, cast, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
Expand Down Expand Up @@ -39,7 +39,6 @@ class TOSASupportedOperators(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.bmm.default,
Expand All @@ -49,6 +48,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.mul.Tensor,
Expand Down Expand Up @@ -137,3 +137,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)

def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose = [
torch.ops.aten.linear.default,
]
return (ops_to_not_decompose, None)
1 change: 0 additions & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from . import ( # noqa
node_visitor,
op_add,
op_addmm,
op_avg_pool2d,
op_batch_norm,
op_bmm,
Expand Down
6 changes: 2 additions & 4 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def define_node(
input_nodes, tosa_graph
)

# Preapre sub output tensor
broadcasted_shape = tutils.broadcast_shapes(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
# Prepare add output tensor
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
add_output = output
Expand Down
148 changes: 0 additions & 148 deletions backends/arm/operators/op_addmm.py

This file was deleted.

8 changes: 0 additions & 8 deletions backends/arm/operators/op_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import is_permute_node_before_addmm
from serializer.tosa_serializer import TosaOp


Expand Down Expand Up @@ -81,13 +80,6 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_permute_node_before_addmm(node):
## Simply add an identityOp
tosa_graph.addOperator(
TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
)
return

# The permutation vector describes a permutation P in default Pytorch dim_order.
# For rank 4, the default dim_order NCHW.
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
Expand Down
Loading

0 comments on commit a453011

Please sign in to comment.