Skip to content

Commit

Permalink
Add test script for RelPosAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
Masao-Someki committed Sep 18, 2022
1 parent dfa5043 commit 0ffafce
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
from .op_test_utils import check_op_type_count

test_cases = [
['encoder', 'transformer', 4, 256, 3, 0, False],
['encoder', 'transformer', 4, 256, 'Attention', 3, False],
# ['encoder', 'contextual_block_transformer', 4, 256, 3, 0, False],
['encoder', 'transformer', 4, 256, 3, 0, True],
['encoder', 'transformer', 4, 256, 'Attention', 3, True],
['encoder', 'conformer_rpe_latest', 4, 256, 'RelPosAttention', 3, True],
# ['encoder', 'contextual_block_transformer', 4, 256, 3, 0, True],
['decoder', 'transformer', 4, 256, 0, 6, True],
['lm', 'transformer_pe', 4, 256, 0, 3, True],
['lm', 'transformer', 4, 256, 0, 2, True],
['decoder', 'transformer', 4, 256, 'CrossAttention', 6, True],
['lm', 'transformer_pe', 4, 256, 'CrossAttention', 3, True],
['lm', 'transformer', 4, 256, 'CrossAttention', 2, True],
]

@pytest.mark.parametrize('model_type, model_name, n_head, h_size, n_att, n_cross_att, use_custom_ort', test_cases)
def test_optimize(model_type, model_name, n_head, h_size, n_att, n_cross_att, use_custom_ort, model_export):
@pytest.mark.parametrize('model_type, model_name, n_head, h_size, node_name, node_num, use_custom_ort', test_cases)
def test_optimize(model_type, model_name, n_head, h_size, node_name, node_num, use_custom_ort, model_export):
export_dir = model_export.cache_dir / 'test' / \
model_type / f'cache_{model_name}'
output_dir = model_export.cache_dir / 'test' / \
Expand All @@ -45,11 +46,6 @@ def test_optimize(model_type, model_name, n_head, h_size, n_att, n_cross_att, us
)

# load the optimized model and check if the number of fused nodes is correct.
nodes = {}
if n_att > 0:
nodes['Attention'] = n_att
if n_cross_att > 0:
nodes['CrossAttention'] = n_cross_att

nodes = {node_name : node_num}
check_op_type_count(str(output_dir / model_name), **nodes)

0 comments on commit 0ffafce

Please sign in to comment.