Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Add unit test for Spill-Fill buffer (#7518)
Browse files Browse the repository at this point in the history
Add unit test to validate the size of the Spill-Fill buffer.
  • Loading branch information
winskuo-quic authored Jan 9, 2025
1 parent 84e377a commit 25a94ef
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
12 changes: 12 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,18 @@ def forward(self, input_pos, k_val):
return k_out


class LargeTensorLinear(torch.nn.Module):
def __init__(self):
super().__init__()
hidden_dim = 4096
self.linear1 = torch.nn.Linear(512, hidden_dim)
self.linear2 = torch.nn.Linear(hidden_dim, 512)

def forward(self, x):
x1 = self.linear1(x) + self.linear1(x)
return self.linear2(x1)


class LayerNorm(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
37 changes: 37 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,6 +1581,24 @@ def test_qnn_backend_skip_node_op(self):
skip_node_op_set={"aten.add.Tensor"},
)

def test_qnn_backend_spill_fill_buffer_size(self):
module = LargeTensorLinear() # noqa: F405
sample_input = (torch.randn(1, 256, 512),)
edge_prog = capture_program(module, sample_input)

backend_options = generate_htp_compiler_spec(
use_fp16=True,
use_multi_contexts=True,
)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.model],
backend_options=backend_options,
)
partitioner = QnnPartitioner(compiler_specs)
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
self.assertNotEqual(0, max_sf_size)

def test_qnn_backend_multi_contexts(self):
module = SimpleModel() # noqa: F405
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
Expand Down Expand Up @@ -2011,6 +2029,25 @@ def calibrator(gm):
).to_executorch()
self.verify_output(module, sample_input, exec_prog)

def test_qnn_backend_spill_fill_buffer_size(self):
module = LargeTensorLinear() # noqa: F405
sample_input = (torch.randn(1, 256, 512),)
module = self.get_qdq_module(module, sample_input)
edge_prog = capture_program(module, sample_input)

backend_options = generate_htp_compiler_spec(
use_fp16=False,
use_multi_contexts=True,
)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=self.chipset_table[TestQNN.model],
backend_options=backend_options,
)
partitioner = QnnPartitioner(compiler_specs)
edge_prog.exported_program = to_backend(edge_prog.exported_program, partitioner)
max_sf_size = update_spill_fill_size(edge_prog.exported_program)
self.assertNotEqual(0, max_sf_size)

def test_qnn_backend_graph_level_mixed_precision(self):
module = SimpleModel() # noqa: F405
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
Expand Down
8 changes: 5 additions & 3 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,17 @@ def set_spec(module, options):
options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
set_spec(module, options)

max_sf_size, modules_map = 0, {}
if isinstance(exported_program, list):
max_sf_size, modules_map = 0, {}
for prog in exported_program:
max_sf_buf_size, module_map = get_program_info(prog)
max_sf_size = max(max_sf_size, max_sf_buf_size)
modules_map.update(module_map)
update_program(max_sf_size, modules_map)
else:
update_program(*get_program_info(exported_program))
max_sf_size, module_map = get_program_info(exported_program)
update_program(max_sf_size, module_map)

return max_sf_size


def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
Expand Down

0 comments on commit 25a94ef

Please sign in to comment.