Skip to content

Commit

Permalink
[Arm tester] Run delegate nodes using tosa_reference_model
Browse files Browse the repository at this point in the history
Instead of executing just one delegate,
we execute the graph and dispatch the
delegate nodes to the tosa_reference_model.
This makes the tester more flexible and
enables running tests with multiple
delegates, multiple outputs etc.

Signed-off-by: Erik Lundell <erik.lundell@arm.com>
Change-Id: I5fb0192da2d187f29c025e0c51238a9cb7d7ec90
  • Loading branch information
Erik-Lundell committed Jan 8, 2025
1 parent 08770b7 commit 781ad6d
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 198 deletions.
51 changes: 0 additions & 51 deletions backends/arm/_passes/tag_io_quant_pass.py

This file was deleted.

16 changes: 15 additions & 1 deletion backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -202,6 +202,20 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
return False


def is_quantize_io(compile_specs: List[CompileSpec]) -> bool:
for spec in compile_specs:
if spec.key == "quantize_io" and spec.value.decode() == "True":
return True
return False


def get_tosa_version(compile_spec: List[CompileSpec]) -> TosaSpecification:
for spec in compile_spec:
if spec.key == "tosa_version":
return TosaSpecification.create_from_string(spec.value.decode())
raise RuntimeError("Could not find TOSA version in CompileSpec")


def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
for spec in compile_spec:
if spec.key == "debug_artifact_path":
Expand Down
63 changes: 48 additions & 15 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,8 +10,10 @@
from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import ArmBackend # usort: skip
from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass
from executorch.backends.arm.arm_backend import (
ArmBackend,
is_quantize_io,
) # usort: skip
from executorch.backends.arm.operator_support.tosa_supported_operators import (
TOSASupportedOperators,
)
Expand All @@ -23,7 +25,7 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.passes import PassManager
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

Expand All @@ -35,6 +37,22 @@
logger.setLevel(logging.INFO)


def is_quant_node(node: torch.fx.node.Node) -> bool:
return node.target in {
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
}


def is_dequant_node(node: torch.fx.node.Node) -> bool:
return node.target in {
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
}


@final
class ArmPartitioner(Partitioner):
def __init__(self, compile_spec: List[CompileSpec]) -> None:
Expand All @@ -43,6 +61,7 @@ def __init__(self, compile_spec: List[CompileSpec]) -> None:
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags

logger.info("ArmPartitioner::partition")
partition_tags = {}

Expand All @@ -52,28 +71,42 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:

logger.info(f"Partitioning for {tosa_spec}")

for spec in self.delegation_spec.compile_specs:
if spec.key == "quantize_io" and spec.value.decode() == "True":
# Exclude IO quantization from the partition
passes = PassManager(
passes=[
TagIOQuantPass(),
]
)
passes(exported_program.graph_module)

capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
TOSASupportedOperators(tosa_spec),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
for partition in partition_list:
tag = f"tag{partition.id}"

def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
return (
"delegation_tag" in node.meta and node.meta["delegation_tag"] == tag
)

for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

if not is_quantize_io(self.delegation_spec.compile_specs):
continue

# De-tag outmost q-nodes upwards and dq-nodes downwards.
# De-tag if at least one input/ output is not part of partition.
for node in partition.nodes:
if is_quant_node(node):
for input in node.all_input_nodes:
if not is_partitioned(input):
del node.meta["delegation_tag"]
break

if is_dequant_node(node):
for user in node.users:
if not is_partitioned(user):
del node.meta["delegation_tag"]
break

tag_constant_data(exported_program)

return PartitionResult(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -50,7 +50,6 @@ def maybe_get_tosa_collate_path() -> str | None:
tosa_test_base = os.path.join(tosa_test_base, "tosa-mi")
else:
tosa_test_base = os.path.join(tosa_test_base, "other")

return os.path.join(tosa_test_base, test_class, test_name)

return None
Expand Down Expand Up @@ -83,6 +82,7 @@ def get_tosa_compile_spec_unbuilt(
.tosa_compile_spec(tosa_version)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(custom_path)
.set_quantize_io(True)
)

return compile_spec_builder
Expand Down
57 changes: 57 additions & 0 deletions backends/arm/test/misc/test_multiple_delegates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester


class TestMultipleDelegates(unittest.TestCase):
class MultipleDelegatesModule(torch.nn.Module):
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))

def get_inputs(self):
return self.inputs

def forward(self, x: torch.Tensor, y: torch.Tensor):
z = x + y
s = torch.sin(z)
return s * z

def test_tosa_MI(self):
module = self.MultipleDelegatesModule()
inputs = module.get_inputs()
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
)

def test_tosa_BI(self):
module = self.MultipleDelegatesModule()
inputs = module.get_inputs()
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 2})
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
)
53 changes: 53 additions & 0 deletions backends/arm/test/misc/test_multiple_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester


class TestMultipleOutputs(unittest.TestCase):
class MultipleOutputsModule(torch.nn.Module):
inputs = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))

def get_inputs(self):
return self.inputs

def forward(self, x: torch.Tensor, y: torch.Tensor):
return (x * y, x.sum(dim=-1, keepdim=True))

def test_tosa_MI_pipeline(self):
module = self.MultipleOutputsModule()
inputs = module.get_inputs()
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
)
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs)
)

def test_tosa_BI_pipeline(self):
module = self.MultipleOutputsModule()
inputs = module.get_inputs()
(
ArmTester(
module,
example_inputs=inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
)
64 changes: 0 additions & 64 deletions backends/arm/test/passes/test_tag_io_quant_pass.py

This file was deleted.

Loading

0 comments on commit 781ad6d

Please sign in to comment.