diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 9da33023a0..64419acce4 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -104,7 +104,8 @@ tl::expected to_v1_including_node_numbering(computation_graph); V1ComputationGraph v1_cg = v1_result.first; bidict layer_numbering = v1_result.second; - V1BinarySPDecomposition v1_sp_decomposition = to_v1(sp_decomposition, layer_numbering); + V1BinarySPDecomposition v1_sp_decomposition = + to_v1(sp_decomposition, layer_numbering); return JsonSPModelExport{ v1_sp_decomposition, diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h index 2eccd36719..29e9e7c90b 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -10,8 +10,11 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation - generic_binary_sp_impl_for_mm_problem_tree(); +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index eb50ee365e..fdc80a1e37 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -12,16 +12,18 @@ namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, - ComputationGraphBinarySeriesSplit, - ComputationGraphBinaryParallelSplit, - layer_guid_t> generic_impl_for_computation_graph_sp_tree(); + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree(); SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &); +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &); std::optional get_computation_graph_left_assoc_binary_sp_decomposition( @@ -34,7 +36,7 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &); std::unordered_multiset get_layers(ComputationGraphBinarySPDecomposition const &); -V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, bidict const &layer_numbering); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h index 05a1ae1169..f348b1a851 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -6,7 +6,8 @@ namespace FlexFlow { -BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &); +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h index 83e53e3d41..0842ffb48f 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -6,7 +6,8 @@ namespace FlexFlow { -BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); +BinarySeriesSplit + binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h index e8c02ebfb5..86fa1a59aa 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -13,13 +13,14 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t> generic_impl_for_pcg_sp_tree(); - -BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); std::optional get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 53b8d5bdd6..0e0f60c891 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -16,9 +16,8 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( std::unordered_set edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) - -> AbstractedSingleTensorMovement - { + auto get_movement_for_tensor = + [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { std::unordered_set tensor_edges = filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index bf44ef0fd7..10abd7ff90 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -45,7 +45,8 @@ MachineMappingResult } } - MachineMappingResult result = problem_tree.visit(overload{ + MachineMappingResult result = + problem_tree.visit(overload{ [&](MMProblemTreeSeriesSplit const &series_split) { return get_optimal_machine_mapping( result_cache, @@ -86,8 +87,9 @@ MachineMappingResult [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedOpCostEstimateKey leaf = mm_problem_tree_get_subtree_at_path( - MachineMappingProblemTree{series_split}, l) - .value().get(); + MachineMappingProblemTree{series_split}, l) + .value() + .get(); return context.allowed_machine_views(leaf, resources); }); return transform( @@ -130,7 +132,8 @@ MachineMappingResult }; MachineMappingResult result = infeasible_machine_mapping_result(); - AbstractedTensorSetMovement tensor_movement = series_split.tensor_set_movement; + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; for (ParallelLayerGuidObliviousMachineMapping const &assigned_pre_machine_views : @@ -178,9 +181,9 @@ MachineMappingResult get_optimal_machine_mapping( MachineMappingResult series_result = [&] { MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), - /*left_child=*/lhs, - /*right_child=*/rhs, + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, }; return get_optimal_machine_mapping(result_cache, diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index ada271580f..367af3701e 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -20,32 +20,31 @@ MachineMappingProblemTree get_machine_mapping_problem_tree( to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { return sp.visit(overload{ - [&](PCGBinarySeriesSplit const &series) { - AbstractedTensorSetMovement tensor_movement = - get_abstracted_tensor_set_movement_across_split(tr_pcg, - series); - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/tensor_movement, - /*lhs=*/to_problem_tree(series.get_left_child()), - /*rhs=*/to_problem_tree(series.get_right_child()), - }, - }; - }, - [&](PCGBinaryParallelSplit const ¶llel) { - return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - to_problem_tree(parallel.get_left_child()), - to_problem_tree(parallel.get_right_child()), - }, - }; - }, - [&](parallel_layer_guid_t const &leaf) { - return MachineMappingProblemTree{ - get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), - }; - }, - }); + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; + }, + [&](parallel_layer_guid_t const &leaf) { + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); }; return to_problem_tree(sp_decomposition_tree); diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index a5b3cab43e..1e39a7be19 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -5,46 +5,69 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation - generic_binary_sp_impl_for_mm_problem_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { return GenericBinarySPDecompositionTreeImplementation< - MachineMappingProblemTree, - MMProblemTreeSeriesSplit, - MMProblemTreeParallelSplit, - UnmappedOpCostEstimateKey>{ - /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeSeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](MachineMappingProblemTree const &tree) -> UnmappedOpCostEstimateKey const & { - return tree.get(); - }, + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](MachineMappingProblemTree const &tree) + -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, }; } SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { - return tree.visit(overload { - [](MMProblemTreeSeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](MMProblemTreeParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](UnmappedOpCostEstimateKey const &) { return SPDecompositionTreeNodeType::NODE; }, + return tree.visit(overload{ + [](MMProblemTreeSeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](UnmappedOpCostEstimateKey const &) { + return SPDecompositionTreeNodeType::NODE; + }, }); } @@ -61,7 +84,8 @@ std::unordered_set std::optional mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, BinaryTreePath const &path) { - return get_subtree_at_path(tree, generic_binary_sp_impl_for_mm_problem_tree(), path); + return get_subtree_at_path( + tree, generic_binary_sp_impl_for_mm_problem_tree(), path); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc index 004aca6a81..96c8106cad 100644 --- a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -41,7 +41,8 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); std::unordered_set raw_edges = get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); @@ -57,7 +58,8 @@ std::unordered_set TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); std::unordered_set raw_outputs = get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); @@ -72,7 +74,8 @@ PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( TransitiveReducedDataflowGraphView raw_tr_g = get_underlying_transitive_reduced_dataflow_graph(tr_pcg); - BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split); + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); SplitBoundaryNodes raw_boundary = get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc index f26b899109..32fb53b58a 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -10,55 +10,68 @@ namespace FlexFlow { GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, - ComputationGraphBinarySeriesSplit, - ComputationGraphBinaryParallelSplit, - layer_guid_t> generic_impl_for_computation_graph_sp_tree() { - - return GenericBinarySPDecompositionTreeImplementation< - ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySPDecomposition, ComputationGraphBinarySeriesSplit, ComputationGraphBinaryParallelSplit, - layer_guid_t>{ - /*series_get_left_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](ComputationGraphBinarySeriesSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](ComputationGraphBinaryParallelSplit const &split) -> ComputationGraphBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](ComputationGraphBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinarySeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](ComputationGraphBinarySPDecomposition const &tree) -> ComputationGraphBinaryParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](ComputationGraphBinarySPDecomposition const &tree) -> layer_guid_t const & { - return tree.get(); - }, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> SPDecompositionTreeNodeType { return get_node_type(tree); }, + /*require_series=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> layer_guid_t const & { return tree.get(); }, }; } SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &tree) { - return tree.visit(overload { - [](ComputationGraphBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](ComputationGraphBinaryParallelSplit const ¶llel) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](layer_guid_t const &leaf) { - return SPDecompositionTreeNodeType::NODE; - }, + return tree.visit(overload{ + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, }); } @@ -66,30 +79,35 @@ layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { return tree.get(); } -ComputationGraphBinarySPDecomposition - computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &bin) { - return bin.visit(overload { - [](BinarySeriesSplit const &series) { - return ComputationGraphBinarySPDecomposition{ - ComputationGraphBinarySeriesSplit{ - computation_graph_sp_decomp_from_binary_sp_decomp(series.get_left_child()), - computation_graph_sp_decomp_from_binary_sp_decomp(series.get_right_child()), - }, - }; - }, - [](BinaryParallelSplit const ¶llel) { - return ComputationGraphBinarySPDecomposition{ - ComputationGraphBinaryParallelSplit{ - computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_left_child()), - computation_graph_sp_decomp_from_binary_sp_decomp(parallel.get_right_child()), - }, - }; - }, - [](Node const &node) { - return ComputationGraphBinarySPDecomposition{ - layer_guid_t{node}, - }; - }, +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &bin) { + return bin.visit(overload{ + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, }); } @@ -130,11 +148,13 @@ std::optional } bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { - return is_binary_sp_tree_left_associative(tree, generic_impl_for_computation_graph_sp_tree()); + return is_binary_sp_tree_left_associative( + tree, generic_impl_for_computation_graph_sp_tree()); } bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { - return is_binary_sp_tree_right_associative(tree, generic_impl_for_computation_graph_sp_tree()); + return is_binary_sp_tree_right_associative( + tree, generic_impl_for_computation_graph_sp_tree()); } std::unordered_multiset @@ -142,31 +162,31 @@ std::unordered_multiset return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); } -V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &tree, - bidict const &layer_numbering) { - return tree.visit(overload { - [&](ComputationGraphBinarySeriesSplit const &series) { - return V1BinarySPDecomposition{ - V1BinarySeriesSplit{ - to_v1(series.get_left_child(), layer_numbering), - to_v1(series.get_right_child(), layer_numbering), - }, - }; - }, - [&](ComputationGraphBinaryParallelSplit const ¶llel) { - return V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - to_v1(parallel.get_left_child(), layer_numbering), - to_v1(parallel.get_right_child(), layer_numbering), - }, - }; - }, - [&](layer_guid_t const &layer) { - return V1BinarySPDecomposition{ - layer_numbering.at_r(layer), - }; - } - }); +V1BinarySPDecomposition + to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit( + overload{[&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + }}); } } // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc index 7e6327d06a..657a3c3166 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -3,10 +3,11 @@ namespace FlexFlow { -BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &pcg_split) { +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &pcg_split) { return BinaryParallelSplit{ - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc index b0fec5f6ce..304ad224b1 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -3,10 +3,11 @@ namespace FlexFlow { -BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &pcg_split) { +BinarySeriesSplit binary_series_split_from_pcg_series_split( + PCGBinarySeriesSplit const &pcg_split) { return BinarySeriesSplit{ - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), }; } diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index 0555c2a14d..5eb993c6ef 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -1,70 +1,83 @@ #include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" #include "compiler/series_parallel/pcg/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/overload.h" namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t> generic_impl_for_pcg_sp_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree() { return GenericBinarySPDecompositionTreeImplementation< - PCGBinarySPDecomposition, - PCGBinarySeriesSplit, - PCGBinaryParallelSplit, - parallel_layer_guid_t>{ - /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](PCGBinarySeriesSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](PCGBinaryParallelSplit const &split) -> PCGBinarySPDecomposition const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { - return tree.get(); - }, - /*require_parallel=*/[](PCGBinarySPDecomposition const &tree) -> PCGBinaryParallelSplit const & { - return tree.get(); - }, - /*require_leaf=*/[](PCGBinarySPDecomposition const &tree) -> parallel_layer_guid_t const & { - return tree.get(); - }, + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](PCGBinarySPDecomposition const &tree) + -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](PCGBinarySPDecomposition const &tree) + -> parallel_layer_guid_t const & { + return tree.get(); + }, }; } - -BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { - return pcg_tree.visit(overload { - [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - binary_series_split_from_pcg_series_split(series), - }; - }, - [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{ - binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), - binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), - }, - }; - }, - [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - layer.raw_graph_node, - }; - }, +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload{ + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, }); } @@ -78,11 +91,18 @@ std::unordered_multiset return get_leaves(tree, generic_impl_for_pcg_sp_tree()); } -SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &tree) { - return tree.visit(overload { - [](PCGBinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](PCGBinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](parallel_layer_guid_t const &) { return SPDecompositionTreeNodeType::NODE; }, +SPDecompositionTreeNodeType + get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](PCGBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](PCGBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](parallel_layer_guid_t const &) { + return SPDecompositionTreeNodeType::NODE; + }, }); } diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index b63ce95ae0..5c8ea1c0f1 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -9,11 +9,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_abstracted_tensor_set_movement_across_split") { - auto make_series_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; }; @@ -70,8 +72,8 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_leaf(input1.parallel_layer), - make_leaf(input2.parallel_layer), + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), }; AbstractedTensorSetMovement result = @@ -94,9 +96,9 @@ TEST_SUITE(FF_TEST_SUITE) { pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_series_split(make_leaf(input.parallel_layer), - make_leaf(layer_1.parallel_layer)), - make_leaf(layer_2.parallel_layer), + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), }; AbstractedTensorSetMovement result = @@ -140,12 +142,11 @@ TEST_SUITE(FF_TEST_SUITE) { {relu_output_attrs}); PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ - make_series_split( - make_leaf(input.parallel_layer), - make_series_split( - make_leaf(layer_1.parallel_layer), - make_leaf(layer_2.parallel_layer))), - make_leaf(layer_3.parallel_layer), + make_series_split( + make_leaf(input.parallel_layer), + make_series_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), }; AbstractedTensorSetMovement result = @@ -188,9 +189,9 @@ TEST_SUITE(FF_TEST_SUITE) { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split(make_leaf(input.parallel_layer), - make_leaf(layer_1.parallel_layer)), + make_leaf(layer_1.parallel_layer)), make_parallel_split(make_leaf(layer_2.parallel_layer), - make_leaf(layer_3.parallel_layer)), + make_leaf(layer_3.parallel_layer)), }; AbstractedTensorSetMovement result = @@ -244,12 +245,10 @@ TEST_SUITE(FF_TEST_SUITE) { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ make_series_split( make_leaf(input.parallel_layer), - make_parallel_split( - make_leaf(layer_1.parallel_layer), - make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), make_parallel_split(make_leaf(layer_3.parallel_layer), - make_leaf(layer_4.parallel_layer)) - }; + make_leaf(layer_4.parallel_layer))}; AbstractedTensorSetMovement result = get_abstracted_tensor_set_movement_across_split( diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 7194fc038c..0a874948e4 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -16,28 +16,29 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_optimal_machine_mapping") { auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { - return MachineMappingProblemTree{k}; + return MachineMappingProblemTree{k}; }; - auto make_series_split = [](AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - /*tensor_set_movement=*/tensor_set_movement, - /*left_child=*/lhs, - /*right_child=*/rhs, - }, - }; - }; + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; auto make_parallel_split = [](MachineMappingProblemTree const &lhs, MachineMappingProblemTree const &rhs) { return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - /*left_child=*/lhs, - /*right_child=*/rhs, - }, + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, }; }; @@ -200,7 +201,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("pair of layers in parallel") { MachineMappingProblemTree problem_tree = - make_parallel_split(make_leaf(k1), make_leaf(k2)); + make_parallel_split(make_leaf(k1), make_leaf(k2)); MachineMappingConstraints constraints = get_unconstrained_solution_for_layers( diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc index 09d4af7756..06ab1e5b8c 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -12,24 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { return PCGBinarySPDecomposition{l}; }; - auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { return PCGBinarySPDecomposition{ - PCGBinarySeriesSplit{ - lhs, - rhs, - }, + PCGBinarySeriesSplit{ + lhs, + rhs, + }, }; }; auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, PCGBinarySPDecomposition const &rhs) { - return PCGBinarySPDecomposition{ - PCGBinaryParallelSplit{ - lhs, - rhs, - }, + PCGBinaryParallelSplit{ + lhs, + rhs, + }, }; }; @@ -37,28 +36,29 @@ TEST_SUITE(FF_TEST_SUITE) { return MachineMappingProblemTree{k}; }; - auto mm_problem_tree_make_series = [](AbstractedTensorSetMovement const &tensor_set_movement, - MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { - return MachineMappingProblemTree{ - MMProblemTreeSeriesSplit{ - tensor_set_movement, - lhs, - rhs, - }, - }; - }; - - auto mm_problem_tree_make_parallel = [](MachineMappingProblemTree const &lhs, - MachineMappingProblemTree const &rhs) { + auto mm_problem_tree_make_series = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; - return MachineMappingProblemTree{ - MMProblemTreeParallelSplit{ - lhs, - rhs, - }, - }; - }; + auto mm_problem_tree_make_parallel = + [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; ParallelComputationGraph pcg = empty_parallel_computation_graph(); @@ -113,7 +113,8 @@ TEST_SUITE(FF_TEST_SUITE) { UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); - PCGBinarySPDecomposition sp_decomposition = PCGBinarySPDecomposition{input_layer}; + PCGBinarySPDecomposition sp_decomposition = + PCGBinarySPDecomposition{input_layer}; MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); @@ -198,9 +199,9 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); - MachineMappingProblemTree correct = mm_problem_tree_make_parallel( - mm_problem_tree_make_leaf(input1_key), - mm_problem_tree_make_leaf(input2_key)); + MachineMappingProblemTree correct = + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); CHECK(result == correct); } @@ -240,10 +241,10 @@ TEST_SUITE(FF_TEST_SUITE) { /*output_shapes=*/{ew_op_output_shape}, }; - PCGBinarySPDecomposition sp_decomposition = pcg_make_series( - pcg_make_parallel(pcg_make_leaf(input1_layer), - pcg_make_leaf(input2_layer)), - pcg_make_leaf(ew_op_layer)); + PCGBinarySPDecomposition sp_decomposition = + pcg_make_series(pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); @@ -278,9 +279,8 @@ TEST_SUITE(FF_TEST_SUITE) { }, }}, /*pre=*/ - mm_problem_tree_make_parallel( - mm_problem_tree_make_leaf(input1_key), - mm_problem_tree_make_leaf(input2_key)), + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), /*post=*/mm_problem_tree_make_leaf(ew_op_key)); CHECK(result == correct); diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h index 62cfd6ec62..a1ca0aceed 100644 --- a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H #define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H -#include #include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include namespace nlohmann { diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc index 3adb79eb8f..5341e03c0a 100644 --- a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -1,75 +1,84 @@ #include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" #include "utils/exception.h" -#include "utils/overload.h" #include "utils/fmt/json.h" +#include "utils/overload.h" using namespace ::FlexFlow; namespace nlohmann { -V1BinarySPDecomposition adl_serializer::from_json(json const &j) { +V1BinarySPDecomposition + adl_serializer::from_json(json const &j) { std::string type = j.at("type").get(); if (type == "series") { return V1BinarySPDecomposition{ - j.get(), + j.get(), }; } else if (type == "parallel") { return V1BinarySPDecomposition{ - j.get(), + j.get(), }; } else if (type == "leaf") { return V1BinarySPDecomposition{ - j.at("value").get(), + j.at("value").get(), }; } else { - throw mk_runtime_error(fmt::format("Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" in json object: {}", type, j)); + throw mk_runtime_error(fmt::format( + "Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" " + "in json object: {}", + type, + j)); } } -void adl_serializer::to_json(json &j, V1BinarySPDecomposition const &tree) { - tree.visit(overload { - [&](V1BinarySeriesSplit const &split) { - j = split; - j["type"] = "series"; - return std::monostate{}; - }, - [&](V1BinaryParallelSplit const &split) { - j = split; - j["type"] = "parallel"; - return std::monostate{}; - }, - [&](int leaf) { - j["value"] = leaf; - j["type"] = "leaf"; - return std::monostate{}; - }, +void adl_serializer::to_json( + json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload{ + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, }); } -V1BinarySeriesSplit adl_serializer::from_json(json const &j) { +V1BinarySeriesSplit + adl_serializer::from_json(json const &j) { return V1BinarySeriesSplit{ - /*lhs=*/j.at("left_child").get(), - /*rhs=*/j.at("right_child").get(), + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), }; } -void adl_serializer::to_json(json &j, V1BinarySeriesSplit const &series) { +void adl_serializer::to_json( + json &j, V1BinarySeriesSplit const &series) { j["left_child"] = series.get_left_child(); j["right_child"] = series.get_right_child(); } -V1BinaryParallelSplit adl_serializer::from_json(json const &j) { +V1BinaryParallelSplit + adl_serializer::from_json(json const &j) { return V1BinaryParallelSplit{ - /*lhs=*/j.at("left_child").get(), - /*rhs=*/j.at("right_child").get(), + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), }; } -void adl_serializer::to_json(json &j, V1BinaryParallelSplit const &series) { +void adl_serializer::to_json( + json &j, V1BinaryParallelSplit const &series) { j["left_child"] = series.get_left_child(); j["right_child"] = series.get_right_child(); } - -} // namespace FlexFlow +} // namespace nlohmann diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc index e9f2573914..9068e14517 100644 --- a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -6,46 +6,46 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ - V1BinarySeriesSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, - }, + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, }, - V1BinarySPDecomposition{3}, - }, }; nlohmann::json example_json = { - {"type", "series"}, - { - "left_child", + {"type", "series"}, { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, - }, - }, }; SUBCASE("to_json") { @@ -56,7 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("from_json") { - V1BinarySPDecomposition result = example_json.get(); + V1BinarySPDecomposition result = + example_json.get(); V1BinarySPDecomposition correct = example_tree; CHECK(result == correct); @@ -65,43 +66,43 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, }, - }, - V1BinarySPDecomposition{3}, + V1BinarySPDecomposition{3}, }; nlohmann::json example_json = { - { - "left_child", { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, - }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, }, - }, }; SUBCASE("to_json") { @@ -121,43 +122,43 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("adl_serializer") { V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ - V1BinarySPDecomposition{ - V1BinaryParallelSplit{ - V1BinarySPDecomposition{2}, - V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, }, - }, - V1BinarySPDecomposition{3}, + V1BinarySPDecomposition{3}, }; nlohmann::json example_json = { - { - "left_child", { - {"type", "parallel"}, - { "left_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, }, - }, - { + }, + { "right_child", { - {"type", "leaf"}, - {"value", 2}, + {"type", "leaf"}, + {"value", 3}, }, - }, - }, - }, - { - "right_child", - { - {"type", "leaf"}, - {"value", 3}, }, - }, }; SUBCASE("to_json") { diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h index 4831afa408..1635747612 100644 --- a/lib/utils/include/utils/archetypes/value_type.h +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -10,14 +10,26 @@ template struct value_type { value_type() = delete; - value_type(value_type const &) { assert(false); } - value_type &operator=(value_type const &) { assert(false); } - - value_type(value_type &&) { assert(false); } - value_type &operator=(value_type &&) { assert(false); } - - bool operator==(value_type const &) const { assert(false); } - bool operator!=(value_type const &) const { assert(false); } + value_type(value_type const &) { + assert(false); + } + value_type &operator=(value_type const &) { + assert(false); + } + + value_type(value_type &&) { + assert(false); + } + value_type &operator=(value_type &&) { + assert(false); + } + + bool operator==(value_type const &) const { + assert(false); + } + bool operator!=(value_type const &) const { + assert(false); + } }; } // namespace FlexFlow @@ -27,10 +39,10 @@ namespace std { template struct hash<::FlexFlow::value_type> { size_t operator()(::FlexFlow::value_type const &) const { - assert (false); + assert(false); }; }; -} +} // namespace std #endif diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h index 15ad0de4e0..c7aa87e3eb 100644 --- a/lib/utils/include/utils/fmt/json.h +++ b/lib/utils/include/utils/fmt/json.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H -#include #include +#include namespace fmt { template struct formatter<::nlohmann::json, Char> : formatter { - template + template auto format(::nlohmann::json const &j, FormatContext &ctx) { - std::ostringstream oss; + std::ostringstream oss; oss << j; return formatter::format(oss.str(), ctx); } diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 07928f7871..9cf5d63210 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -11,27 +11,34 @@ namespace FlexFlow { template -std::unordered_set - find_paths_to_leaf(Tree const &tree, FullBinaryTreeImplementation const &impl, Leaf const &needle) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::unordered_set { - return set_union( - transform(find_paths_to_leaf(impl.get_left_child(parent), impl, needle), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(find_paths_to_leaf(impl.get_right_child(parent), impl, needle), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - }, - [&](Leaf const &leaf) -> std::unordered_set { - if (leaf == needle) { - return {binary_tree_root_path()}; - } else { - return {}; - } - }, +std::unordered_set find_paths_to_leaf( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform( + find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform( + find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 20c2eb8b62..822acfe9ee 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -12,24 +12,27 @@ namespace FlexFlow { template -std::unordered_set - get_all_leaf_paths(Tree const &tree, - FullBinaryTreeImplementation const &impl) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::unordered_set { - return set_union( - transform(get_all_leaf_paths(impl.get_left_child(parent), impl), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(impl.get_right_child(parent), impl), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - }, - [&](Leaf const &leaf) -> std::unordered_set { - return {binary_tree_root_path()}; - }, +std::unordered_set get_all_leaf_paths( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index 5c1e21014d..7517028ec0 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -9,7 +9,7 @@ namespace FlexFlow { template -Tree get_child(Parent const &parent, +Tree get_child(Parent const &parent, FullBinaryTreeImplementation const &impl, BinaryTreePathEntry const &e) { switch (e) { diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index 87633f29a9..8f9d8e919f 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -13,19 +13,17 @@ std::unordered_multiset get_leaves(Tree const &tree, FullBinaryTreeImplementation const &impl) { - auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) - -> std::unordered_multiset - { - return multiset_union(get_leaves(impl.get_left_child(parent), impl), - get_leaves(impl.get_right_child(parent), impl)); - }, - [](Leaf const &leaf) - -> std::unordered_multiset - { - return {leaf}; - }, - }; + auto visitor = + FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_multiset { + return multiset_union( + get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) -> std::unordered_multiset { + return {leaf}; + }, + }; return visit(tree, impl, visitor); } diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h index 69d4e2ea49..922a42242c 100644 --- a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -7,20 +7,21 @@ namespace FlexFlow { template -int get_num_tree_nodes(Tree const &tree, - FullBinaryTreeImplementation const &impl) { - +int get_num_tree_nodes( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor{ - [&](Parent const &parent) -> int { - return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + get_num_tree_nodes(impl.get_right_child(parent), impl); - }, - [](Leaf const &) -> int { return 1; }, + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, }; return visit(tree, impl, visitor); } - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index bbdc74850c..83ce1367b9 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -3,31 +3,29 @@ #include "utils/full_binary_tree/binary_tree_path.dtg.h" #include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/full_binary_tree/visit.h" #include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" #include namespace FlexFlow { template -std::optional - get_subtree_at_path(Tree const &tree, - FullBinaryTreeImplementation const &impl, - BinaryTreePath const &p) { +std::optional get_subtree_at_path( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + BinaryTreePath const &p) { if (p == binary_tree_root_path()) { return tree; } auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ - [&](Parent const &parent) -> std::optional { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - - return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); - }, - [](Leaf const &leaf) -> std::optional { - return std::nullopt; - }, + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { return std::nullopt; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index 87aa115c8c..832d39bdff 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -2,13 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H #include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" namespace FlexFlow { template -Result visit(Tree const &tree, +Result visit(Tree const &tree, FullBinaryTreeImplementation const &impl, FullBinaryTreeVisitor const &visitor) { if (impl.is_leaf(tree)) { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index 28e9beeebd..de48cd17e9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -10,11 +10,11 @@ namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> generic_impl_for_binary_sp_tree(); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree(); bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h index 9eaf84149f..105f5490a4 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -7,12 +7,15 @@ namespace FlexFlow { template -std::unordered_set - find_paths_to_leaf(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - Leaf const &needle) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); +std::unordered_set find_paths_to_leaf( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return find_paths_to_leaf(tree, full_binary_impl, needle); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h index fd29b69567..0bddbee81c 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -1,59 +1,69 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/exception.h" #include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" -#include +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/overload.h" -#include "utils/exception.h" +#include namespace FlexFlow { template -FullBinaryTreeImplementation, Leaf> - get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &impl) { +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &impl) { using Parent = std::variant; auto full_binary_impl = FullBinaryTreeImplementation{ - /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { - return std::visit(overload { - [&](Series const &series) -> Tree const & { - return impl.series_get_left_child(series); - }, - [&](Parallel const ¶llel) -> Tree const & { - return impl.parallel_get_left_child(parallel); - }, - }, parent); - }, - /*get_right_child=*/[impl](Parent const &parent) -> Tree const & { - return std::visit(overload { - [&](Series const &series) -> Tree const & { - return impl.series_get_right_child(series); - }, - [&](Parallel const ¶llel) -> Tree const & { - return impl.parallel_get_right_child(parallel); - }, - }, parent); - }, - /*is_leaf=*/[impl](Tree const &tree) -> bool { - return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; - }, - /*require_leaf=*/[impl](Tree const &tree) -> Leaf const & { - return impl.require_leaf(tree); - }, - /*require_parent=*/[impl](Tree const &tree) -> Parent { - SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); - switch (node_type) { - case SPDecompositionTreeNodeType::SERIES: - return Parent{impl.require_series(tree)}; - case SPDecompositionTreeNodeType::PARALLEL: - return Parent{impl.require_parallel(tree)}; - default: - throw mk_runtime_error(fmt::format("Unexpected SPDecompositionTreeNodeType: {}", node_type)); - } - } - }; + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, + parent); + }, + /*get_right_child=*/ + [impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, + parent); + }, + /*is_leaf=*/ + [impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/ + [impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/ + [impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format( + "Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + }}; return full_binary_impl; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h index 4637cbd81c..b0bb8355db 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -9,10 +9,13 @@ namespace FlexFlow { template std::unordered_set get_all_leaf_paths( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_all_leaf_paths(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 7bbc5cf603..c543375148 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -7,12 +7,15 @@ namespace FlexFlow { template -std::unordered_multiset - get_leaves(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { +std::unordered_multiset get_leaves( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_leaves(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index b5fe0d4131..4678e0c0f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -7,11 +7,15 @@ namespace FlexFlow { template -int get_num_tree_nodes(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { +int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_num_tree_nodes(tree, full_binary_impl); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h index 8a687d9702..c48185fb7f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -8,12 +8,15 @@ namespace FlexFlow { template -std::optional - get_subtree_at_path(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - BinaryTreePath const &path) { - FullBinaryTreeImplementation, Leaf> - full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); +std::optional get_subtree_at_path( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + BinaryTreePath const &path) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); return get_subtree_at_path(tree, full_binary_impl, path); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 17ff9c5dd1..68e0a3af32 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -9,22 +9,33 @@ namespace FlexFlow { template bool is_binary_sp_tree_left_associative( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { + GenericBinarySPDecompositionTreeImplementation const &impl) { - auto visitor = GenericBinarySPDecompositionTreeVisitor{ - [&](Series const &split) { - return impl.get_node_type(impl.series_get_right_child(split)) != SPDecompositionTreeNodeType::SERIES && - is_binary_sp_tree_left_associative(impl.series_get_left_child(split), impl) && - is_binary_sp_tree_left_associative(impl.series_get_right_child(split), impl); - }, - [&](Parallel const &split) { - return impl.get_node_type(impl.parallel_get_right_child(split)) != SPDecompositionTreeNodeType::PARALLEL && - is_binary_sp_tree_left_associative(impl.parallel_get_left_child(split), impl) && - is_binary_sp_tree_left_associative(impl.parallel_get_right_child(split), impl); - }, - [&](Leaf const &leaf) { - return true; - }, + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index b284ce763e..7042765203 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -9,21 +9,32 @@ namespace FlexFlow { template bool is_binary_sp_tree_right_associative( Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl) { - auto visitor = GenericBinarySPDecompositionTreeVisitor{ - [&](Series const &split) { - return impl.get_node_type(impl.series_get_left_child(split)) != SPDecompositionTreeNodeType::SERIES && - is_binary_sp_tree_right_associative(impl.series_get_left_child(split), impl) && - is_binary_sp_tree_right_associative(impl.series_get_right_child(split), impl); - }, - [&](Parallel const &split) { - return impl.get_node_type(impl.parallel_get_left_child(split)) != SPDecompositionTreeNodeType::PARALLEL && - is_binary_sp_tree_right_associative(impl.parallel_get_left_child(split), impl) && - is_binary_sp_tree_right_associative(impl.parallel_get_right_child(split), impl); - }, - [&](Leaf const &leaf) { - return true; - }, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, }; return visit(tree, impl, visitor); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 89bb45f0fb..c06db135b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -1,17 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" namespace FlexFlow { -template -ReturnType visit( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &impl, - GenericBinarySPDecompositionTreeVisitor const &visitor) { +template +ReturnType + visit(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); switch (node_type) { case SPDecompositionTreeNodeType::SERIES: { @@ -27,7 +38,8 @@ ReturnType visit( return result; } default: - throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType value: {}", node_type)); + throw mk_runtime_error(fmt::format( + "Unknown SPDecompositionTreeNodeType value: {}", node_type)); } } diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 98eb913aeb..7374b45a60 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#include "utils/graph/series_parallel/series_split.dtg.h" #include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" namespace FlexFlow { // struct SeriesSplit { // public: // SeriesSplit() = delete; -// explicit SeriesSplit(std::vector> const &); -// explicit SeriesSplit( +// explicit SeriesSplit(std::vector> const +// &); explicit SeriesSplit( // std::initializer_list> const &); // // bool operator==(SeriesSplit const &) const; @@ -71,6 +71,6 @@ namespace FlexFlow { // size_t operator()(::FlexFlow::ParallelSplit const &) const; // }; -} // namespace std +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h index f72485dcbd..dd5f397c19 100644 --- a/lib/utils/include/utils/json/check_is_json_deserializable.h +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -6,7 +6,7 @@ namespace FlexFlow { #define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_deserializable::value, \ + static_assert(::FlexFlow::is_json_deserializable::value, \ #TYPENAME " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h index f3d1a058f8..dfcb26081d 100644 --- a/lib/utils/include/utils/json/check_is_json_serializable.h +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -5,8 +5,8 @@ namespace FlexFlow { -#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ - static_assert(::FlexFlow::is_json_serializable::value, \ +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ #TYPENAME " should be json serializeable") } // namespace FlexFlow diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc index 9c197112a1..f7da47d8f9 100644 --- a/lib/utils/src/utils/archetypes/value_type.cc +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -2,7 +2,6 @@ namespace FlexFlow { -template - struct value_type<0>; +template struct value_type<0>; } // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc index 783b75973c..49ad57fba7 100644 --- a/lib/utils/src/utils/fmt/json.cc +++ b/lib/utils/src/utils/fmt/json.cc @@ -2,7 +2,6 @@ namespace fmt { -template - struct formatter<::nlohmann::json, char>; +template struct formatter<::nlohmann::json, char>; } diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc index b3ddab6cbc..47845720ed 100644 --- a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -7,8 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::unordered_set - find_paths_to_leaf(Tree const &, FullBinaryTreeImplementation const &, Leaf const &); +template std::unordered_set + find_paths_to_leaf(Tree const &, + FullBinaryTreeImplementation const &, + Leaf const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc index cbbffb0b4a..b4d8aa1011 100644 --- a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -3,9 +3,10 @@ namespace FlexFlow { -template - std::unordered_set - get_all_leaf_paths(value_type<0> const &, - FullBinaryTreeImplementation, value_type<1>, value_type<2>> const &); +template std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, + value_type<1>, + value_type<2>> const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc index 3283db398b..19362ae510 100644 --- a/lib/utils/src/utils/full_binary_tree/get_child.cc +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -7,9 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - Tree get_child(Parent const &, - FullBinaryTreeImplementation const &, - BinaryTreePathEntry const &); +template Tree + get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc index 18221cd98a..0d7e9106f6 100644 --- a/lib/utils/src/utils/full_binary_tree/get_leaves.cc +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -7,9 +7,8 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::unordered_multiset - get_leaves(Tree const &, - FullBinaryTreeImplementation const &); +template std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc index b651309c32..7a99dd60fa 100644 --- a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -7,8 +7,7 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - int get_num_tree_nodes(Tree const &, - FullBinaryTreeImplementation const &); +template int get_num_tree_nodes( + Tree const &, FullBinaryTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc index 689237752a..1eea13fedd 100644 --- a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -7,10 +7,9 @@ using Tree = value_type<0>; using Parent = value_type<1>; using Leaf = value_type<2>; -template - std::optional - get_subtree_at_path(Tree const &, - FullBinaryTreeImplementation const &, - BinaryTreePath const &); +template std::optional get_subtree_at_path( + Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc index c8a36dff66..4a4f7c9302 100644 --- a/lib/utils/src/utils/full_binary_tree/visit.cc +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -2,9 +2,8 @@ namespace FlexFlow { -template - int visit(std::string const &, - FullBinaryTreeImplementation const &, - FullBinaryTreeVisitor const &); +template int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 56718fa71f..62489ff75f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -1,66 +1,84 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" namespace FlexFlow { -GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> generic_impl_for_binary_sp_tree() { +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree() { return GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> - { - /*series_get_left_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_left_child(); - }, - /*parallel_get_left_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_left_child(); - }, - /*series_get_right_child=*/[](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_right_child(); - }, - /*parallel_get_right_child=*/[](BinaryParallelSplit const &split) -> BinarySPDecompositionTree const & { - return split.get_right_child(); - }, - /*get_node_type=*/[](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { - return get_node_type(tree); - }, - /*require_series=*/[](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { - return tree.require_series(); - }, - /*require_parallel=*/[](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { - return tree.require_parallel(); - }, - /*require_leaf=*/[](BinarySPDecompositionTree const &tree) -> Node const & { - return tree.require_node(); - }, + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node>{ + /*series_get_left_child=*/[](BinarySeriesSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/ + [](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/ + [](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, }; } bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { - return is_binary_sp_tree_left_associative(tree, generic_impl_for_binary_sp_tree()); + return is_binary_sp_tree_left_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tree) { - return is_binary_sp_tree_right_associative(tree, generic_impl_for_binary_sp_tree()); +bool is_binary_sp_tree_right_associative( + BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, + generic_impl_for_binary_sp_tree()); } -std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tree) { +std::unordered_multiset + get_leaves(BinarySPDecompositionTree const &tree) { return get_leaves(tree, generic_impl_for_binary_sp_tree()); } -SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tree) { - return tree.visit(overload { - [](BinarySeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](BinaryParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; }, - [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, +SPDecompositionTreeNodeType + get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload{ + [](BinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](BinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, }); } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc index e30b9f97a6..07e2c3e3e3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -8,10 +8,12 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_set - find_paths_to_leaf(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - Leaf const &); +template std::unordered_set find_paths_to_leaf( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc index bc6b4b1ccf..56a6d0cc85 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -9,6 +9,10 @@ using Parallel = value_type<2>; using Leaf = value_type<3>; FullBinaryTreeImplementation, Leaf> - get_full_binary_impl_from_generic_sp_impl(GenericBinarySPDecompositionTreeImplementation const &); + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc index 7bc9c4bfe4..71d3f6ac31 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_set get_all_leaf_paths( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 6c80f4ba9b..3bb90bfa32 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::unordered_multiset - get_leaves(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &); +template std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 89e8deb437..3d166145c1 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -8,8 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - int get_num_tree_nodes(Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc index e95284fa5e..d1d8079c0b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -8,10 +8,12 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - std::optional - get_subtree_at_path(Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - BinaryTreePath const &); +template std::optional get_subtree_at_path( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 2b478edb20..69cbb28582 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -8,9 +8,11 @@ using Series = value_type<1>; using Parallel = value_type<2>; using Leaf = value_type<3>; -template - bool is_binary_sp_tree_left_associative( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index e50a861219..584099e33e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -8,9 +8,11 @@ using Series = value_type<2>; using Parallel = value_type<3>; using Leaf = value_type<4>; -template - bool is_binary_sp_tree_right_associative( - Tree const &tree, - GenericBinarySPDecompositionTreeImplementation const &); +template bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc index b7175e0e1b..056ae2a8d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -9,10 +9,16 @@ using Series = value_type<2>; using Parallel = value_type<3>; using Leaf = value_type<4>; -template - ReturnType visit( - Tree const &, - GenericBinarySPDecompositionTreeImplementation const &, - GenericBinarySPDecompositionTreeVisitor const &); +template ReturnType + visit(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 33ac5f00e9..69b2ebea8e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -23,26 +23,28 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { std::vector children = transform(s.children, from_series_child); - return foldl1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinarySeriesSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) -> BinarySPDecompositionTree { std::vector children = transform(vector_of(s.get_children()), from_parallel_child); - return foldl1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; + }); }; from_parallel_child = [&](std::variant const &v) diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 2477140d71..478d90e0c3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -21,25 +21,27 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( auto from_series = [&](SeriesSplit const &s) { std::vector children = transform(s.children, from_series_child); - return foldr1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinarySeriesSplit{x, accum}, - }; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) { std::vector children = transform(vector_of(s.get_children()), from_parallel_child); - return foldr1(children, - [](BinarySPDecompositionTree const &accum, - BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { - return BinarySPDecompositionTree{ - BinaryParallelSplit{x, accum}, - }; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; + }); }; from_parallel_child = [&](std::variant const &v) diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 84ef2fc106..cd29af59a0 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -33,10 +33,10 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); std::unordered_map - ttsp_edge_to_sp_tree = - map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { return BinarySPDecompositionTree{n}; }); + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -47,10 +47,10 @@ std::optional auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinaryParallelSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); @@ -67,10 +67,10 @@ std::optional MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ - BinarySeriesSplit{ - ttsp_edge_to_sp_tree.at(e1), - ttsp_edge_to_sp_tree.at(e2), - }, + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 07df693ae1..410a40236d 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -49,27 +49,29 @@ std::variant flatten_ast( std::variant from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return binary.template visit>(overload{ - [](Node const &n) { return n; }, - [](BinarySeriesSplit const &s) { - return IntermediateSpDecompositionTree{ - SplitType::SERIES, - { - from_binary_sp_tree(s.get_left_child()), - from_binary_sp_tree(s.get_right_child()), - }, - }; - }, - [](BinaryParallelSplit const &p) { - return IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - { - from_binary_sp_tree(p.get_left_child()), - from_binary_sp_tree(p.get_right_child()), - }, - }; - }, - }); + return binary + .template visit>( + overload{ + [](Node const &n) { return n; }, + [](BinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), + }, + }; + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc index 9364e02afc..c35789044d 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -10,14 +10,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; - + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc index 1b49c7218d..1f8f66b932 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -12,17 +12,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_edges_across_split") { DataflowGraph g = DataflowGraph::create(); - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("multiple nodes with edges across") { NodeAddedResult n1_added = g.add_node({}, 1); @@ -82,7 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { get_dataflow_graph_transitive_reduction(g); BinarySeriesSplit split = BinarySeriesSplit{ - make_leaf(n1), + make_leaf(n1), make_leaf(n2), }; diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc index 222e9b20bb..0e77739434 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -10,14 +10,13 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_transitive_reduced_outputs_across_split") { - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; - + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 8981312c4b..9ca869b2b0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,6 +1,6 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include @@ -12,11 +12,11 @@ TEST_SUITE(FF_TEST_SUITE) { Node n2 = Node{2}; Node n3 = Node{3}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { return get_leaves(tree, impl); @@ -33,13 +33,12 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("series split") { SUBCASE("children are not the same") { - BinarySPDecompositionTree input = - BinarySPDecompositionTree{ + BinarySPDecompositionTree input = BinarySPDecompositionTree{ BinarySeriesSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n2}, + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, }, - }; + }; std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n2}; @@ -48,13 +47,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("children are the same") { - BinarySPDecompositionTree input = - BinarySPDecompositionTree{ + BinarySPDecompositionTree input = BinarySPDecompositionTree{ BinarySeriesSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, }, - }; + }; std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n1}; @@ -66,10 +64,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { BinarySPDecompositionTree input = BinarySPDecompositionTree{ - BinaryParallelSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n2}, - }, + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, }; std::unordered_multiset result = generic_get_leaves(input); @@ -80,10 +78,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("children are the same") { BinarySPDecompositionTree input = BinarySPDecompositionTree{ - BinaryParallelSplit{ - BinarySPDecompositionTree{n1}, - BinarySPDecompositionTree{n1}, - }, + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, }; std::unordered_multiset result = generic_get_leaves(input); @@ -93,29 +91,23 @@ TEST_SUITE(FF_TEST_SUITE) { } } - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("nested") { - BinarySPDecompositionTree input = - make_parallel_split( - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))), - make_parallel_split( - make_leaf(n2), - make_leaf(n1))); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); std::unordered_multiset result = generic_get_leaves(input); std::unordered_multiset correct = {n1, n1, n2, n2, n3}; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index f61ff83bf9..ad7e1c2609 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -11,31 +11,31 @@ TEST_SUITE(FF_TEST_SUITE) { Node n2 = Node{2}; Node n3 = Node{3}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; - auto generic_get_num_tree_nodes = [&](BinarySPDecompositionTree const &tree) { - return get_num_tree_nodes(tree, impl); - }; + auto generic_get_num_tree_nodes = + [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; SUBCASE("leaf") { - BinarySPDecompositionTree input = - make_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); int result = generic_get_num_tree_nodes(input); int correct = 1; @@ -88,16 +88,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - BinarySPDecompositionTree input = - make_parallel_split( - make_series_split( - make_leaf(n1), - make_series_split( - make_leaf(n2), - make_leaf(n3))), - make_parallel_split( - make_leaf(n2), - make_leaf(n1))); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); int result = generic_get_num_tree_nodes(input); int correct = 9; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 05ff0b4aaa..3fae155280 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -12,23 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = Node{3}; Node n4 = Node{4}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually left associative") { SUBCASE("just node") { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 324008fdca..5b4e26107e 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -12,23 +12,23 @@ TEST_SUITE(FF_TEST_SUITE) { Node n3 = Node{3}; Node n4 = Node{4}; - GenericBinarySPDecompositionTreeImplementation< - BinarySPDecompositionTree, - BinarySeriesSplit, - BinaryParallelSplit, - Node> impl = generic_impl_for_binary_sp_tree(); - - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually right associative") { SUBCASE("just node") { diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 20f939a8f0..fee971e5e0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,17 +18,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; @@ -49,8 +49,7 @@ TEST_SUITE(FF_TEST_SUITE) { left_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_series_split(make_leaf(n1), make_leaf(n2)), - make_leaf(n3)); + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 5db50ab2ef..fd540f853f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -14,17 +14,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("leaf") { BinarySPDecompositionTree input = make_leaf(n1); @@ -37,8 +37,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative series") { BinarySPDecompositionTree input = make_series_split( - make_series_split(make_leaf(n2), make_leaf(n1)), - make_leaf(n3)); + make_series_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -49,8 +48,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("right associative series") { BinarySPDecompositionTree input = make_series_split( - make_leaf(n2), - make_series_split(make_leaf(n1), make_leaf(n3))); + make_leaf(n2), make_series_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -73,8 +71,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_parallel_split(make_leaf(n2), make_leaf(n1)), - make_leaf(n3)); + make_parallel_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -85,8 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("right associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_leaf(n2), - make_parallel_split(make_leaf(n1), make_leaf(n3))); + make_leaf(n2), make_parallel_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = @@ -113,9 +109,9 @@ TEST_SUITE(FF_TEST_SUITE) { make_parallel_split( make_leaf(n1), make_series_split( - make_series_split(make_series_split(make_leaf(n2), - make_leaf(n3)), - make_leaf(n3)), + make_series_split( + make_series_split(make_leaf(n2), make_leaf(n3)), + make_leaf(n3)), make_leaf(n5))), make_series_split(make_leaf(n6), make_leaf(n4))), make_leaf(n5)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 19b9cfd944..532ff86c90 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,17 +16,17 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; - auto make_series_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; }; - auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, BinarySPDecompositionTree const &rhs) { + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](Node const &n) { - return BinarySPDecompositionTree{n}; - }; + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; @@ -47,8 +47,7 @@ TEST_SUITE(FF_TEST_SUITE) { right_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_leaf(n1), - make_series_split(make_leaf(n2), make_leaf(n3))); + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); CHECK(result == correct); }