From 1d5140d5e98c18e73ce576673aa34022ad6d804f Mon Sep 17 00:00:00 2001 From: Marsella8 <45826022+Marsella8@users.noreply.github.com> Date: Thu, 10 Oct 2024 07:53:36 -0700 Subject: [PATCH] Utils: Refactor and Test Updates (#1464) * removed GraphInternal * minor changes * Added views.cc testing * Fixed containers testing * views.cc fix * test reorganizing * fmt * test rearranging * rearranging * graph fixes * undirected graph test fix * Updated docs * Test updates + bug fixes * minor fix * PR fixes * moved graph-testing to other PR * minor fixes * Remove unnecessary includes * Post-merge fixes * Small bugfix and remove some unnecessary includes --------- Co-authored-by: Pietro Max Marsella Co-authored-by: Colin Unger Co-authored-by: Colin Unger --- .../machine_mapping/machine_mapping.cc | 1 - lib/kernels/include/kernels/accessor.h | 1 - .../include/kernels/initializer_kernels.h | 1 - lib/kernels/test/src/test_dropout.cc | 1 - lib/pcg/include/pcg/optimizer_attrs.h | 2 +- lib/pcg/src/pcg/operator_task_space.cc | 3 +- lib/runtime/src/model.cc | 1 - .../src/task_invocation_compilation.cc | 1 - .../task_spec/task_invocation_compilation.cc | 1 - .../src/task_spec/tensor_args_format.cc | 2 +- .../src/substitutions/pcg_pattern_match.cc | 4 +- .../output_expr_to_result_sub_pcg_mapping.cc | 4 +- .../unlabelled/multidigraph_pattern_match.cc | 56 --- ...rge_bidicts.h => merge_disjoint_bidicts.h} | 18 +- lib/utils/include/utils/containers.decl.h | 88 ---- lib/utils/include/utils/containers.h | 233 ----------- .../utils/containers/are_all_distinct.h | 16 + .../include/utils/containers/compare_by.h | 15 + .../include/utils/containers/get_first.h | 21 - .../include/utils/containers/get_one_of.h | 20 + lib/utils/include/utils/containers/index_of.h | 25 ++ .../include/utils/containers/is_submapeq_of.h | 18 + .../include/utils/containers/is_subseteq_of.h | 10 +- .../utils/containers/is_superseteq_of.h | 17 + lib/utils/include/utils/containers/map_keys.h | 14 + lib/utils/include/utils/containers/maximum.h | 13 +- lib/utils/include/utils/containers/product.h | 3 + .../include/utils/containers/product_where.h | 26 ++ .../utils/containers/reversed_container.h | 85 ++++ lib/utils/include/utils/containers/subvec.h | 15 +- lib/utils/include/utils/containers/sum.h | 13 +- .../include/utils/containers/sum_where.h | 24 ++ .../include/utils/containers/value_all.h | 34 ++ .../include/utils/containers/vector_split.h | 10 +- lib/utils/include/utils/fmt/expected.h | 2 +- lib/utils/include/utils/fmt/map.h | 2 +- lib/utils/include/utils/fmt/multiset.h | 2 +- lib/utils/include/utils/fmt/optional.h | 12 +- lib/utils/include/utils/fmt/pair.h | 2 +- lib/utils/include/utils/fmt/set.h | 2 +- lib/utils/include/utils/fmt/unordered_map.h | 2 +- .../include/utils/fmt/unordered_multiset.h | 2 +- lib/utils/include/utils/fmt/unordered_set.h | 2 +- lib/utils/include/utils/fmt/variant.h | 2 +- lib/utils/include/utils/fmt/vector.h | 2 +- lib/utils/include/utils/graph/cow_ptr_t.h | 2 - .../instances/hashmap_undirected_graph.h | 0 lib/utils/include/utils/join_strings.h | 9 +- lib/utils/include/utils/rapidcheck/variant.h | 19 + lib/utils/include/utils/stack_string.h | 7 + lib/utils/include/utils/unique.h | 13 - lib/utils/include/utils/variant.h | 11 - lib/utils/src/containers.cc | 1 - .../utils/bidict/algorithms/merge_bidicts.cc | 1 - .../algorithms/merge_disjoint_bidicts.cc | 1 + .../src/utils/cli/cli_get_help_message.cc | 3 +- .../src/utils/containers/are_all_distinct.cc | 1 + lib/utils/src/utils/containers/compare_by.cc | 1 + lib/utils/src/utils/containers/get_first.cc | 1 - lib/utils/src/utils/containers/get_one_of.cc | 1 + lib/utils/src/utils/containers/index_of.cc | 1 + .../src/utils/containers/is_submapeq_of.cc | 1 + .../src/utils/containers/is_superseteq_of.cc | 1 + .../src/utils/containers/product_where.cc | 1 + .../utils/containers/reversed_container.cc | 1 + lib/utils/src/utils/containers/sum_where.cc | 1 + lib/utils/src/utils/containers/value_all.cc | 1 + lib/utils/src/utils/fmt/optional.cc | 8 + .../algorithms/find_isomorphism.cc | 4 +- .../get_cbc_decomposition.cc | 4 - .../is_complete_bipartite_digraph.cc | 1 - .../algorithms/get_imm_post_dominator.cc | 4 +- .../algorithms/find_isomorphism.cc | 4 +- .../algorithms/find_isomorphisms.cc | 2 +- lib/utils/src/utils/rapidcheck/variant.cc | 10 + .../include/test/utils/doctest/fmt/optional.h | 5 + .../src/test/utils/doctest/fmt/optional.cc | 8 + lib/utils/test/src/test_algorithms.cc | 246 ----------- lib/utils/test/src/test_containers.cc | 393 ------------------ lib/utils/test/src/test_multidigraph.cc | 94 ----- lib/utils/test/src/test_random_utils.cc | 67 --- lib/utils/test/src/test_type_index.cc | 35 -- lib/utils/test/src/test_undirected_graph.cc | 62 --- .../algorithms/merge_disjoint_bidicts.cc | 42 ++ lib/utils/test/src/utils/bidict/bidict.cc | 36 +- lib/utils/test/src/utils/containers/all_of.cc | 13 + .../src/utils/containers/are_all_distinct.cc | 24 ++ .../test/src/utils/containers/are_disjoint.cc | 31 ++ .../test/src/utils/containers/compare_by.cc | 19 + .../test/src/utils/containers/contains.cc | 13 + lib/utils/test/src/utils/containers/count.cc | 13 + .../test/src/utils/containers/enumerate.cc | 25 +- .../test/src/utils/containers/filter_keys.cc | 18 + lib/utils/test/src/utils/containers/find.cc | 54 +++ .../test/src/utils/containers/flatmap.cc | 34 ++ .../test/src/utils/containers/get_one_of.cc | 19 + .../test/src/utils/containers/index_of.cc | 24 ++ .../src/utils/containers/is_submapeq_of.cc | 40 ++ .../src/utils/containers/is_subseteq_of.cc | 16 + .../src/utils/containers/is_superseteq_of.cc | 25 ++ lib/utils/test/src/utils/containers/keys.cc | 18 + .../test/src/utils/containers/map_keys.cc | 26 ++ .../test/src/utils/containers/map_values.cc | 17 + .../test/src/utils/containers/maximum.cc | 53 +-- .../test/src/utils/containers/merge_maps.cc | 30 ++ .../test/src/utils/containers/product.cc | 32 ++ .../src/utils/containers/product_where.cc | 34 ++ .../src/utils/containers/restrict_keys.cc | 17 + .../test/src/utils/containers/set_union.cc | 16 + .../test/src/utils/containers/sorted_by.cc | 35 ++ lib/utils/test/src/utils/containers/subvec.cc | 62 +++ .../test/src/utils/containers/sum_where.cc | 36 ++ .../test/src/utils/containers/value_all.cc | 24 ++ lib/utils/test/src/utils/containers/values.cc | 19 + .../test/src/utils/containers/vector_split.cc | 39 ++ .../deduplicated_priority_queue.cc} | 2 + .../disjoint_set.cc} | 19 +- .../{test_dot_file.cc => utils/dot_file.cc} | 0 .../utils/graph/multidigraph/multidigraph.cc | 171 ++++++++ lib/utils/test/src/utils/hash/map.cc | 20 + lib/utils/test/src/utils/hash/set.cc | 20 + lib/utils/test/src/utils/hash/tuple.cc | 21 + .../hash/unordered_map.cc} | 8 +- .../test/src/utils/hash/unordered_set.cc | 20 + lib/utils/test/src/utils/hash/vector.cc | 20 + lib/utils/test/src/utils/join_strings.cc | 47 +++ lib/utils/test/src/utils/random_utils.cc | 60 +++ .../test/src/utils/rapidcheck/variant.cc | 13 + .../record_formatter.cc} | 0 .../{test_sequence.cc => utils/sequence.cc} | 0 .../{test_stack_map.cc => utils/stack_map.cc} | 15 +- .../stack_string.cc} | 6 +- .../stack_vector.cc} | 31 +- .../src/{test_tuple.cc => utils/tuple.cc} | 27 +- lib/utils/test/src/utils/type_index.cc | 34 ++ .../src/{test_variant.cc => utils/variant.cc} | 11 +- .../src/{test_vector.cc => utils/vector.cc} | 4 + 137 files changed, 1791 insertions(+), 1515 deletions(-) delete mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc rename lib/utils/include/utils/bidict/algorithms/{merge_bidicts.h => merge_disjoint_bidicts.h} (52%) delete mode 100644 lib/utils/include/utils/containers.decl.h delete mode 100644 lib/utils/include/utils/containers.h create mode 100644 lib/utils/include/utils/containers/are_all_distinct.h create mode 100644 lib/utils/include/utils/containers/compare_by.h delete mode 100644 lib/utils/include/utils/containers/get_first.h create mode 100644 lib/utils/include/utils/containers/get_one_of.h create mode 100644 lib/utils/include/utils/containers/index_of.h create mode 100644 lib/utils/include/utils/containers/is_submapeq_of.h create mode 100644 lib/utils/include/utils/containers/is_superseteq_of.h create mode 100644 lib/utils/include/utils/containers/product_where.h create mode 100644 lib/utils/include/utils/containers/reversed_container.h create mode 100644 lib/utils/include/utils/containers/sum_where.h create mode 100644 lib/utils/include/utils/containers/value_all.h rename lib/utils/{src => include}/utils/graph/instances/hashmap_undirected_graph.h (100%) create mode 100644 lib/utils/include/utils/rapidcheck/variant.h delete mode 100644 lib/utils/include/utils/unique.h delete mode 100644 lib/utils/src/containers.cc delete mode 100644 lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc create mode 100644 lib/utils/src/utils/containers/are_all_distinct.cc create mode 100644 lib/utils/src/utils/containers/compare_by.cc delete mode 100644 lib/utils/src/utils/containers/get_first.cc create mode 100644 lib/utils/src/utils/containers/get_one_of.cc create mode 100644 lib/utils/src/utils/containers/index_of.cc create mode 100644 lib/utils/src/utils/containers/is_submapeq_of.cc create mode 100644 lib/utils/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/src/utils/containers/product_where.cc create mode 100644 lib/utils/src/utils/containers/reversed_container.cc create mode 100644 lib/utils/src/utils/containers/sum_where.cc create mode 100644 lib/utils/src/utils/containers/value_all.cc create mode 100644 lib/utils/src/utils/rapidcheck/variant.cc delete mode 100644 lib/utils/test/src/test_algorithms.cc delete mode 100644 lib/utils/test/src/test_containers.cc delete mode 100644 lib/utils/test/src/test_multidigraph.cc delete mode 100644 lib/utils/test/src/test_random_utils.cc delete mode 100644 lib/utils/test/src/test_type_index.cc delete mode 100644 lib/utils/test/src/test_undirected_graph.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc create mode 100644 lib/utils/test/src/utils/containers/all_of.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_distinct.cc create mode 100644 lib/utils/test/src/utils/containers/are_disjoint.cc create mode 100644 lib/utils/test/src/utils/containers/compare_by.cc create mode 100644 lib/utils/test/src/utils/containers/contains.cc create mode 100644 lib/utils/test/src/utils/containers/count.cc create mode 100644 lib/utils/test/src/utils/containers/filter_keys.cc create mode 100644 lib/utils/test/src/utils/containers/find.cc create mode 100644 lib/utils/test/src/utils/containers/get_one_of.cc create mode 100644 lib/utils/test/src/utils/containers/index_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_submapeq_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_subseteq_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/test/src/utils/containers/keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_values.cc create mode 100644 lib/utils/test/src/utils/containers/merge_maps.cc create mode 100644 lib/utils/test/src/utils/containers/product.cc create mode 100644 lib/utils/test/src/utils/containers/product_where.cc create mode 100644 lib/utils/test/src/utils/containers/restrict_keys.cc create mode 100644 lib/utils/test/src/utils/containers/set_union.cc create mode 100644 lib/utils/test/src/utils/containers/sorted_by.cc create mode 100644 lib/utils/test/src/utils/containers/subvec.cc create mode 100644 lib/utils/test/src/utils/containers/sum_where.cc create mode 100644 lib/utils/test/src/utils/containers/value_all.cc create mode 100644 lib/utils/test/src/utils/containers/values.cc create mode 100644 lib/utils/test/src/utils/containers/vector_split.cc rename lib/utils/test/src/{test_deduplicated_priority_queue.cc => utils/deduplicated_priority_queue.cc} (96%) rename lib/utils/test/src/{test_disjoint_set.cc => utils/disjoint_set.cc} (73%) rename lib/utils/test/src/{test_dot_file.cc => utils/dot_file.cc} (100%) create mode 100644 lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc create mode 100644 lib/utils/test/src/utils/hash/map.cc create mode 100644 lib/utils/test/src/utils/hash/set.cc create mode 100644 lib/utils/test/src/utils/hash/tuple.cc rename lib/utils/test/src/{test_hash.cc => utils/hash/unordered_map.cc} (69%) create mode 100644 lib/utils/test/src/utils/hash/unordered_set.cc create mode 100644 lib/utils/test/src/utils/hash/vector.cc create mode 100644 lib/utils/test/src/utils/join_strings.cc create mode 100644 lib/utils/test/src/utils/random_utils.cc create mode 100644 lib/utils/test/src/utils/rapidcheck/variant.cc rename lib/utils/test/src/{test_format.cc => utils/record_formatter.cc} (100%) rename lib/utils/test/src/{test_sequence.cc => utils/sequence.cc} (100%) rename lib/utils/test/src/{test_stack_map.cc => utils/stack_map.cc} (73%) rename lib/utils/test/src/{test_stack_string.cc => utils/stack_string.cc} (95%) rename lib/utils/test/src/{test_stack_vector.cc => utils/stack_vector.cc} (68%) rename lib/utils/test/src/{test_tuple.cc => utils/tuple.cc} (73%) create mode 100644 lib/utils/test/src/utils/type_index.cc rename lib/utils/test/src/{test_variant.cc => utils/variant.cc} (91%) rename lib/utils/test/src/{test_vector.cc => utils/vector.cc} (91%) diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 6f350d8773..57e82684e9 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,5 +1,4 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "utils/containers.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 5fbcd91a06..39da65c3be 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -7,7 +7,6 @@ #include "op-attrs/datatype.h" #include "utils/exception.h" #include "utils/required.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 52609a303f..9840e457e6 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -4,7 +4,6 @@ #include "accessor.h" #include "kernels/cpu.h" #include "op-attrs/datatype_value.dtg.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/test/src/test_dropout.cc b/lib/kernels/test/src/test_dropout.cc index 981bc611d8..81f3c7183a 100644 --- a/lib/kernels/test/src/test_dropout.cc +++ b/lib/kernels/test/src/test_dropout.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "kernels/dropout_kernels.h" #include "test_utils.h" -#include "utils/containers.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 4bac74b999..3e787503d6 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -3,7 +3,7 @@ #include "pcg/optimizers/adam_optimizer_attrs.h" #include "pcg/optimizers/sgd_optimizer_attrs.h" -#include "utils/variant.h" +#include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc index 02522ae411..2538cb4ea0 100644 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -5,6 +5,7 @@ #include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -25,7 +26,7 @@ std::unordered_set TaskSpaceCoordinate get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { - return maximum(get_task_space_coordinates(task)).value(); + return maximum(get_task_space_coordinates(task)); } size_t num_dims(OperatorTaskSpace const &task) { diff --git a/lib/runtime/src/model.cc b/lib/runtime/src/model.cc index a655bcb050..22f0f2e98d 100644 --- a/lib/runtime/src/model.cc +++ b/lib/runtime/src/model.cc @@ -27,7 +27,6 @@ #include "parallel_tensor_mapping.h" #include "task_spec/task_argument_accessor.h" #include "test_utils.h" -#include "utils/containers.h" #include "utils/random_utils.h" #include #include diff --git a/lib/runtime/src/task_invocation_compilation.cc b/lib/runtime/src/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_invocation_compilation.cc +++ b/lib/runtime/src/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/task_invocation_compilation.cc b/lib/runtime/src/task_spec/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_spec/task_invocation_compilation.cc +++ b/lib/runtime/src/task_spec/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/tensor_args_format.cc b/lib/runtime/src/task_spec/tensor_args_format.cc index 22bd021996..8fb7fc029b 100644 --- a/lib/runtime/src/task_spec/tensor_args_format.cc +++ b/lib/runtime/src/task_spec/tensor_args_format.cc @@ -1,5 +1,5 @@ #include "tensor_args_format.h" -#include "utils/containers.h" +#include "utils/containers/flatmap.h" namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index f1f4e31d57..b701be65cf 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -2,7 +2,7 @@ #include "substitutions/pcg_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/map_values.h" #include "utils/containers/zip.h" @@ -27,7 +27,7 @@ bidict bidict_from_keys_and_values(pattern_node_outputs, matched_layer_output_tensors); - result = merge_bidicts(result, mapping); + result = merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc index 083334f0db..22e6a9f333 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" namespace FlexFlow { @@ -23,7 +23,7 @@ bidict mapping_for_layer = bidict_from_keys_and_values( layer_outputs, output_graph_expr_outputs); - result = merge_bidicts(result, mapping_for_layer); + result = merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc deleted file mode 100644 index 8ce60fab4f..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ /dev/null @@ -1,56 +0,0 @@ -#include "substitutions/unlabelled/multidigraph_pattern_match.h" -// #include "substitutions/unlabelled/edge_splits.h" -// #include "substitutions/unlabelled/pattern_edge.h" -#include "utils/containers.h" - -namespace FlexFlow { - -// MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { -// return MultiDiGraphPatternMatch{ -// bidict{}, -// bidict{}, -// }; -// } - -// std::optional -// unsplit_matches(MultiDiGraphPatternMatch const &prefix, -// MultiDiGraphPatternMatch const &postfix, -// UnlabelledPatternEdgeSplits const &edge_splits) { -// -// MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); -// -// std::unordered_set handled; -// for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { -// ClosedPatternEdge closed_edge = std::get(coi); -// OutputPatternEdge output_edge = std::get(coi); -// InputPatternEdge input_edge = std::get(coi); -// -// handled.insert(pattern_edge_from_output_edge(output_edge)); -// handled.insert(pattern_edge_from_input_edge(input_edge)); -// -// OpenMultiDiEdge output_graph_edge = -// prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); -// OpenMultiDiEdge input_graph_edge = -// postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); -// if (output_graph_edge == input_graph_edge) { -// result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), -// output_graph_edge); -// } else { -// return std::nullopt; -// } -// } -// -// for (auto const &kv : -// merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { -// if (!contains(handled, kv.first)) { -// result.edge_assignment.equate(kv.first, kv.second); -// } -// } -// -// result.node_assignment = -// merge_maps(prefix.node_assignment, postfix.node_assignment); -// -// return result; -// } - -} // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h similarity index 52% rename from lib/utils/include/utils/bidict/algorithms/merge_bidicts.h rename to lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index d388e35d75..97e7334c26 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,17 +1,25 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/bidict/bidict.h" #include "utils/containers/are_disjoint.h" +#include "utils/exception.h" namespace FlexFlow { template -bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { - assert(are_disjoint(left_entries(lhs), left_entries(rhs))); - assert(are_disjoint(right_entries(lhs), right_entries(rhs))); +bidict merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } bidict result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h deleted file mode 100644 index cb652a9e69..0000000000 --- a/lib/utils/include/utils/containers.decl.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H - -#include "utils/bidict/bidict.h" -#include "utils/containers/get_element_type.h" -#include "utils/required_core.h" -#include "utils/type_traits_core.h" -#include -#include -#include - -namespace FlexFlow { - -template -Element sum_where(Container const &container, ConditionF const &condition); - -template -Element product_where(Container const &container, ConditionF const &condition); - -template -bool contains_l(bidict const &m, K const &k); - -template -bool contains_r(bidict const &m, V const &v); - -template -std::unordered_map filter_values(std::unordered_map const &m, - F const &f); - -template -std::optional index_of(Container const &c, Element const &e); - -template -std::unordered_map restrict_keys(std::unordered_map const &m, - std::unordered_set const &mask); - -template -std::optional at_idx(std::vector const &v, size_t idx); - -template -std::function lookup_in(std::unordered_map const &m); - -template -std::function lookup_in_l(bidict const &m); - -template -std::function lookup_in_r(bidict const &m); - -template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r); - -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input); - -template -std::optional maybe_get_only(C const &c); - -template -std::optional optional_all_of(Container const &, Function const &); - -template -std::function compare_by(F const &f); - -template -T reversed(T const &t); - -template -std::vector value_all(std::vector> const &v); - -template -std::unordered_set value_all(std::unordered_set> const &v); - -template -struct reversed_container_t; - -template -reversed_container_t reversed_container(C const &c); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h deleted file mode 100644 index 0e3b1fc0bd..0000000000 --- a/lib/utils/include/utils/containers.h +++ /dev/null @@ -1,233 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL - -#include "containers.decl.h" -#include "required_core.h" -#include "type_traits_core.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/contains.h" -#include "utils/containers/extend.h" -#include "utils/containers/extend_vector.h" -#include "utils/containers/filter.h" -#include "utils/containers/intersection.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/containers/sorted.h" -#include "utils/containers/transform.h" -#include "utils/containers/vector_transform.h" -#include "utils/exception.h" -#include "utils/hash/pair.h" -#include "utils/type_traits.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace FlexFlow { - -template -Element sum_where(Container const &container, ConditionF const &condition) { - Element result = 0; - for (Element const &element : container) { - if (condition(element)) { - result += element; - } - } - return result; -} - -template -Element product_where(Container const &container, ConditionF const &condition) { - Element result = 1; - for (Element const &element : container) { - if (condition(element)) { - result *= element; - } - } - return result; -} - -template -bool contains_l(bidict const &m, K const &k) { - return m.contains_l(k); -} - -template -bool contains_r(bidict const &m, V const &v) { - return m.contains_r(v); -} - -template -bool is_submap(std::unordered_map const &m, - std::unordered_map const &sub) { - return restrict_keys(m, keys(sub)) == sub; -} - -template -std::optional index_of(Container const &c, Element const &e) { - auto it = std::find(c.cbegin(), c.cend(), e); - if (it == c.cend()) { - return std::nullopt; - } else { - return std::distance(c.cbegin(), it); - } -} - -template -std::function lookup_in(std::unordered_map const &m) { - return [&m](K const &k) -> V { return m.at(k); }; -} - -template -std::function lookup_in_l(bidict const &m) { - return [&m](L const &l) -> R { return m.at_l(l); }; -} - -template -std::function lookup_in_r(bidict const &m) { - return [&m](R const &r) -> L { return m.at_r(r); }; -} - -template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - return is_subseteq_of(r, l); -} - -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input) { - std::unordered_set result; - std::transform( - input.cbegin(), input.cend(), std::inserter(result, result.begin()), f); - return result; -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -template -std::function compare_by(F const &f) { - return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; -} - -template -std::vector value_all(std::vector> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -struct reversed_container_t { - reversed_container_t() = delete; - reversed_container_t(C const &c) : container(c) {} - - reversed_container_t(reversed_container_t const &) = delete; - reversed_container_t(reversed_container_t &&) = delete; - reversed_container_t &operator=(reversed_container_t const &) = delete; - reversed_container_t &operator=(reversed_container_t &&) = delete; - - using iterator = typename C::reverse_iterator; - using const_iterator = typename C::const_reverse_iterator; - using reverse_iterator = typename C::iterator; - using const_reverse_iterator = typename C::const_iterator; - using value_type = typename C::value_type; - using pointer = typename C::pointer; - using const_pointer = typename C::const_pointer; - using reference = typename C::reference; - using const_reference = typename C::const_reference; - - iterator begin() { - return this->container.rend(); - } - - iterator end() { - return this->container.rbegin(); - } - - const_iterator cbegin() const { - return this->container.crend(); - } - - const_iterator cend() const { - return this->container.crbegin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator end() const { - return this->cend(); - } - - reverse_iterator rbegin() { - return this->container.begin(); - } - - reverse_iterator rend() { - return this->container.end(); - } - - const_reverse_iterator crbegin() const { - return this->container.cbegin(); - } - - const_reverse_iterator crend() const { - return this->container.cend(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - -private: - C const &container; -}; - -template -reversed_container_t reversed_container(C const &c) { - return reversed_container_t(c); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/are_all_distinct.h b/lib/utils/include/utils/containers/are_all_distinct.h new file mode 100644 index 0000000000..d02845ba16 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_distinct.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H + +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +template +bool are_all_distinct(C const &c) { + return unordered_set_of(c).size() == unordered_multiset_of(c).size(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/compare_by.h b/lib/utils/include/utils/containers/compare_by.h new file mode 100644 index 0000000000..d6cb7f48cd --- /dev/null +++ b/lib/utils/include/utils/containers/compare_by.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H + +#include + +namespace FlexFlow { + +template +std::function compare_by(F const &f) { + return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h deleted file mode 100644 index a616c44c20..0000000000 --- a/lib/utils/include/utils/containers/get_first.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H - -#include -#include - -namespace FlexFlow { - -template -T get_first(std::unordered_set const &s) { - return *s.cbegin(); -} - -template -T get_first(std::set const &s) { - return *s.cbegin(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/get_one_of.h b/lib/utils/include/utils/containers/get_one_of.h new file mode 100644 index 0000000000..47c46fb1d6 --- /dev/null +++ b/lib/utils/include/utils/containers/get_one_of.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H + +#include "utils/exception.h" +#include "utils/fmt/unordered_set.h" +#include +namespace FlexFlow { + +template +T get_one_of(std::unordered_set const &s) { + if (s.empty()) { + throw mk_runtime_error(fmt::format( + "get_one_of expected non-empty container but receieved {}", s)); + } + return *s.cbegin(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/index_of.h b/lib/utils/include/utils/containers/index_of.h new file mode 100644 index 0000000000..1792490d0c --- /dev/null +++ b/lib/utils/include/utils/containers/index_of.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H + +#include +#include + +namespace FlexFlow { + +/** + * @details If multiple `e` are present within the container, the function + * returns the index of the first appearance + **/ +template +std::optional index_of(Container const &c, Element const &e) { + auto it = std::find(c.cbegin(), c.cend(), e); + if (it == c.cend()) { + return std::nullopt; + } else { + return std::distance(c.cbegin(), it); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_submapeq_of.h b/lib/utils/include/utils/containers/is_submapeq_of.h new file mode 100644 index 0000000000..03cb5ccd78 --- /dev/null +++ b/lib/utils/include/utils/containers/is_submapeq_of.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H + +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include + +namespace FlexFlow { + +template +bool is_submapeq_of(std::unordered_map const &sub, + std::unordered_map const &m) { + return restrict_keys(m, keys(sub)) == sub; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_subseteq_of.h b/lib/utils/include/utils/containers/is_subseteq_of.h index 26543ca75b..705c092962 100644 --- a/lib/utils/include/utils/containers/is_subseteq_of.h +++ b/lib/utils/include/utils/containers/is_subseteq_of.h @@ -7,14 +7,14 @@ namespace FlexFlow { template -bool is_subseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - if (l.size() > r.size()) { +bool is_subseteq_of(std::unordered_set const &sub, + std::unordered_set const &super) { + if (sub.size() > super.size()) { return false; } - for (auto const &ll : l) { - if (!contains(r, ll)) { + for (auto const &s : sub) { + if (!contains(super, s)) { return false; } } diff --git a/lib/utils/include/utils/containers/is_superseteq_of.h b/lib/utils/include/utils/containers/is_superseteq_of.h new file mode 100644 index 0000000000..23b16d92f9 --- /dev/null +++ b/lib/utils/include/utils/containers/is_superseteq_of.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H + +#include "utils/containers/is_subseteq_of.h" +#include + +namespace FlexFlow { + +template +bool is_superseteq_of(std::unordered_set const &super, + std::unordered_set const &sub) { + return is_subseteq_of(sub, super); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index e252333e93..4e5352748d 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -1,21 +1,35 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/exception.h" #include #include namespace FlexFlow { +/** + * @brief Applies the given function to all the keys within the given map and + * returns the updated map. + */ template > std::unordered_map map_keys(std::unordered_map const &m, F const &f) { + std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } + if (keys(m).size() != keys(result).size()) { + throw mk_runtime_error( + "keys passed to map_keys must be transformed into distinct keys"); + } + return result; } diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 634bb61bc1..b3d6d0c6d7 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -1,18 +1,19 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#include "utils/exception.h" #include -#include namespace FlexFlow { -template -std::optional maximum(C const &v) { - if (v.empty()) { - return std::nullopt; +template +typename C::value_type maximum(C const &c) { + if (c.empty()) { + throw mk_runtime_error( + fmt::format("maximum expected non-empty container but received {}", c)); } - return *std::max_element(std::cbegin(v), std::cend(v)); + return *std::max_element(c.begin(), c.end()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/product.h b/lib/utils/include/utils/containers/product.h index 52ff36e790..af04edcb81 100644 --- a/lib/utils/include/utils/containers/product.h +++ b/lib/utils/include/utils/containers/product.h @@ -5,6 +5,9 @@ namespace FlexFlow { +/** + * @details An empty container vacuously has product 1 + **/ template Element product(Container const &container) { Element result = 1; diff --git a/lib/utils/include/utils/containers/product_where.h b/lib/utils/include/utils/containers/product_where.h new file mode 100644 index 0000000000..51af47c2fa --- /dev/null +++ b/lib/utils/include/utils/containers/product_where.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H + +#include + +namespace FlexFlow { + +/** + * @details An empty container vacuously has product 1 + **/ +template +Element product_where(Container const &container, ConditionF const &condition) { + Element result = 1; + for (Element const &element : container) { + if (condition(element)) { + result *= element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed_container.h b/lib/utils/include/utils/containers/reversed_container.h new file mode 100644 index 0000000000..cffef3c6e6 --- /dev/null +++ b/lib/utils/include/utils/containers/reversed_container.h @@ -0,0 +1,85 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H + +namespace FlexFlow { + +template +struct reversed_container_t { + reversed_container_t() = delete; + reversed_container_t(C const &c) : container(c) {} + + reversed_container_t(reversed_container_t const &) = delete; + reversed_container_t(reversed_container_t &&) = delete; + reversed_container_t &operator=(reversed_container_t const &) = delete; + reversed_container_t &operator=(reversed_container_t &&) = delete; + + using iterator = typename C::reverse_iterator; + using const_iterator = typename C::const_reverse_iterator; + using reverse_iterator = typename C::iterator; + using const_reverse_iterator = typename C::const_iterator; + using value_type = typename C::value_type; + using pointer = typename C::pointer; + using const_pointer = typename C::const_pointer; + using reference = typename C::reference; + using const_reference = typename C::const_reference; + + iterator begin() { + return this->container.rend(); + } + + iterator end() { + return this->container.rbegin(); + } + + const_iterator cbegin() const { + return this->container.crend(); + } + + const_iterator cend() const { + return this->container.crbegin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator end() const { + return this->cend(); + } + + reverse_iterator rbegin() { + return this->container.begin(); + } + + reverse_iterator rend() { + return this->container.end(); + } + + const_reverse_iterator crbegin() const { + return this->container.cbegin(); + } + + const_reverse_iterator crend() const { + return this->container.cend(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + +private: + C const &container; +}; + +template +reversed_container_t reversed_container(C const &c) { + return reversed_container_t(c); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 5ae90ec5ba..c89e9227de 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -1,7 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H +#include "utils/exception.h" #include +#include #include namespace FlexFlow { @@ -15,11 +17,15 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { + int size = static_cast(v.size()); + int new_idx = idx; if (idx < 0) { - return v.size() + idx; - } else { - return idx; + new_idx = size + idx; } + if (new_idx < 0 || new_idx > size) { + throw mk_runtime_error("Index {} is out of bounds for array {}"); + } + return new_idx; }; if (maybe_start.has_value()) { @@ -29,6 +35,9 @@ std::vector subvec(std::vector const &v, if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (begin_iter >= end_iter) { + return {}; + } if (end_iter < begin_iter) { end_iter = begin_iter; diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h index 5dbd620781..135e704045 100644 --- a/lib/utils/include/utils/containers/sum.h +++ b/lib/utils/include/utils/containers/sum.h @@ -3,11 +3,14 @@ namespace FlexFlow { -template -T sum(C const &c) { - T result = 0; - for (T const &t : c) { - result += t; +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum(Container const &container) { + Element result = 0; + for (Element const &element : container) { + result += element; } return result; } diff --git a/lib/utils/include/utils/containers/sum_where.h b/lib/utils/include/utils/containers/sum_where.h new file mode 100644 index 0000000000..214f51c1c9 --- /dev/null +++ b/lib/utils/include/utils/containers/sum_where.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H + +namespace FlexFlow { + +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum_where(Container const &container, ConditionF const &condition) { + Element result = 0; + for (Element const &element : container) { + if (condition(element)) { + result += element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/value_all.h b/lib/utils/include/utils/containers/value_all.h new file mode 100644 index 0000000000..5727bd8396 --- /dev/null +++ b/lib/utils/include/utils/containers/value_all.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H + +#include "utils/containers/transform.h" +#include "utils/exception.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +std::vector value_all(std::vector> const &v) { + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); + }); + }); +} + +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); + }); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_split.h b/lib/utils/include/utils/containers/vector_split.h index a1ab12a070..872733e3ce 100644 --- a/lib/utils/include/utils/containers/vector_split.h +++ b/lib/utils/include/utils/containers/vector_split.h @@ -1,14 +1,20 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H +#include "utils/fmt/vector.h" +#include +#include #include namespace FlexFlow { template std::pair, std::vector> vector_split(std::vector const &v, - std::size_t idx) { - assert(v.size() > idx); + int idx) { + if (idx < 0 || idx > static_cast(v.size())) { + throw std::out_of_range(fmt::format( + "Index out of range in vector_split: index = {}, vector = {}", idx, v)); + } std::vector prefix(v.begin(), v.begin() + idx); std::vector postfix(v.begin() + idx, v.end()); diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 4170882ae6..7ef7f24eb7 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::tl::expected const &m, FormatContext &ctx) + auto format(::tl::expected const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result; diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 46bf9ca8fa..9225040d4d 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -17,7 +17,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::map const &m, FormatContext &ctx) + auto format(::std::map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index 616b784aac..4234d90e94 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::multiset const &m, FormatContext &ctx) + auto format(::std::multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 2364e49568..16ccf61878 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -14,7 +14,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::optional const &m, FormatContext &ctx) + auto format(::std::optional const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); @@ -29,6 +29,14 @@ struct formatter< } }; +template +struct formatter : formatter { + template + auto format(std::nullopt_t, FormatContext &ctx) const -> decltype(ctx.out()) { + return formatter::format("nullopt", ctx); + } +}; + } // namespace fmt namespace FlexFlow { @@ -40,6 +48,8 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { return s << fmt::to_string(t); } +std::ostream &operator<<(std::ostream &, std::nullopt_t); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index ab5ddd4e28..d261c344a1 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -14,7 +14,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::pair const &m, FormatContext &ctx) + auto format(::std::pair const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(L); CHECK_FMTABLE(R); diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index a183d37542..c46984cc5a 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -16,7 +16,7 @@ struct formatter<::std::set, std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::set const &m, FormatContext &ctx) + auto format(::std::set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 876a032fe6..12faa64e32 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -18,7 +18,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_map const &m, FormatContext &ctx) + auto format(::std::unordered_map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 09dd3c5eab..a4c17f7f5e 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_multiset const &m, FormatContext &ctx) + auto format(::std::unordered_multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index be347ec5ea..20a08916fc 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_set const &m, FormatContext &ctx) + auto format(::std::unordered_set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 06a56417c3..1690955286 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -11,7 +11,7 @@ struct formatter, Char> /* std::enable_if_t>::value>> */ : formatter<::std::string> { template - auto format(std::variant const &m, FormatContext &ctx) + auto format(std::variant const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result = diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 5d9ca0aeae..1eec7a306b 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::vector const &m, FormatContext &ctx) + auto format(::std::vector const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9a655ae072..7aed437136 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -2,8 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_COW_PTR_T_H #include "utils/type_traits.h" -#include "utils/unique.h" -#include "utils/variant.h" #include #include diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index 9eb717b066..7b53384b39 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -13,15 +13,12 @@ std::string join_strings(InputIt first, F const &f) { std::ostringstream oss; bool first_iter = true; - /* int i = 0; */ for (; first != last; first++) { if (!first_iter) { oss << delimiter; } oss << f(*first); - /* break; */ first_iter = false; - /* i++; */ } return oss.str(); } @@ -38,6 +35,12 @@ std::string join_strings(Container const &c, std::string const &delimiter) { return join_strings(c.cbegin(), c.cend(), delimiter); } +template +std::string + join_strings(Container const &c, std::string const &delimiter, F const &f) { + return join_strings(c.cbegin(), c.cend(), delimiter, f); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/rapidcheck/variant.h b/lib/utils/include/utils/rapidcheck/variant.h new file mode 100644 index 0000000000..bc741ea340 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/variant.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf( + gen::construct>(gen::arbitrary())...); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 7a936ebd7b..2a7b2d1849 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -82,6 +82,13 @@ void from_json(nlohmann::json const &j, stack_string &v) { v = stack_string{as_string}; } +template +std::basic_ostream & + operator<<(std::basic_ostream &s, + stack_basic_string const &v) { + return s << fmt::to_string(v); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/unique.h b/lib/utils/include/utils/unique.h deleted file mode 100644 index cf6eb39026..0000000000 --- a/lib/utils/include/utils/unique.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H -#define _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H - -#include - -namespace FlexFlow { -template -std::unique_ptr make_unique(Args &&...args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb2286a9cd..241d631200 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -213,15 +213,4 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::oneOf(gen::cast>(gen::arbitrary())...); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/src/containers.cc b/lib/utils/src/containers.cc deleted file mode 100644 index 2af7fd1892..0000000000 --- a/lib/utils/src/containers.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc deleted file mode 100644 index f70be2355f..0000000000 --- a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/bidict/algorithms/merge_bidicts.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..754b8d2e90 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc index 03c53c9356..51947aa885 100644 --- a/lib/utils/src/utils/cli/cli_get_help_message.cc +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -2,6 +2,7 @@ #include "utils/containers/concat_vectors.h" #include "utils/containers/maximum.h" #include "utils/containers/transform.h" +#include "utils/fmt/vector.h" #include "utils/integer_conversions.h" #include "utils/join_strings.h" #include @@ -53,7 +54,7 @@ std::string cli_get_help_message(std::string const &program_name, if (!all_arg_columns.empty()) { int max_column_width = - std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + std::min(int_from_size_t(maximum(all_arg_column_widths)), 20); auto render_column = [&](std::string const &key, std::optional const &description) { diff --git a/lib/utils/src/utils/containers/are_all_distinct.cc b/lib/utils/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..52c665d191 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_distinct.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_distinct.h" diff --git a/lib/utils/src/utils/containers/compare_by.cc b/lib/utils/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..b7df348f37 --- /dev/null +++ b/lib/utils/src/utils/containers/compare_by.cc @@ -0,0 +1 @@ +#include "utils/containers/compare_by.h" diff --git a/lib/utils/src/utils/containers/get_first.cc b/lib/utils/src/utils/containers/get_first.cc deleted file mode 100644 index ce8eb9cbea..0000000000 --- a/lib/utils/src/utils/containers/get_first.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/get_first.h" diff --git a/lib/utils/src/utils/containers/get_one_of.cc b/lib/utils/src/utils/containers/get_one_of.cc new file mode 100644 index 0000000000..4ce017f6b9 --- /dev/null +++ b/lib/utils/src/utils/containers/get_one_of.cc @@ -0,0 +1 @@ +#include "utils/containers/get_one_of.h" diff --git a/lib/utils/src/utils/containers/index_of.cc b/lib/utils/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..d0c4b4dfd3 --- /dev/null +++ b/lib/utils/src/utils/containers/index_of.cc @@ -0,0 +1 @@ +#include "utils/containers/index_of.h" diff --git a/lib/utils/src/utils/containers/is_submapeq_of.cc b/lib/utils/src/utils/containers/is_submapeq_of.cc new file mode 100644 index 0000000000..567d94fac5 --- /dev/null +++ b/lib/utils/src/utils/containers/is_submapeq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_submapeq_of.h" diff --git a/lib/utils/src/utils/containers/is_superseteq_of.cc b/lib/utils/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..0728c96f17 --- /dev/null +++ b/lib/utils/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_superseteq_of.h" diff --git a/lib/utils/src/utils/containers/product_where.cc b/lib/utils/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..3a435e7d80 --- /dev/null +++ b/lib/utils/src/utils/containers/product_where.cc @@ -0,0 +1 @@ +#include "utils/containers/product_where.h" diff --git a/lib/utils/src/utils/containers/reversed_container.cc b/lib/utils/src/utils/containers/reversed_container.cc new file mode 100644 index 0000000000..1a2fe3cf63 --- /dev/null +++ b/lib/utils/src/utils/containers/reversed_container.cc @@ -0,0 +1 @@ +#include "utils/containers/reversed_container.h" diff --git a/lib/utils/src/utils/containers/sum_where.cc b/lib/utils/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..c09f9c573e --- /dev/null +++ b/lib/utils/src/utils/containers/sum_where.cc @@ -0,0 +1 @@ +#include "utils/containers/sum_where.h" diff --git a/lib/utils/src/utils/containers/value_all.cc b/lib/utils/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..1f863a20e4 --- /dev/null +++ b/lib/utils/src/utils/containers/value_all.cc @@ -0,0 +1 @@ +#include "utils/containers/value_all.h" diff --git a/lib/utils/src/utils/fmt/optional.cc b/lib/utils/src/utils/fmt/optional.cc index e21b32eaa9..4642292920 100644 --- a/lib/utils/src/utils/fmt/optional.cc +++ b/lib/utils/src/utils/fmt/optional.cc @@ -1 +1,9 @@ #include "utils/fmt/optional.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::nullopt_t) { + return (s << std::string{"nullopt"}); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc index d06a64597e..0e4d0c6759 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 8afe7da926..92bd1e32ca 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,12 +1,9 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" -#include "utils/containers/set_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" -#include "utils/fmt/set.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" @@ -16,7 +13,6 @@ #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" -#include "utils/hash/unordered_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index 2eab8371b2..ccd2808603 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 1d98c44a9f..39523f2ec1 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -1,6 +1,6 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominator.h" #include "utils/containers/generate_map.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" @@ -31,7 +31,7 @@ std::optional return get_imm_post_dominator(g, get_only(nodes)); } - Node contracted_node = get_first(nodes); + Node contracted_node = get_one_of(nodes); std::unordered_map contraction = generate_map(nodes, [&](Node const &) { return contracted_node; }); return get_imm_post_dominator(apply_contraction(g, contraction), diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc index d622497629..d75a447127 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index 1dd5353301..fa17678943 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -3,7 +3,7 @@ #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/get_all_permutations.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" diff --git a/lib/utils/src/utils/rapidcheck/variant.cc b/lib/utils/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..f0537d454f --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/variant.cc @@ -0,0 +1,10 @@ +#include "utils/rapidcheck/variant.h" + +namespace rc { + +using T0 = int; +using T1 = std::string; + +template struct Arbitrary>; + +} // namespace rc diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h index 519cde7d74..25552003e3 100644 --- a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -13,6 +13,11 @@ struct StringMaker> { } }; +template <> +struct StringMaker { + static String convert(std::nullopt_t const &); +}; + } // namespace doctest #endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc index 8a3f7f158e..09b02ac059 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -1 +1,9 @@ #include "test/utils/doctest/fmt/optional.h" + +namespace doctest { + +String StringMaker::convert(std::nullopt_t const &m) { + return toString(fmt::to_string(m)); +} + +} // namespace doctest diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc deleted file mode 100644 index 44f602f3bc..0000000000 --- a/lib/utils/test/src/test_algorithms.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include "utils/graph/algorithms.h" -#include "utils/graph/construction.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/undirected.h" -#include -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - } - - TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); - } - - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } - - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } - - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } - } - - TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = { - {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs - } - - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); - } - } - - TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < - index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); - } - - TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); - } - - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_connected_components(g) == expected_components); - } - - TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - - CHECK(get_weakly_connected_components(g) == expected_components); - } -} diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc deleted file mode 100644 index 76b7fd0d31..0000000000 --- a/lib/utils/test/src/test_containers.cc +++ /dev/null @@ -1,393 +0,0 @@ -#include "utils/containers.h" -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); - } - - TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); - } - - TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); - } - - TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); - } - - TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); - } - - TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); - } - - TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); - } - - TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); - } - - TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); - } - - TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); - } - - TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); - } - - TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); - } - - TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); - } - - // TEST_CASE("items") { - // std::unordered_map m = {{1, std::string("one")}, {2, - // std::string("two")}, {3,std::string("three")}}; - // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); - } - - TEST_CASE("unordered_multiset_of") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(unordered_set_of(v) == expected); - } - - TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); - } - - TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); - } - - TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); - } - - TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); - } - - TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - } - - TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); - } - - TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); - } - - TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); - } - - TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); - } - - TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); - } - - TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); - } - - TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); - } - - TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); - } - - TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); - } - - TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); - } - - TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("vector_of") { - std::unordered_set s = {1, 2, 3}; - std::vector result = vector_of(s); - CHECK(result == std::vector({3, 2, 1})); - } - - TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); - } - - TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); - } - - TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); - - CHECK(result == std::vector{0, 1, 2, 3, 4}); - } - - TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); - } - - TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); - - // Checking the maximum is as expected - REQUIRE(result == 5); - } - - TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; - - // Checking the reversed sequence is as expected - CHECK(result == expected); - } - - TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); - } - - TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); - } - - TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); - } - - TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); - } - - TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - - CHECK(subvec_v == std::vector({2, 3, 4})); - - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); - } - - auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); - } - } - return factors; - }; - - // Example for vector - TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); - } -} diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc deleted file mode 100644 index cc7ac1de32..0000000000 --- a/lib/utils/test/src/test_multidigraph.cc +++ /dev/null @@ -1,94 +0,0 @@ -#include "utils/graph/adjacency_multidigraph.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/multidigraph_interfaces.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( - query_set({p[1], p[2]}))) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( - query_set({p[0], p[2]}))) == - std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == - std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } - - SUBCASE("remove_edge") { - g.remove_edge(e[0]); - - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - } - } -} diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc deleted file mode 100644 index 2b816eea4f..0000000000 --- a/lib/utils/test/src/test_random_utils.cc +++ /dev/null @@ -1,67 +0,0 @@ -#include "utils/random_utils.h" -#include -#include - -void checkProbabilities(std::vector const &counts, - int numIterations, - std::vector const &weights, - float totalWeight) { - for (int i = 0; i < counts.size(); i++) { - float expectedProbability = weights[i] / totalWeight; - float observedProbability = static_cast(counts[i]) / numIterations; - CHECK(observedProbability == - doctest::Approx(expectedProbability).epsilon(0.01f)); - } -} - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("select_random") { - std::vector values = {1, 2, 3, 4, 5}; - - SUBCASE("Select random value") { - int result = select_random(values); - - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } - - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); - } - } - - TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } - - checkProbabilities(counts, numIterations, weights, values.size()); - } - - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } - - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; - } - - checkProbabilities(counts, numIterations, weights, totalWeight); - } - } -} diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc deleted file mode 100644 index e7ce12346a..0000000000 --- a/lib/utils/test/src/test_type_index.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "utils/type_index.h" -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("type_index function") { - SUBCASE("int type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); - } - - SUBCASE("string type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); - } - } - - TEST_CASE("matches function") { - std::type_index idx = typeid(float); - - SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); - } - - SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); - } - } -} diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc deleted file mode 100644 index ea519478d3..0000000000 --- a/lib/utils/test/src/test_undirected_graph.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include "test/utils/rapidcheck.h" -#include "test/utils/rapidcheck/visitable.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/undirected.h" -#include - -/* namespace rc { */ - -/* template <> */ -/* struct Arbitrary { */ -/* static Gen arbitrary() { */ -/* int num_nodes = *gen::inRange( */ -/* } */ -/* }; */ - -/* } */ - -using namespace FlexFlow; - -using namespace rc; - -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_strong_typedef::value, ""); */ -/* static_assert(supports_rc_arbitrary>::value, ""); - */ -/* static_assert(supports_rc_arbitrary>::value, - * ""); */ -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_streamable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE( - "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct(gen::elementOf(n), - gen::elementOf(n))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == unordered_set_of(e)); - }); - } -} diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..0a1babd9f9 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -0,0 +1,42 @@ +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_disjoint_bidicts") { + + SUBCASE("disjoint keys and values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "three"}, {4, "four"}}; + + bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping key, different associated value") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + + SUBCASE("overlapping key, same associated value") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "two"}, {3, "three"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + + SUBCASE("overlapping values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "two"}, {4, "four"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + } +} diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index fed655013f..d158af129f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -12,7 +12,37 @@ TEST_SUITE(FF_TEST_SUITE) { dict.equate(1, "one"); dict.equate(2, "two"); - // Test the equate() function + SUBCASE("L type is the same as R type") { + bidict bd; + bd.equate(1, 3); + + SUBCASE("bidict::contains_l") { + CHECK(bd.contains_l(1)); + CHECK_FALSE(bd.contains_l(3)); + } + + SUBCASE("bidict::contains_r") { + CHECK(bd.contains_r(3)); + CHECK_FALSE(bd.contains_r(1)); + } + } + + SUBCASE("L type is not the same as R type") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); + + SUBCASE("bidict::contains_l") { + CHECK(dict.contains_l(1)); + CHECK_FALSE(dict.contains_l(3)); + } + + SUBCASE("bidict::contains_r") { + CHECK(dict.contains_r("one")); + CHECK_FALSE(dict.contains_r("three")); + } + } + SUBCASE("bidict::equate") { CHECK(dict.at_l(1) == "one"); CHECK(dict.at_r("one") == 1); @@ -20,7 +50,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_l() function SUBCASE("bidict::erase_l") { dict.erase_l(1); CHECK(dict.size() == 1); @@ -28,7 +57,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_r() function SUBCASE("bidict::erase_r") { dict.erase_r("one"); CHECK(dict.size() == 1); @@ -36,14 +64,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_l(2) == "two"); } - // Test the reversed() function SUBCASE("bidict::reversed") { bidict reversed_dict = dict.reversed(); CHECK(reversed_dict.at_l("one") == 1); CHECK(reversed_dict.at_r(2) == "two"); } - // Test the size() function SUBCASE("bidict::size") { CHECK(dict.size() == 2); } diff --git a/lib/utils/test/src/utils/containers/all_of.cc b/lib/utils/test/src/utils/containers/all_of.cc new file mode 100644 index 0000000000..247dd62787 --- /dev/null +++ b/lib/utils/test/src/utils/containers/all_of.cc @@ -0,0 +1,13 @@ +#include "utils/containers/all_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 4 == 0; }) == false); + } +} diff --git a/lib/utils/test/src/utils/containers/are_all_distinct.cc b/lib/utils/test/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..6c2e9ea445 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_distinct.cc @@ -0,0 +1,24 @@ +#include "utils/containers/are_all_distinct.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_distinct") { + + SUBCASE("Empty Container") { + std::vector input = {}; + CHECK(are_all_distinct(input)); + } + SUBCASE("All elements are distinct") { + std::vector input = {1, 2, 3, 4}; + CHECK(are_all_distinct(input)); + } + + SUBCASE("Not all elements are distinct") { + std::vector input = {2, 2, 3, 4}; + CHECK_FALSE(are_all_distinct(input)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/are_disjoint.cc b/lib/utils/test/src/utils/containers/are_disjoint.cc new file mode 100644 index 0000000000..17516dbf13 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_disjoint.cc @@ -0,0 +1,31 @@ +#include "utils/containers/are_disjoint.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_disjoint") { + SUBCASE("disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("not disjoint") { + std::unordered_set l = {1, 2, 3, 4}; + std::unordered_set r = {3, 4, 5, 6}; + CHECK_FALSE(are_disjoint(l, r)); + } + + SUBCASE("one empty set") { + std::unordered_set l = {1, 2}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("both empty sets") { + std::unordered_set l = {}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/compare_by.cc b/lib/utils/test/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..8c7221ffb4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/compare_by.cc @@ -0,0 +1,19 @@ +#include "utils/containers/compare_by.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("compare_by") { + std::vector input = {"abc", "a", "ab"}; + auto comp = compare_by( + [](std::string const &s) { return s.length(); }); + std::vector correct = {"a", "ab", "abc"}; + std::sort(input.begin(), input.end(), comp); + CHECK(correct == input); + } +} diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc new file mode 100644 index 0000000000..6e0a84c7ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -0,0 +1,13 @@ +#include "utils/containers/contains.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } +} diff --git a/lib/utils/test/src/utils/containers/count.cc b/lib/utils/test/src/utils/containers/count.cc new file mode 100644 index 0000000000..4d5e05fb9d --- /dev/null +++ b/lib/utils/test/src/utils/containers/count.cc @@ -0,0 +1,13 @@ +#include "utils/containers/count.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index c6ce9942e9..2f9a5b3c02 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,10 +1,11 @@ #include "utils/containers/enumerate.h" #include "test/utils/doctest/fmt/map.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include "utils/containers/keys.h" -#include "utils/containers/unordered_set_of.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include @@ -42,25 +43,13 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("enumerate(std::unordered_set)") { - std::unordered_set input = {"zero", "one", "two", "three"}; - - std::map correct = { - {0, "zero"}, - {1, "one"}, - {2, "two"}, - {3, "three"}, - }; - - std::map result = enumerate(input); - - std::unordered_set result_keys = keys(correct); - std::unordered_set result_values = - unordered_set_of(values(correct)); + std::unordered_set input = {"A", "B", "C", "D"}; std::unordered_set correct_keys = {0, 1, 2, 3}; - std::unordered_set correct_values = input; + std::unordered_multiset correct_values = {"A", "B", "C", "D"}; + std::map result = enumerate(input); - CHECK(result_keys == correct_keys); - CHECK(result_values == correct_values); + CHECK(keys(result) == correct_keys); + CHECK(unordered_multiset_of(values(result)) == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/filter_keys.cc b/lib/utils/test/src/utils/containers/filter_keys.cc new file mode 100644 index 0000000000..00e327a6f1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filter_keys.cc @@ -0,0 +1,18 @@ +#include "utils/containers/filter_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; + std::unordered_map result = filter_keys(m, f); + std::unordered_map correct = {{1, "one"}, {3, "three"}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc new file mode 100644 index 0000000000..36d4b771d8 --- /dev/null +++ b/lib/utils/test/src/utils/containers/find.cc @@ -0,0 +1,54 @@ +#include "utils/containers/find.h" +#include "test/utils/doctest/check_without_stringify.h" +#include +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find") { + + SUBCASE("vector") { + std::vector v = {1, 2, 3, 3, 4, 5, 3}; + + SUBCASE("element found") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } + + SUBCASE("element not found") { + CHECK_WITHOUT_STRINGIFY(find(v, 6) == std::find(v.begin(), v.end(), 6)); + } + + SUBCASE("multiple occurrences of element") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } + } + + SUBCASE("unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + + SUBCASE("element in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } + } + + SUBCASE("set") { + std::set s = {1, 2, 3, 4, 5}; + + SUBCASE("element in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } + } + } +} diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index c10cc5ae75..bd6d3ae5be 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -2,6 +2,7 @@ #include "test/utils/doctest/fmt/pair.h" #include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/map_keys.h" #include "utils/hash/pair.h" #include @@ -10,6 +11,39 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::vector, F)") { + SUBCASE("same data-type") { + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } + } + return factors; + }; + + std::vector input = {2, 3, 4, 5}; + std::vector result = flatmap(input, get_factors); + std::vector correct = {1, 2, 1, 3, 1, 2, 4, 1, 5}; + CHECK(result == correct); + } + + SUBCASE("different data-type") { + auto get_string_sequence = [](int x) -> std::vector { + return { + std::to_string(x - 1), std::to_string(x), std::to_string(2 * x)}; + }; + + std::vector input = {2, 4, 10}; + std::vector result = flatmap(input, get_string_sequence); + std::vector correct = { + "1", "2", "4", "3", "4", "8", "9", "10", "20"}; + CHECK(result == correct); + } + } + TEST_CASE("flatmap(std::unordered_set, F)") { auto get_chars = [](std::string const &s) { std::unordered_set result; diff --git a/lib/utils/test/src/utils/containers/get_one_of.cc b/lib/utils/test/src/utils/containers/get_one_of.cc new file mode 100644 index 0000000000..326a292560 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_one_of.cc @@ -0,0 +1,19 @@ +#include "utils/containers/get_one_of.h" +#include "utils/containers/contains.h" +#include +#include +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_one_of") { + SUBCASE("non-empty set") { + std::unordered_set s = {1, 2, 3}; + CHECK(contains(s, get_one_of(s))); + } + + SUBCASE("empty set") { + std::unordered_set s = {}; + CHECK_THROWS(get_one_of(s)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..6ab49cfd42 --- /dev/null +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -0,0 +1,24 @@ +#include "utils/containers/index_of.h" +#include "test/utils/doctest/fmt/optional.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("index_of") { + + std::vector v = {1, 2, 3, 4, 3, 5}; + + SUBCASE("element occurs once in container") { + CHECK(index_of(v, 4).value() == 3); + } + SUBCASE("if element appears multiple times, return the first occurrence") { + CHECK(index_of(v, 3).value() == 2); + } + SUBCASE("element not in container") { + CHECK(index_of(v, 7) == std::nullopt); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_submapeq_of.cc b/lib/utils/test/src/utils/containers/is_submapeq_of.cc new file mode 100644 index 0000000000..df89444235 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_submapeq_of.cc @@ -0,0 +1,40 @@ +#include "utils/containers/is_submapeq_of.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_submapeq_of") { + std::unordered_map super = { + {1, "one"}, {2, "two"}, {3, "three"}}; + + SUBCASE("keys and values match") { + std::unordered_map sub = {{1, "one"}, {2, "two"}}; + CHECK(is_submapeq_of(sub, super)); + } + + SUBCASE("keys and values don't match") { + std::unordered_map sub = {{1, "one"}, {4, "four"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("keys match but values don't") { + std::unordered_map sub = {{1, "wrong_value"}, + {2, "two"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("values match but keys don't") { + std::unordered_map sub = {{5, "one"}, {6, "two"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("sub is a superset of super") { + std::unordered_map sub = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_subseteq_of.cc b/lib/utils/test/src/utils/containers/is_subseteq_of.cc new file mode 100644 index 0000000000..d762f171b6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_subseteq_of.cc @@ -0,0 +1,16 @@ +#include "utils/containers/is_subseteq_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } +} diff --git a/lib/utils/test/src/utils/containers/is_superseteq_of.cc b/lib/utils/test/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..e3b429fa64 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1,25 @@ +#include "utils/containers/is_superseteq_of.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_superseteq_of") { + std::unordered_set super = {1, 2, 3, 4}; + + SUBCASE("true containment") { + std::unordered_set sub = {1, 2, 3}; + CHECK(is_superseteq_of(super, sub)); + } + + SUBCASE("false containment") { + std::unordered_set sub = {1, 2, 5}; + CHECK_FALSE(is_superseteq_of(super, sub)); + } + + SUBCASE("reflexive") { + CHECK(is_superseteq_of(super, super)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc new file mode 100644 index 0000000000..5bdaef6d08 --- /dev/null +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -0,0 +1,18 @@ +#include "utils/containers/keys.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {1, 2, 3}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc new file mode 100644 index 0000000000..5c0a81d5e6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -0,0 +1,26 @@ +#include "utils/containers/map_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_keys") { + SUBCASE("Distinct keys after transformation") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; + std::unordered_map result = map_keys(m, f); + std::unordered_map correct = {{1, "one"}, {4, "two"}}; + CHECK(correct == result); + } + + SUBCASE("Non-distinct keys after transformation") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {-1, "minus one"}}; + auto f = [](int x) { return std::abs(x); }; + CHECK_THROWS(map_keys(m, f)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/map_values.cc b/lib/utils/test/src/utils/containers/map_values.cc new file mode 100644 index 0000000000..a21645d0d5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_values.cc @@ -0,0 +1,17 @@ +#include "utils/containers/map_values.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {3, "three"}}; + auto f = [](std::string const &s) { return s.size(); }; + std::unordered_map result = map_values(m, f); + std::unordered_map correct = {{1, 3}, {3, 5}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc index 71e7395805..2309458069 100644 --- a/lib/utils/test/src/utils/containers/maximum.cc +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -1,60 +1,25 @@ #include "utils/containers/maximum.h" -#include "test/utils/doctest/fmt/multiset.h" -#include "test/utils/doctest/fmt/optional.h" -#include "test/utils/doctest/fmt/set.h" -#include "test/utils/doctest/fmt/unordered_multiset.h" -#include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include -#include -#include #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("maximum(T)", - T, - std::vector, - std::unordered_set, - std::unordered_multiset, - std::set, - std::multiset) { - SUBCASE("input is empty") { - T input = {}; + TEST_CASE("maximum") { - std::optional result = maximum(input); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("input does not have duplicates") { - T input = {1, 3, 2}; - - std::optional result = maximum(input); - std::optional correct = 3; - - CHECK(result == correct); + SUBCASE("non-empty container") { + std::vector input = {1, 5, 3, 4, 2}; + int correct = 5; + int result = maximum(input); + CHECK(correct == result); } - SUBCASE("input has duplicates") { - T input = {1, 2, 2, 0}; + SUBCASE("empty container") { + std::vector input = {}; - std::optional result = maximum(input); - std::optional correct = 2; - - CHECK(result == correct); + CHECK_THROWS(maximum(input)); } } - - TEST_CASE("maximum(std::vector)") { - std::vector input = {"hello", "world"}; - - std::optional result = maximum(input); - std::optional correct = "world"; - - CHECK(result == correct); - } } diff --git a/lib/utils/test/src/utils/containers/merge_maps.cc b/lib/utils/test/src/utils/containers/merge_maps.cc new file mode 100644 index 0000000000..a083e94de3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_maps.cc @@ -0,0 +1,30 @@ +#include "utils/containers/merge_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_maps") { + + SUBCASE("disjoint keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS(merge_maps(lhs, rhs)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/product.cc b/lib/utils/test/src/utils/containers/product.cc new file mode 100644 index 0000000000..3fa94c8e9e --- /dev/null +++ b/lib/utils/test/src/utils/containers/product.cc @@ -0,0 +1,32 @@ +#include "utils/containers/product.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("product", + C, + std::vector, + std::vector, + std::set, + std::unordered_set) { + + SUBCASE("non-empty container") { + C input = {1, -2, 3, 5}; + auto correct = -30; + auto result = product(input); + CHECK(correct == result); + } + + SUBCASE("empty container") { + C input = {}; + auto correct = 1; + auto result = product(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/product_where.cc b/lib/utils/test/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..098ae01252 --- /dev/null +++ b/lib/utils/test/src/utils/containers/product_where.cc @@ -0,0 +1,34 @@ +#include "utils/containers/product_where.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("product_where") { + + SUBCASE("empty starting container") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("non-empty filtered container") { + std::vector input = {1, -2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = -8; + int result = product_where(input, condition); + CHECK(correct == result); + } + SUBCASE("empty filtered container") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/restrict_keys.cc b/lib/utils/test/src/utils/containers/restrict_keys.cc new file mode 100644 index 0000000000..b4b376784e --- /dev/null +++ b/lib/utils/test/src/utils/containers/restrict_keys.cc @@ -0,0 +1,17 @@ +#include "utils/containers/restrict_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map correct = {{2, "two"}, {3, "three"}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/set_union.cc b/lib/utils/test/src/utils/containers/set_union.cc new file mode 100644 index 0000000000..d842e4df96 --- /dev/null +++ b/lib/utils/test/src/utils/containers/set_union.cc @@ -0,0 +1,16 @@ +#include "utils/containers/set_union.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set correct = {1, 2, 3, 4}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc new file mode 100644 index 0000000000..0ae2e0da77 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -0,0 +1,35 @@ +#include "utils/containers/sorted_by.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sorted_by") { + SUBCASE("sort increasing") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, [](int a, int b) { return a < b; }); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("sort decreasing") { + std::unordered_set input = {-5, -1, -3, -2, -4}; + std::vector result = + sorted_by(input, [](int a, int b) { return a > b; }); + std::vector correct = {-1, -2, -3, -4, -5}; + CHECK(result == correct); + } + + SUBCASE("container contains duplicate elements") { + std::vector input = {3, 1, 3, -4, 1}; + std::vector result = + sorted_by(input, [](int a, int b) { return a < b; }); + std::vector correct = {-4, 1, 1, 3, 3}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc new file mode 100644 index 0000000000..610fc55b5a --- /dev/null +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -0,0 +1,62 @@ +#include "utils/containers/subvec.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("subvec") { + std::vector v = {1, 2, 3, 4, 5}; + + SUBCASE("Basic subvector") { + auto result = subvec(v, 1, 4); + std::vector correct = {2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("From beginning to index") { + auto result = subvec(v, std::nullopt, 3); + std::vector correct = {1, 2, 3}; + CHECK(result == correct); + } + + SUBCASE("From index to end") { + auto result = subvec(v, 2, std::nullopt); + std::vector correct = {3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("All of the vector") { + auto result = subvec(v, std::nullopt, std::nullopt); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("Start greater than end") { + auto result = subvec(v, 3, 1); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Start equal to end") { + auto result = subvec(v, 3, 3); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Negative indices") { + auto result = subvec(v, -3, -1); + std::vector correct = {3, 4}; + CHECK(result == correct); + } + + SUBCASE("Upper index is out of bounds by 1") { + CHECK_THROWS(subvec(v, 2, 6)); + } + + SUBCASE("Lower index is out of bounds by 1") { + CHECK_THROWS(subvec(v, -6, 2)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum_where.cc b/lib/utils/test/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..7a909aea39 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum_where.cc @@ -0,0 +1,36 @@ +#include "utils/containers/sum_where.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("sum_where") { + + SUBCASE("starting container is empty") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("resulting container is non-empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 6; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("resulting container is empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/value_all.cc b/lib/utils/test/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..1a5f2c508a --- /dev/null +++ b/lib/utils/test/src/utils/containers/value_all.cc @@ -0,0 +1,24 @@ +#include "utils/containers/value_all.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("value_all") { + SUBCASE("With nullopt") { + std::vector> input = {1, 2, std::nullopt, 4, 5}; + CHECK_THROWS(value_all(input)); + } + + SUBCASE("Without nullopt") { + std::vector> input = {1, 2, 3, 4, 5}; + std::vector correct = {1, 2, 3, 4, 5}; + std::vector result = value_all(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/values.cc b/lib/utils/test/src/utils/containers/values.cc new file mode 100644 index 0000000000..5fe69ac5e9 --- /dev/null +++ b/lib/utils/test/src/utils/containers/values.cc @@ -0,0 +1,19 @@ +#include "utils/containers/values.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}, {33, "three"}}; + std::unordered_multiset result = values(m); + std::unordered_multiset correct = { + "one", "two", "three", "three"}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc new file mode 100644 index 0000000000..76bb21348e --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -0,0 +1,39 @@ +#include "utils/containers/vector_split.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + + SUBCASE("Normal case: idx = 2") { + auto [prefix, postfix] = vector_split(v, 2); + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } + + SUBCASE("Boundary case: idx = 0") { + auto [prefix, postfix] = vector_split(v, 0); + CHECK(prefix.empty()); + CHECK(postfix == std::vector({1, 2, 3, 4, 5})); + } + + SUBCASE("Boundary case: idx == list_size") { + auto [prefix, postfix] = vector_split(v, 5); + CHECK(prefix == std::vector({1, 2, 3, 4, 5})); + CHECK(postfix.empty()); + } + + SUBCASE("Out of bounds case: idx = -1") { + CHECK_THROWS_AS(vector_split(v, -1), std::out_of_range); + } + + SUBCASE("Out of bounds case: idx == list_size + 1") { + CHECK_THROWS_AS(vector_split(v, 6), std::out_of_range); + } + } +} diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/utils/deduplicated_priority_queue.cc similarity index 96% rename from lib/utils/test/src/test_deduplicated_priority_queue.cc rename to lib/utils/test/src/utils/deduplicated_priority_queue.cc index 048e95acb7..d5b35cf654 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/utils/deduplicated_priority_queue.cc @@ -1,6 +1,8 @@ #include "utils/deduplicated_priority_queue.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { DeduplicatedPriorityQueue queue; diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc similarity index 73% rename from lib/utils/test/src/test_disjoint_set.cc rename to lib/utils/test/src/utils/disjoint_set.cc index 65037be3dd..be88a30cdd 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,4 +1,5 @@ #include "utils/disjoint_set.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace FlexFlow; @@ -18,10 +19,10 @@ std::string generate_element(int seed) { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; + disjoint_set> ds; SUBCASE("SingleElementSets") { - optional element = generate_element(1); + std::optional element = generate_element(1); CHECK(ds.find(element) == element); element = generate_element(2); @@ -29,10 +30,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + std::optional element1 = generate_element(1); + std::optional element2 = generate_element(2); + std::optional element3 = generate_element(3); + std::optional element4 = generate_element(4); ds.m_union(element1, element2); CHECK(ds.find(element1) == ds.find(element2)); @@ -55,11 +56,11 @@ TEST_SUITE(FF_TEST_SUITE) { ds.m_union(1, 4); ds.m_union(5, 6); - std::map, optional, OptionalComparator> + std::map, std::optional, OptionalComparator> expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, std::optional, OptionalComparator> + mapping = ds.get_mapping(); for (auto const &kv : mapping) { CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/utils/dot_file.cc similarity index 100% rename from lib/utils/test/src/test_dot_file.cc rename to lib/utils/test/src/utils/dot_file.cc diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc new file mode 100644 index 0000000000..ce4d7a373b --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -0,0 +1,171 @@ +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/containers/contains.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/multidiedge_query.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + MultiDiEdge e0 = g.add_edge(n0, n2); + MultiDiEdge e1 = g.add_edge(n1, n0); + MultiDiEdge e2 = g.add_edge(n1, n0); + MultiDiEdge e3 = g.add_edge(n1, n2); + MultiDiEdge e4 = g.add_edge(n1, n2); + MultiDiEdge e5 = g.add_edge(n2, n0); + MultiDiEdge e6 = g.add_edge(n2, n2); + + SUBCASE("add_node") { + Node n3 = g.add_node(); + std::unordered_set result = g.query_nodes(NodeQuery{{n3}}); + std::unordered_set correct = {n3}; + CHECK(result == correct); + } + + SUBCASE("add_edge") { + SUBCASE("non-duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7}; + CHECK(result == correct); + } + + SUBCASE("duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + MultiDiEdge e8 = g.add_edge(n2, n1); + + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7, e8}; + CHECK(result == correct); + } + } + + SUBCASE("remove_node") { + g.remove_node(n0); + + std::unordered_set node_result = g.query_nodes(NodeQuery{{n0}}); + std::unordered_set node_correct = {}; + CHECK(node_result == node_correct); + + std::unordered_set edge_result = + g.query_edges(MultiDiEdgeQuery({n0}, {n1, n2})); + std::unordered_set edge_correct = {}; + CHECK(edge_result == edge_correct); + } + + SUBCASE("remove_edge") { + g.remove_edge(e3); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n2})); + std::unordered_set correct = {e4}; + CHECK(result == correct); + + SUBCASE("remove non-duplicate edge") { + g.remove_edge(e0); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0}, {n2})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + + SUBCASE("remove duplicate edge") { + g.remove_edge(e1); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0})); + std::unordered_set correct = {e2}; + CHECK(result == correct); + } + } + + SUBCASE("query_nodes") { + SUBCASE("all nodes") { + std::unordered_set result = + g.query_nodes(NodeQuery{{n0, n1, n2}}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("specific nodes") { + std::unordered_set result = g.query_nodes(NodeQuery{{n0, n2}}); + std::unordered_set correct = {n0, n2}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_nodes(NodeQuery{matchall()}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("nodes not in graph") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = g.query_nodes(NodeQuery{{n3, n4}}); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + + SUBCASE("query_edges") { + SUBCASE("all edges") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("edges from n1") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); + std::unordered_set correct = {e1, e2, e3, e4}; + CHECK(result == correct); + } + + SUBCASE("edges to n2") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); + std::unordered_set correct = {e0, e3, e4, e6}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery(matchall(), matchall())); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("nodes that don't exist") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1, n3}, {n4})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + SUBCASE("get_multidiedge_src") { + Node result = g.get_multidiedge_src(e0); + Node correct = n0; + CHECK(result == correct); + } + + SUBCASE("get_multidiedge_dst") { + Node result = g.get_multidiedge_dst(e0); + Node correct = n2; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/hash/map.cc b/lib/utils/test/src/utils/hash/map.cc new file mode 100644 index 0000000000..b4da6ddb68 --- /dev/null +++ b/lib/utils/test/src/utils/hash/map.cc @@ -0,0 +1,20 @@ +#include "utils/hash/map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::map map1{{1, 2}}; + std::map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/set.cc b/lib/utils/test/src/utils/hash/set.cc new file mode 100644 index 0000000000..f9ccd925c6 --- /dev/null +++ b/lib/utils/test/src/utils/hash/set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::set set1{1, 2, 3}; + std::set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/tuple.cc b/lib/utils/test/src/utils/hash/tuple.cc new file mode 100644 index 0000000000..61240fd7b1 --- /dev/null +++ b/lib/utils/test/src/utils/hash/tuple.cc @@ -0,0 +1,21 @@ +#include "utils/hash/tuple.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::tuple tuple1{1, "test", 3.14}; + std::tuple tuple2{2, "test", 3.14}; + + size_t hash1 = get_std_hash(tuple1); + size_t hash2 = get_std_hash(tuple2); + + CHECK(hash1 != hash2); + + std::get<0>(tuple1) = 2; + hash1 = get_std_hash(tuple1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/utils/hash/unordered_map.cc similarity index 69% rename from lib/utils/test/src/test_hash.cc rename to lib/utils/test/src/utils/hash/unordered_map.cc index decf405e7a..bc0b2879a4 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/utils/hash/unordered_map.cc @@ -1,10 +1,10 @@ -#include "utils/hash-utils.h" +#include "utils/hash/unordered_map.h" #include -using namespace FlexFlow; +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("hash:unordered_map") { + TEST_CASE("std::hash>") { std::unordered_map map1{{1, 2}}; std::unordered_map map2{{1, 2}, {3, 4}}; @@ -13,7 +13,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(hash1 != hash2); - map1.insert({1, 2}); + map1.insert({3, 4}); hash1 = get_std_hash(map1); CHECK(hash1 == hash2); } diff --git a/lib/utils/test/src/utils/hash/unordered_set.cc b/lib/utils/test/src/utils/hash/unordered_set.cc new file mode 100644 index 0000000000..299b3e5d13 --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_set set1{1, 2, 3}; + std::unordered_set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/vector.cc b/lib/utils/test/src/utils/hash/vector.cc new file mode 100644 index 0000000000..ee5bb456eb --- /dev/null +++ b/lib/utils/test/src/utils/hash/vector.cc @@ -0,0 +1,20 @@ +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::vector vec1{1, 2, 3}; + std::vector vec2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(vec1); + size_t hash2 = get_std_hash(vec2); + + CHECK(hash1 != hash2); + + vec1.push_back(4); + hash1 = get_std_hash(vec1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc new file mode 100644 index 0000000000..ca1887e84d --- /dev/null +++ b/lib/utils/test/src/utils/join_strings.cc @@ -0,0 +1,47 @@ +#include "utils/join_strings.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("join_strings") { + std::vector v = {"Hello", "world", "!"}; + + SUBCASE("iterator") { + std::string result = join_strings(v.begin(), v.end(), " "); + std::string correct = "Hello world !"; + CHECK(result == correct); + } + + SUBCASE("join_strings with container") { + std::string result = join_strings(v, " "); + std::string correct = "Hello world !"; + CHECK(result == correct); + } + + SUBCASE("join_strings with transforming function") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + std::string result = join_strings(v, " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); + } + + SUBCASE("join_strings with transforming function, iterator") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + std::string result = + join_strings(v.begin(), v.end(), " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); + } + + SUBCASE("empty sequence") { + v = {}; + std::string result = join_strings(v, "!"); + std::string correct = ""; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/random_utils.cc b/lib/utils/test/src/utils/random_utils.cc new file mode 100644 index 0000000000..8e7d22138f --- /dev/null +++ b/lib/utils/test/src/utils/random_utils.cc @@ -0,0 +1,60 @@ +#include "utils/random_utils.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/containers/repeat.h" +#include "utils/containers/sum.h" +#include "utils/containers/zip.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("select_random(std::vector)") { + std::vector values = {1, 2, 3, 4, 5}; + + SUBCASE("selected value is in container") { + SUBCASE("equal weights") { + int result = select_random(values); + CHECK(contains(values, result)); + } + + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.3f, 0.2f, 0.2f, 0.2f}; + int result = select_random(values, weights); + CHECK(contains(values, result)); + } + } + + SUBCASE("correct distribution") { + auto check_probabilities = [](std::vector const &values, + std::vector const &weights) { + int num_iterations = 10'000; + std::vector trials = repeat( + num_iterations, [&]() { return select_random(values, weights); }); + + for (std::pair const &p : zip(values, weights)) { + int v = p.first; + float w = p.second; + float expectedProbability = w / sum(weights); + int num_occurrences = + filter(trials, [&](int c) { return (c == v); }).size(); + float observedProbability = + static_cast(num_occurrences) / num_iterations; + CHECK(observedProbability == + doctest::Approx(expectedProbability).epsilon(0.01f)); + } + }; + + SUBCASE("equal weights") { + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + check_probabilities(values, weights); + } + + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + check_probabilities(values, weights); + } + } + } +} diff --git a/lib/utils/test/src/utils/rapidcheck/variant.cc b/lib/utils/test/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..ca830fb3ab --- /dev/null +++ b/lib/utils/test/src/utils/rapidcheck/variant.cc @@ -0,0 +1,13 @@ +#include "utils/rapidcheck/variant.h" +#include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Arbitrary") { + RC_SUBCASE("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + }); + } +} diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/utils/record_formatter.cc similarity index 100% rename from lib/utils/test/src/test_format.cc rename to lib/utils/test/src/utils/record_formatter.cc diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/utils/sequence.cc similarity index 100% rename from lib/utils/test/src/test_sequence.cc rename to lib/utils/test/src/utils/sequence.cc diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/utils/stack_map.cc similarity index 73% rename from lib/utils/test/src/test_stack_map.cc rename to lib/utils/test/src/utils/stack_map.cc index f117820c5d..3ab4faaa80 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -1,4 +1,6 @@ #include "utils/stack_map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace FlexFlow; @@ -6,8 +8,8 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("stack_map") { stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { + + SUBCASE("operator[]") { map[1] = 10; map[2] = 20; @@ -15,8 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the insert() function - SUBCASE("Insert") { + SUBCASE("insert") { map.insert(1, 10); map.insert(2, 20); @@ -24,21 +25,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the at() function to access elements - SUBCASE("At") { + SUBCASE("at") { map[1] = 10; map[2] = 20; CHECK(map.at(1) == 10); CHECK(map.at(2) == 20); CHECK(map.at(1) != 20); - // Test const version of at() function + stack_map const &const_map = map; CHECK(const_map.at(1) == 10); CHECK(const_map.at(2) == 20); } - // Test the begin() and end() functions for iterator SUBCASE("Iterator") { map[1] = 10; map[2] = 20; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/utils/stack_string.cc similarity index 95% rename from lib/utils/test/src/test_stack_string.cc rename to lib/utils/test/src/utils/stack_string.cc index b89e3277cd..8dbfe36c9d 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -1,5 +1,5 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include "test/utils/rapidcheck.h" #include using namespace FlexFlow; @@ -84,6 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { constexpr std::size_t MAXSIZE = 10; - RCSUBCASE([](stack_string const &s) {}); + RC_SUBCASE([&](stack_string const &s) { + RC_ASSERT(s.size() <= MAXSIZE); + }); } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc similarity index 68% rename from lib/utils/test/src/test_stack_vector.cc rename to lib/utils/test/src/utils/stack_vector.cc index 577e61092c..6cdd91ece1 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,28 +1,31 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" #include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::push_back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); + std::vector result = vector; + std::vector correct = {10}; + CHECK(result == correct); vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); + correct = {10, 20}; + result = vector; + CHECK(result == correct); } - TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator[]", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -36,7 +39,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector[2] == 30); } - TEST_CASE_TEMPLATE("Size", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::size", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -50,7 +53,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.size() == 2); } - TEST_CASE_TEMPLATE("==", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator==", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector1, vector2; @@ -66,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector1 == vector2); } - TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -78,7 +82,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.back() == 20); } - TEST_CASE_TEMPLATE("Arbitrary", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector - check for size bound", T, int, double, char) { constexpr std::size_t MAXSIZE = 10; RC_SUBCASE("within bound", [&](stack_vector v) { RC_ASSERT(v.size() <= MAXSIZE); diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/utils/tuple.cc similarity index 73% rename from lib/utils/test/src/test_tuple.cc rename to lib/utils/test/src/utils/tuple.cc index 96171510a7..01a1ebca18 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -1,6 +1,5 @@ #include "utils/tuple.h" #include - #include #include @@ -37,23 +36,23 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); + TEST_CASE("tuple_prepend") { + std::tuple t1 = {3.14f, 2.71828}; int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); + std::tuple result = tuple_prepend(value, t1); + std::tuple correct = {42, 3.14f, 2.71828}; + CHECK(tuple_compare(result, correct)); } - TEST_CASE("Testing tuple_head_t") { + TEST_CASE("tuple_head_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, std::tuple<>>::value); } - TEST_CASE("Testing tuple_slice_t") { + TEST_CASE("tuple_slice_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, @@ -62,17 +61,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::tuple>::value); } - TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; + TEST_CASE("tuple_compare") { + std::tuple tup1 = {1, 3.14, 'a'}; + std::tuple tup2 = {1, 3.14, 'a'}; + std::tuple tup3 = {2, 3.14, 'b'}; CHECK(tuple_compare(tup1, tup2)); CHECK(!tuple_compare(tup1, tup3)); } - TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("get") { + std::tuple tup = {1, 3.14, 'a'}; CHECK(get(tup) == 1); CHECK(get(tup) == 3.14); diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc new file mode 100644 index 0000000000..0d53868a92 --- /dev/null +++ b/lib/utils/test/src/utils/type_index.cc @@ -0,0 +1,34 @@ +#include "utils/type_index.h" +#include "test/utils/doctest/check_without_stringify.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_type_index_for_type") { + SUBCASE("int type") { + std::type_index idx = get_type_index_for_type(); + std::type_index expected_idx = typeid(int); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); + } + + SUBCASE("string type") { + std::type_index idx = get_type_index_for_type(); + std::type_index expected_idx = typeid(std::string); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); + } + } + + TEST_CASE("matches(std::type_index)") { + std::type_index idx = typeid(float); + + SUBCASE("matching type") { + CHECK(matches(idx)); + } + + SUBCASE("non-matching type") { + CHECK_FALSE(matches(idx)); + } + } +} diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/utils/variant.cc similarity index 91% rename from lib/utils/test/src/test_variant.cc rename to lib/utils/test/src/utils/variant.cc index 0bd01b8dfe..3f6feadda0 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,7 +1,10 @@ -#include "test/utils/rapidcheck.h" #include "utils/variant.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/variant.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { @@ -70,10 +73,4 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } - - TEST_CASE("Arbitrary") { - RC_SUBCASE("valid type", [](std::variant v) { - return std::holds_alternative(v) || std::holds_alternative(v); - }); - } } diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/utils/vector.cc similarity index 91% rename from lib/utils/test/src/test_vector.cc rename to lib/utils/test/src/utils/vector.cc index c6eb0828b8..18eda49543 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,9 @@ #include "utils/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") {