Skip to content

Commit

Permalink
Split up vector_swizzle tests into smaller chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
lbushi25 committed Jun 13, 2024
1 parent 4043316 commit fe175df
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 26 deletions.
71 changes: 58 additions & 13 deletions tests/common/common_python_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from collections import defaultdict
from string import Template
from itertools import product
from math import ceil, floor

class Data:
signs = [True, False]
Expand Down Expand Up @@ -384,6 +385,8 @@ def get_space_count(line):
def add_spaces_to_lines(count, string):
"""Adds a number of spaces to the start of each line"""
all_lines = string.splitlines(True)
if not all_lines:
return ''
new_string = all_lines[0]
for i in range(1, len(all_lines)):
new_string += ' ' * count + all_lines[i]
Expand Down Expand Up @@ -619,7 +622,7 @@ def substitute_swizzles_templates(type_str, size, index_subset, value_subset, co
test_string)
return string

def gen_swizzle_test(type_str, convert_type_str, as_type_str, size):
def gen_swizzle_test(type_str, convert_type_str, as_type_str, size, num_batches, batch_index):
string = ''
if size > 4:
test_string = SwizzleData.swizzle_full_test_template.substitute(
Expand Down Expand Up @@ -654,26 +657,68 @@ def gen_swizzle_test(type_str, convert_type_str, as_type_str, size):
', '.join(Data.swizzle_elem_list_dict[size][:size]) + '>',
test_string)
return string
# size <=4

# Case when size <=4
# The test files generated for swizzles of vectors of size <= 4 are enormous and are hurting
# compilation times of the suite so we batch the tests according to two command line arguments
# in num_batches and batch_index that will dictate how many tests we can put in a single test file.
# Specifically, the test cases are to be split in num_batches different groups aka batches
# and the batch_index tells the script which batch in particular we want to output to a test file during this run.
# Both of these arguments, num_batches and batch_index, are controlled by the cmake test generation script.

total_tests = 0
for length in range(size, size + 1):
for index_subset, value_subset in zip(
product(
Data.swizzle_xyzw_list_dict[size][:size],
repeat=length),
product(Data.vals_list_dict[size][:size], repeat=length)):
string += substitute_swizzles_templates(type_str, size,
index_subset, value_subset, convert_type_str, as_type_str)
total_tests += 1
batch_size = ceil(total_tests / num_batches)
cur_index = 0
cur_batch = 0
for length in range(size, size + 1):
for index_subset, value_subset in zip(
product(
Data.swizzle_xyzw_list_dict[size][:size],
repeat=length),
product(Data.vals_list_dict[size][:size], repeat=length)):
cur_batch = floor(cur_index / batch_size)
if cur_batch > batch_index:
break
if cur_batch == batch_index:
string += substitute_swizzles_templates(type_str, size,
index_subset, value_subset, convert_type_str, as_type_str)
cur_index += 1

# Same logic as above repeated for the case when size == 4
if size == 4:
total_tests = 0
for length in range(size, size + 1):
for index_subset, value_subset in zip(
product(
Data.swizzle_rgba_list_dict[size][:size],
repeat=length),
product(
Data.vals_list_dict[size][:size], repeat=length)):
string += substitute_swizzles_templates(type_str, size,
index_subset, value_subset, convert_type_str, as_type_str)
total_tests += 1
batch_size = ceil(total_tests / num_batches)
cur_index = 0
cur_batch = 0
for length in range(size, size + 1):
for index_subset, value_subset in zip(
product(
Data.swizzle_rgba_list_dict[size][:size],
repeat=length),
product(
Data.vals_list_dict[size][:size], repeat=length)):
cur_batch = floor(cur_index / batch_size)
if cur_batch > batch_index:
break
if cur_batch == batch_index:
string += substitute_swizzles_templates(type_str, size,
index_subset, value_subset, convert_type_str, as_type_str)
cur_index += 1
return string


Expand Down Expand Up @@ -724,7 +769,7 @@ def get_reverse_type(type_str):
# Reason for the TODO above is that this function and several more it calls are
# not really common and only used to generate vector_swizzles test.
# FIXME: The test (main template and others) should be updated to use Catch2
def make_swizzles_tests(type_str, input_file, output_file):
def make_swizzles_tests(type_str, input_file, output_file, num_batches, batch_index):
if type_str == 'bool':
Data.vals_list_dict = cast_to_bool(Data.vals_list_dict)

Expand All @@ -733,15 +778,15 @@ def make_swizzles_tests(type_str, input_file, output_file):
convert_type_str = get_reverse_type(type_str)
as_type_str = get_reverse_type(type_str)
swizzles[0] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 1)
as_type_str, 1, num_batches, batch_index)
swizzles[1] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 2)
as_type_str, 2, num_batches, batch_index)
swizzles[2] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 3)
as_type_str, 3, num_batches, batch_index)
swizzles[3] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 4)
as_type_str, 4, num_batches, batch_index)
swizzles[4] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 8)
as_type_str, 8, num_batches, batch_index)
swizzles[5] = gen_swizzle_test(type_str, convert_type_str,
as_type_str, 16)
as_type_str, 16, num_batches, batch_index)
write_swizzle_source_file(swizzles, input_file, output_file, type_str)
2 changes: 1 addition & 1 deletion tests/common/vector_swizzles.template
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class TEST_NAME : public util::test_base {
}
};

util::test_proxy<TEST_NAME> proxy;
inline util::test_proxy<TEST_NAME> proxy;

} /* namespace vector_swizzles_$TYPE_NAME__ */
$ENDIF
25 changes: 14 additions & 11 deletions tests/vector_swizzles/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
set(TEST_CASES_LIST "")

set(TYPE_LIST "")
set(NUM_BATCHES "32")
get_std_type(TYPE_LIST)
get_no_vec_alias_type(TYPE_LIST)
get_fixed_width_type(TYPE_LIST)

half_double_filter(TYPE_LIST)

foreach(TY IN LISTS TYPE_LIST)
set(OUT_FILE "vector_swizzles_${TY}.cpp")
STRING(REGEX REPLACE ":" "_" OUT_FILE ${OUT_FILE})
STRING(REGEX REPLACE " " "_" OUT_FILE ${OUT_FILE})
STRING(REGEX REPLACE "std__" "" OUT_FILE ${OUT_FILE})
foreach(BATCH_INDEX RANGE 1 ${NUM_BATCHES})
set(OUT_FILE "vector_swizzles_${TY}_batch_${BATCH_INDEX}.cpp")
STRING(REGEX REPLACE ":" "_" OUT_FILE ${OUT_FILE})
STRING(REGEX REPLACE " " "_" OUT_FILE ${OUT_FILE})
STRING(REGEX REPLACE "std__" "" OUT_FILE ${OUT_FILE})

# Invoke our generator
# the path to the generated cpp file will be added to TEST_CASES_LIST
generate_cts_test(TESTS TEST_CASES_LIST
GENERATOR "generate_vector_swizzles.py"
OUTPUT ${OUT_FILE}
INPUT "../common/vector_swizzles.template"
EXTRA_ARGS -type "${TY}")
# Invoke our generator
# the path to the generated cpp file will be added to TEST_CASES_LIST
generate_cts_test(TESTS TEST_CASES_LIST
GENERATOR "generate_vector_swizzles.py"
OUTPUT ${OUT_FILE}
INPUT "../common/vector_swizzles.template"
EXTRA_ARGS -type "${TY}" -num_batches ${NUM_BATCHES} -batch_index ${BATCH_INDEX})
endforeach()
endforeach()

add_cts_test(${TEST_CASES_LIST})
14 changes: 13 additions & 1 deletion tests/vector_swizzles/generate_vector_swizzles.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def main():
required=True,
choices=get_types(),
help='Type to generate the test for')
argparser.add_argument(
'-num_batches',
dest='num_batches',
required=True,
type=int,
help='Number of batches to split the test cases into')
argparser.add_argument(
'-batch_index',
dest='batch_index',
required=True,
type=int,
help='Batch index of the test batch to write to the output file.')
argparser.add_argument(
'-o',
required=True,
Expand All @@ -48,7 +60,7 @@ def main():
help='CTS test output')
args = argparser.parse_args()

make_swizzles_tests(args.ty, args.template, args.output)
make_swizzles_tests(args.ty, args.template, args.output, args.num_batches, args.batch_index - 1)


if __name__ == '__main__':
Expand Down

0 comments on commit fe175df

Please sign in to comment.