Skip to content

Commit

Permalink
new backend tests for lifted graphs
Browse files Browse the repository at this point in the history
Summary:
Copying all the tests from test_backends.py to use lifted graph instead.

First step in migrating completely over to torch.export

Reviewed By: cccclai

Differential Revision: D47887945

fbshipit-source-id: 539771ab04389f4f605a8c39ad892e6ba9673369
  • Loading branch information
mcr229 authored and facebook-github-bot committed Jul 31, 2023
1 parent fb602f8 commit f354a96
Show file tree
Hide file tree
Showing 5 changed files with 1,499 additions and 46 deletions.
31 changes: 31 additions & 0 deletions backends/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,37 @@ python_unittest(
],
)

python_unittest(
name = "test_backends_lifted",
srcs = [
"test_backends_lifted.py",
],
supports_static_listing = True,
deps = [
":backend_with_compiler_demo",
":hta_partitioner_demo",
":op_partitioner_demo",
":qnn_backend_demo",
"//caffe2:torch",
"//caffe2/functorch:functorch_src",
"//executorch/backends:backend_api",
"//executorch/backends:compile_spec_schema",
"//executorch/backends:partitioner",
"//executorch/exir:delegate",
"//executorch/exir:graph_module",
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/dialects:lib",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/kernels/portable:custom_ops_generated_lib",
"//executorch/kernels/quantized:custom_ops_generated_lib",
"//executorch/runtime/executor/test:test_backend_compiler_lib",
],
)

python_unittest(
name = "test_graph_partition",
srcs = [
Expand Down
71 changes: 48 additions & 23 deletions backends/test/hta_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,37 @@ def forward(self, x_raw, h, c):
input_h = torch.ones([1, 32])
input_c = torch.ones([1, 32])

pattern_lstm_conv_lifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True, enable_aot=True),
)
.to_edge()
.exported_program.graph_module
)
pattern_lstm_conv = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True),
)
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.to_edge()
.exported_program.graph_module
)

def sub(x, y):
return torch.sub(x, y)

pattern_sub_lifted = (
exir.capture(
sub,
(input_x, input_h),
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
)
.to_edge(exir.EdgeCompileConfig(_use_edge_ops=True))
.exported_program.graph_module
)
pattern_sub = (
exir.capture(
sub,
Expand All @@ -80,7 +98,12 @@ def sub(x, y):
.to_edge()
.exported_program.graph_module
)
self.patterns = [pattern_lstm_conv.graph, pattern_sub.graph]
self.patterns = [
pattern_lstm_conv_lifted.graph,
pattern_lstm_conv.graph,
pattern_sub_lifted.graph,
pattern_sub.graph,
]

backend_id = QnnBackend.__name__
self.delegation_spec = DelegationSpec(backend_id, [])
Expand Down Expand Up @@ -145,28 +168,18 @@ def generate_partition_list(self, graph_module) -> List[Partition]:
]
"""
partitions_from_all_pattern = [
generate_pattern_op_partitions(graph_module, patterns=[pattern])
for pattern in self.patterns
]

# Check if all partitions are exclusive, this partitions don't support inclusive partitions.
is_exclusive = self.is_exclusive(partitions_from_all_pattern)

assert (
is_exclusive
), "There exists inclusive partitions. Currently the fuse method only handle exclusive partitions."
partitions_from_all_pattern = generate_pattern_op_partitions(
graph_module, self.patterns
)

# Assign a unique id for each partition
partition_id = 0

# If want to support inclusive partitions, the logic can be done here to merge partitions etc.
flat_proposed_partitions_with_unique_id = []
for partitions_from_one_pattern in partitions_from_all_pattern:
for partition in partitions_from_one_pattern:
partition.id = partition_id
flat_proposed_partitions_with_unique_id.append(partition)
partition_id += 1
for partition in partitions_from_all_pattern:
partition.id = partition_id
flat_proposed_partitions_with_unique_id.append(partition)
partition_id += 1

return flat_proposed_partitions_with_unique_id

Expand Down Expand Up @@ -213,16 +226,28 @@ def forward(self, x_raw, h, c):
input_h = torch.ones([1, 32])
input_c = torch.ones([1, 32])

pattern_lstm_conv = (
pattern_lstm_conv_lifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
exir.CaptureConfig(pt2_mode=True, enable_aot=True),
)
.to_edge()
.exported_program.graph_module
)
pattern_lstm_conv_unlifted = (
exir.capture(
LSTMConvPattern(),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True),
)
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
.to_edge()
.exported_program.graph_module
)
self.patterns = [pattern_lstm_conv.graph]
self.patterns = [
pattern_lstm_conv_lifted.graph,
pattern_lstm_conv_unlifted.graph,
]
# Only (lstm + conv) pattern is lowerable

backend_id = QnnBackend.__name__
Expand Down
29 changes: 6 additions & 23 deletions backends/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,19 +627,11 @@ def forward(self, x_raw, h, c):

traced = exir.capture(
composite_m, inputs, exir.CaptureConfig(pt2_mode=True)
).to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
)
)
).to_edge()

program_without_delegates = (
exir.capture(CompositeModel(3), inputs)
.to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
)
)
.to_edge()
.to_executorch(
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
)
Expand Down Expand Up @@ -741,20 +733,16 @@ def forward(self, x_raw, h, c):
traced = exir.capture(
composite_m,
inputs,
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=False),
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
exir.CaptureConfig(pt2_mode=True),
).to_edge()

program_without_delegates = (
exir.capture(
CompositeModel(3),
(input_x, input_h, input_c),
exir.CaptureConfig(pt2_mode=True),
)
.to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
)
)
.to_edge()
.to_executorch(
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
)
Expand Down Expand Up @@ -998,13 +986,8 @@ def test_quantized_with_delegate(self) -> None:
exir.CaptureConfig(
pt2_mode=True,
enable_aot=True,
_unlift=True,
),
).to_edge(
exir.EdgeCompileConfig(
_check_ir_validity=False,
)
)
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run(
converted_linear_gm.exported_program.graph_module.code
)
Expand Down
Loading

0 comments on commit f354a96

Please sign in to comment.