From bdcc10e9ceca707b73805d1b6b35cc4730b348da Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 28 Sep 2024 19:16:04 -0700 Subject: [PATCH] Move over to ProblemTree/ResultTree framework for machine mapping --- .../single_tensor_movement.struct.toml | 2 +- ...tracted_single_tensor_movement.struct.toml | 30 +++ .../abstracted_tensor_set_movement.h | 21 ++ ...abstracted_tensor_set_movement.struct.toml | 21 ++ ...tracted_tensor_set_movement_across_split.h | 15 ++ .../get_optimal_machine_mapping.h | 34 ++-- .../machine_mapping/machine_mapping_cache.h | 8 +- .../machine_mapping_constraints.h | 31 +++ ...> machine_mapping_constraints.struct.toml} | 2 +- .../machine_mapping_context.struct.toml | 5 - .../get_machine_mapping_problem_tree.h | 17 ++ .../machine_mapping_problem_tree.h | 51 +++++ .../machine_mapping_problem_tree.struct.toml | 18 ++ ...mm_problem_tree_parallel_split.struct.toml | 18 ++ ...blem_tree_parallel_split_label.struct.toml | 11 ++ .../mm_problem_tree_series_split.h | 15 ++ .../mm_problem_tree_series_split.struct.toml | 18 ++ ...roblem_tree_series_split_label.struct.toml | 15 ++ .../unmapped_op_cost_estimate_key.struct.toml | 36 ++++ .../machine_mapping/machine_mapping_result.h | 5 - .../machine_mapping_result_tree.h | 19 ++ .../machine_mapping_result_tree.struct.toml | 18 ++ .../mm_result_tree_parallel_split.struct.toml | 18 ++ ...sult_tree_parallel_split_label.struct.toml | 13 ++ .../mm_result_tree_series_split.struct.toml | 18 ++ ...result_tree_series_split_label.struct.toml | 13 ++ .../machine_mapping_state.struct.toml | 12 +- .../machine_mapping/partial_machine_mapping.h | 31 --- ..._graph_binary_sp_decomposition.struct.toml | 10 +- .../pcg_binary_parallel_split.struct.toml | 10 +- .../pcg_binary_series_split.struct.toml | 9 +- .../pcg_binary_sp_decomposition.struct.toml | 10 +- .../abstracted_tensor_set_movement.cc | 49 +++++ .../machine_mapping/estimate_layer_cost.cc | 5 +- ...racted_tensor_set_movement_across_split.cc | 48 +++++ .../get_machine_mapping_problem_tree.cc | 45 +++++ .../get_optimal_machine_mapping.cc | 183 +++++++++--------- .../get_tensor_set_movement_across_split.cc | 37 +--- ...ping.cc => machine_mapping_constraints.cc} | 24 ++- .../machine_mapping_problem_tree.cc | 95 +++++++++ .../mm_problem_tree_series_split.cc | 18 ++ .../mm_problem_tree_split_label.cc | 17 ++ .../pcg_binary_parallel_split.cc | 2 +- .../pcg_binary_series_split.cc | 5 +- .../get_machine_mapping_problem_tree.cc | 176 +++++++++++++++++ .../get_optimal_machine_mapping.cc | 9 +- .../parallel_computation_graph.h | 2 + .../parallel_computation_graph.cc | 5 + .../include/utils/full_binary_tree/fmt.h | 37 ++++ .../utils/full_binary_tree/full_binary_tree.h | 87 +++++++++ .../full_binary_tree_node_type.enum.toml | 16 ++ .../utils/full_binary_tree/get_leaves.h | 30 +++ .../utils/full_binary_tree/get_left_child.h | 15 ++ .../utils/full_binary_tree/get_node_type.h | 27 +++ .../utils/full_binary_tree/get_right_child.h | 15 ++ .../include/utils/full_binary_tree/hash.h | 26 +++ .../include/utils/full_binary_tree/require.h | 20 ++ .../utils/full_binary_tree/transform.h | 48 +++++ .../include/utils/full_binary_tree/visit.h | 23 +++ .../binary_parallel_split.struct.toml | 10 +- .../binary_series_split.struct.toml | 10 +- .../binary_sp_decomposition_tree.struct.toml | 10 +- .../fmt.h | 63 ------ .../generic_binary_parallel_split.struct.toml | 29 +++ .../generic_binary_series_split.struct.toml | 30 +++ .../generic_binary_sp_decomposition_tree.h | 155 --------------- ...c_binary_sp_decomposition_tree.struct.toml | 21 ++ .../get.h | 15 -- .../get_leaves.h | 24 +-- .../get_left_child.h | 42 ++-- .../get_node_type.h | 33 ++-- .../get_right_child.h | 42 ++-- .../hash.h | 34 ---- .../generic_binary_sp_decomposition_tree/is.h | 20 +- .../is_binary_sp_tree_left_associative.h | 12 +- .../is_binary_sp_tree_right_associative.h | 12 +- .../make.h | 56 +++--- .../require.h | 38 ++-- .../transform.h | 59 ++++-- .../visit.h | 43 ++-- .../get_leaves.h | 16 ++ .../leaf_only_binary_parallel_split.h | 21 ++ ...eaf_only_binary_parallel_split.struct.toml | 23 +++ ...ly_binary_parallel_split_label.struct.toml | 12 ++ .../leaf_only_binary_series_split.h | 21 ++ .../leaf_only_binary_series_split.struct.toml | 23 +++ ...only_binary_series_split_label.struct.toml | 12 ++ ...y_binary_sp_decomposition_tree.struct.toml | 21 ++ .../make.h | 44 +++++ .../require.h | 52 +++++ .../transform.h | 63 ++++++ .../test/src/utils/containers/flatmap.cc | 35 ++++ 92 files changed, 1997 insertions(+), 722 deletions(-) create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h rename lib/compiler/include/compiler/machine_mapping/{partial_machine_mapping.struct.toml => machine_mapping_constraints.struct.toml} (92%) create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml delete mode 100644 lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h create mode 100644 lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc rename lib/compiler/src/compiler/machine_mapping/{partial_machine_mapping.cc => machine_mapping_constraints.cc} (69%) create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc create mode 100644 lib/utils/include/utils/full_binary_tree/fmt.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml create mode 100644 lib/utils/include/utils/full_binary_tree/get_leaves.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_left_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_node_type.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_right_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/hash.h create mode 100644 lib/utils/include/utils/full_binary_tree/require.h create mode 100644 lib/utils/include/utils/full_binary_tree/transform.h create mode 100644 lib/utils/include/utils/full_binary_tree/visit.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h create mode 100644 lib/utils/test/src/utils/containers/flatmap.cc diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml index 52f66f3420..70f73ebe51 100644 --- a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -26,5 +26,5 @@ name = "src_machine_views" type = "std::unordered_set<::FlexFlow::MachineView>" [[fields]] -name = "dst_machine_view" +name = "dst_machine_views" type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml new file mode 100644 index 0000000000..fcae1e2356 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h new file mode 100644 index 0000000000..80e91b0f85 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); + +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &); + +TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &, + MachineMapping const &pre, + MachineMapping const &post); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml new file mode 100644 index 0000000000..4cf184706b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..33f44a3a11 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H + +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h index 7b4ba275a2..3c71d78093 100644 --- a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -3,10 +3,12 @@ #include "compiler/machine_mapping/machine_mapping.h" #include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" #include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" -#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h" -#include "compiler/series_parallel/pcg_binary_series_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -14,7 +16,7 @@ namespace FlexFlow { -MachineMappingResult get_optimal_machine_mapping( +MachineMappingResultTree get_optimal_machine_mapping( ParallelComputationGraph const &pcg, std::function( ParallelLayerAttrs const &, MachineSpecification const &)> const @@ -23,38 +25,38 @@ MachineMappingResult get_optimal_machine_mapping( MachineSpecification const &resources, MachineMappingCache &cached_subgraph_results); -MachineMappingResult +MachineMappingResultTree get_optimal_machine_mapping_internal(MachineMappingCache &result_cache, MachineMappingContext const &context, MachineSpecification const &resources); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySPDecomposition const &sp_decomposition, + MachineMappingProblemTree const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySeriesSplit const &series, + MMProblemTreeSeriesSplit const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinaryParallelSplit const ¶llel, + MMProblemTreeParallelSplit const &, MachineSpecification const &resources, - PartialMachineMapping const &); + MachineMappingConstraints const &); -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &, parallel_layer_guid_t const &, MachineSpecification const &, - PartialMachineMapping const &); + MachineMappingConstraints const &); } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h index a721ea29ed..b4608a90e0 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H #define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H -#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" #include "compiler/machine_mapping/machine_mapping_state.dtg.h" #include "utils/optional.h" @@ -11,11 +11,11 @@ class MachineMappingCache { public: MachineMappingCache() = default; - std::optional load(MachineMappingState const &) const; - void save(MachineMappingState const &, MachineMappingResult const &); + std::optional load(MachineMappingState const &) const; + void save(MachineMappingState const &, MachineMappingResultTree const &); private: - std::unordered_map cache; + std::unordered_map cache; }; } // namespace FlexFlow diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h new file mode 100644 index 0000000000..320a840bf6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H + +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &); + +std::unordered_set get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); + +std::optional get_machine_view_for_layer(MachineMappingConstraints const &, + parallel_layer_guid_t const &); + +MachineMappingConstraints restrict_domain(MachineMappingConstraints const &, + std::unordered_set const &); + +MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &, + MachineMapping const &); + +MachineMapping require_fully_constrained(MachineMappingConstraints const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml similarity index 92% rename from lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml index b1955185ad..7211c773bb 100644 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "PartialMachineMapping" +name = "MachineMappingConstraints" features = [ "eq", "hash", diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml index 272d4c2097..505141d59f 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -6,13 +6,8 @@ includes = [ "compiler/cost_estimator/cost_estimator.h", "pcg/machine_view.dtg.h", "pcg/machine_specification.dtg.h", - "compiler/machine_mapping/transitive_reduced_pcg.dtg.h", ] -[[fields]] -name = "transitive_reduced_pcg" -type = "::FlexFlow::TransitiveReducedPCG" - [[fields]] name = "cost_estimator" type = "::FlexFlow::CostEstimator" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h new file mode 100644 index 0000000000..b5ab1988ad --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h new file mode 100644 index 0000000000..29b5cf24d5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -0,0 +1,51 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree + mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &pre, + MachineMappingProblemTree const &post); +MachineMappingProblemTree + mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs); +MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &); + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); + +MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &); +MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &); +PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &); + +std::unordered_multiset get_leaves(MachineMappingProblemTree const &); + +template +Result visit(MachineMappingProblemTree const &t, F &&f) { + SPDecompositionTreeNodeType node_type = get_node_type(t); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + Result result = f(require_series_split(t)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + Result result = f(require_parallel_split(t)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + Result result = f(require_leaf(t)); + return result; + } + default: + throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml new file mode 100644 index 0000000000..e322133768 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..b277ca44bd --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml new file mode 100644 index 0000000000..367ffb399f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.struct.toml @@ -0,0 +1,11 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +fields = [] diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h new file mode 100644 index 0000000000..8332da66f9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &); +MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &); +AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml new file mode 100644 index 0000000000..299114862c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml new file mode 100644 index 0000000000..0887d67b49 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..fe76683eb7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h index 621285ae16..0cdd283582 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -5,11 +5,6 @@ namespace FlexFlow { -MachineMappingResult sequential_combine(MachineMappingResult const &s1, - float comm_cost, - MachineMappingResult const &s2); -MachineMappingResult parallel_combine(MachineMappingResult const &s1, - MachineMappingResult const &s2); MachineMappingResult get_infinity_machine_mapping_result(); void minimize_runtime(MachineMappingResult &m1, MachineMappingResult const &m2); diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h new file mode 100644 index 0000000000..0ddbc08297 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_TREE_MACHINE_MAPPING_RESULT_TREE_H + +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h" + +namespace FlexFlow { + +MachineMappingResultTree make_series_split(float comm_cost, + MachineMappingResultTree const &pre, + MachineMappingResultTree const &post); +MachineMappingResultTree make_parallel_split(MachineMappingResultTree const &lhs, + MachineMappingResultTree const &rhs); +MachineMappingResultTree make_leaf_node(float cost, MachineView const &); + +std::optional minimize_cost(std::optional const &, MachineMappingResultTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml new file mode 100644 index 0000000000..69c7a613e0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingResultTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..ceb85e26eb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMResultTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/parallel_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml new file mode 100644 index 0000000000..6bc880e1fb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "MMResultTreeParallelSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "cost" +type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml new file mode 100644 index 0000000000..9210d1c80c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MMResultTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.dtg.h", + "compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_parallel_split_label.dtg.h", + "pcg/machine_view.dtg.h", +] + +[[fields]] +name = "raw_split" +type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMResultTreeSeriesSplitLabel, ::FlexFlow::MMResultTreeParallelSplitLabel, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml new file mode 100644 index 0000000000..0f0a326fb5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result_tree/mm_result_tree_series_split_label.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "MMResultTreeSeriesSplitLabel" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "cost" +type = "float" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml index 0fcb065b10..4d4a29eac7 100644 --- a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -8,18 +8,18 @@ features = [ includes = [ "pcg/machine_specification.dtg.h", - "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h", - "compiler/machine_mapping/partial_machine_mapping.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", ] [[fields]] -name = "subgraph" -type = "::FlexFlow::PCGBinarySPDecomposition" +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" [[fields]] name = "resource" type = "::FlexFlow::MachineSpecification" [[fields]] -name = "partial_solution" -type = "::FlexFlow::PartialMachineMapping" +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h deleted file mode 100644 index 4ed43b3470..0000000000 --- a/lib/compiler/include/compiler/machine_mapping/partial_machine_mapping.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H -#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARTIAL_MACHINE_MAPPING_H - -#include "compiler/machine_mapping/machine_mapping.dtg.h" -#include "compiler/machine_mapping/machine_mapping_context.dtg.h" -#include "compiler/machine_mapping/partial_machine_mapping.dtg.h" -#include "compiler/machine_mapping/include_unconstrained.dtg.h" -#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" - -namespace FlexFlow { - -PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &); - -std::unordered_set get_all_layers(PartialMachineMapping const &, - IncludeUnconstrained const &); - -std::optional get_machine_view_for_layer(PartialMachineMapping const &, - parallel_layer_guid_t const &); - -PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem); - -PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, - std::unordered_map const &additional); - -MachineMapping require_complete_mapping(PartialMachineMapping const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml index 147b1e3acf..98d0fc5faf 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "ComputationGraphBinarySPDecomposition" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition/leaf_only_binary_sp_decomposition.dtg.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySPDecomposition<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml index 75e1fec52f..f7d80138c5 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_parallel_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "PCGBinaryParallelSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", ] [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml index 63fc7562cd..48e19022c9 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_series_split.struct.toml @@ -9,14 +9,9 @@ features = [ includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", ] [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml index c9950bf3f4..bead04b307 100644 --- a/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml +++ b/lib/compiler/include/compiler/series_parallel/pcg_binary_sp_decomposition.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "PCGBinarySPDecomposition" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc new file mode 100644 index 0000000000..96605fa238 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement.cc @@ -0,0 +1,49 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/partial_machine_mapping.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { + return AbstractedTensorSetMovement{{}}; +} + +std::unordered_set get_src_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.src_machine_views; + }); +} + +std::unordered_set get_dst_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.dst_machine_views; + }); +} + +TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &abstracted, + PartialMachineMapping const &pre_mapping, + PartialMachineMapping const &post_mapping) { + auto concretize_tensor_movement = [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/transform(a.src_machine_views, + [&](parallel_layer_guid_t const &layer) { + return get_machine_view_for_layer(pre_mapping, layer).value(); + }), + /*dst_machine_views=*/transform(a.dst_machine_views, + [&](parallel_layer_guid_t const &layer) { + return get_machine_view_for_layer(post_mapping, layer).value(); + }), + }; + }; + + return TensorSetMovement{ + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, concretize_tensor_movement), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc index 1caa31aefc..c01354f68b 100644 --- a/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc +++ b/lib/compiler/src/compiler/machine_mapping/estimate_layer_cost.cc @@ -3,9 +3,8 @@ namespace FlexFlow { -float estimate_layer_cost(ParallelComputationGraph const &pcg, - CostEstimator const &cost_estimator, - parallel_layer_guid_t const &layer, +float estimate_layer_cost(CostEstimator const &cost_estimator, + PCGOperatorAttrs const &layer, MachineView const &machine_view) { PCGOperatorAttrs op_attrs = get_parallel_layer_attrs(pcg, layer).op_attrs; diff --git a/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..2c17fc089d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,48 @@ +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split) { + std::unordered_set + edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { + std::unordered_set tensor_edges = filter(edges_across_split, + [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); + + std::unordered_set src_layers = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); + + std::unordered_set dst_layers = + transform(tensor_edges, + [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); + + return AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/src_layers, + /*dst_machine_views=*/dst_layers, + }; + }; + + std::unordered_map single_tensor_movements = + generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return AbstractedTensorSetMovement{ + values(single_tensor_movements), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..8472228534 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc @@ -0,0 +1,45 @@ +#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { + TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); + + std::function to_problem_tree; + + to_problem_tree = [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + return visit( + sp, + overload { + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return mm_problem_tree_make_series_split( + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(get_left_child(series)), + /*rhs=*/to_problem_tree(get_right_child(series))); + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return mm_problem_tree_make_parallel_split( + to_problem_tree(get_left_child(parallel)), + to_problem_tree(get_right_child(parallel))); + }, + [&](parallel_layer_guid_t const &leaf) { + return mm_problem_tree_make_leaf(pcg_get_op_attrs(pcg, leaf)); + } + }); + }; + + return to_problem_tree(sp_decomposition_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index b731913627..d24ccaf63e 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -1,8 +1,16 @@ #include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/get_allowed_machine_views_list.h" #include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.h" +#include "compiler/machine_mapping/mm_problem_tree_series_split.h" #include "compiler/machine_mapping/partial_machine_mapping.dtg.h" #include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" @@ -49,41 +57,34 @@ MachineMappingResult get_optimal_machine_mapping( return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, MachineSpecification const &resources) { - PCGBinarySPDecomposition sp_decomposition_tree = ({ - std::optional returned = get_pcg_balanced_binary_sp_decomposition(context.transitive_reduced_pcg.full_pcg); - if (!returned.has_value()) { - throw mk_runtime_error("Failed to get serial parallel decomposition"); - } - returned.value(); - }); - std::unordered_set all_layers = get_parallel_layers(context.transitive_reduced_pcg.full_pcg); - return get_optimal_machine_mapping_internal(result_cache, - context, - sp_decomposition_tree, - resources, - get_unconstrained_solution_for_layers(all_layers)); + NOT_IMPLEMENTED(); + // return get_optimal_machine_mapping_internal(result_cache, + // context, + // sp_decomposition_tree, + // resources, + // get_unconstrained_solution_for_layers(all_layers)); } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySPDecomposition const &sp_decomposition_tree, + MachineMappingProblemTree const &problem_tree, MachineSpecification const &resources, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &constraints) { MachineMappingState state = MachineMappingState{ - sp_decomposition_tree, resources, partial_solution, + problem_tree, resources, constraints, }; { - std::optional cached_result = + std::optional cached_result = result_cache.load(state); if (cached_result) { return cached_result.value(); @@ -100,107 +101,106 @@ MachineMappingResult get_optimal_machine_mapping_internal( return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +std::optional get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinarySeriesSplit const &series_split, + MMProblemTreeSeriesSplit const &series_split, MachineSpecification const &resource, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &partial_solution) { - MachineMappingResult optimal_result = get_infinity_machine_mapping_result(); + std::optional result = std::nullopt; auto is_subgraph_input = [&](std::unordered_set const &subgraph_nodes, parallel_tensor_guid_t const &input_tensor) { return !contains(subgraph_nodes, input_tensor.raw_graph_output.node); }; - PCGBinarySPDecomposition pre_sub_tree = get_left_child(series_split); - PCGBinarySPDecomposition post_sub_tree = get_right_child(series_split); - - PCGSplitBoundaryLayers boundary_layers = - pcg_get_transitive_reduced_boundary_layers_for_split(context.transitive_reduced_pcg, - series_split); + AbstractedTensorSetMovement tensor_movement = get_abstracted_tensor_movement(series_split); auto get_boundary_machine_view_assignments = [&](std::unordered_set const &layers) - -> std::unordered_set> + -> std::unordered_set { std::unordered_map> allowed = generate_map(layers, [&](parallel_layer_guid_t const &l) { return get_allowed_machine_views_for_layer(context, l); }); - return get_all_assignments(allowed); + return transform(get_all_assignments(allowed), + [](std::unordered_map const &m) { + return MachineMapping{m}; + }); }; - for (std::unordered_map const &assigned_pre_machine_views - : get_boundary_machine_view_assignments(boundary_layers.pre_split_boundary)) { + for (MachineMapping const &assigned_pre_machine_views + : get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { - PartialMachineMapping pre_candidate = - with_additional_layer_machine_views( - get_sub_solution(partial_solution, pre_sub_tree), + MachineMappingConstraints pre_candidate = + with_additional_constraints( + restrict_domain(partial_solution, get_leaves(get_pre_child(series_split))), assigned_pre_machine_views); - MachineMappingResult pre_result = - get_optimal_machine_mapping_internal(result_cache, - context, - pre_sub_tree, - resource, - pre_candidate); - + MachineMappingResultTree pre_result = ({ + std::optional returned + = get_optimal_machine_mapping_internal(result_cache, + context, + get_pre_child(series_split), + resource, + pre_candidate); + if (!returned.has_value()) { + continue; + } + returned.value(); + }); - for (std::unordered_map const &assigned_post_machine_views - : get_boundary_machine_view_assignments(boundary_layers.post_split_boundary)) { + for (MachineMapping const &assigned_post_machine_views + : get_boundary_machine_view_assignments(get_dst_layers(tensor_movement))) { - PartialMachineMapping post_candidate = - with_additional_layer_machine_views( - get_sub_solution(partial_solution, post_sub_tree), + MachineMappingConstraints post_candidate = + with_additional_constraints( + restrict_domain(partial_solution, get_leaves(get_post_child(series_split))), assigned_post_machine_views); - MachineMappingResult post_result = - get_optimal_machine_mapping_internal(result_cache, - context, - post_sub_tree, - resource, - post_candidate); - - TensorSetMovement comm_across_split = get_tensor_set_movement_across_split( - /*transitive_reduced_pcg=*/context.transitive_reduced_pcg, - /*split=*/series_split, - /*pre_mapping=*/pre_candidate, - /*post_mapping=*/post_candidate); - + MachineMappingResultTree post_result = ({ + std::optional returned + = get_optimal_machine_mapping_internal(result_cache, + context, + get_post_child(series_split), + resource, + post_candidate); + if (!returned.has_value()) { + continue; + } + returned.value(); + }); + + TensorSetMovement comm_across_split = concretize_abstracted_tensor_set_movement(tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); float cost_across_split = context.cost_estimator.estimate_cost(comm_across_split); - minimize_runtime( - optimal_result, - sequential_combine(pre_result, cost_across_split, post_result)); + result = minimize_cost(result, make_series_split(cost_across_split, pre_result, post_result)); } } - return optimal_result; + return result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - PCGBinaryParallelSplit const ¶llel, + MMProblemTreeParallelSplit const ¶llel, MachineSpecification const &resources, - PartialMachineMapping const &partial_solution) { - - PCGBinarySPDecomposition left_subtree = get_left_child(parallel); - PartialMachineMapping left_sub_solution = get_sub_solution(partial_solution, - left_subtree); - - PCGBinarySPDecomposition right_subtree = get_right_child(parallel); - PartialMachineMapping right_sub_solution = get_sub_solution(partial_solution, - right_subtree); + MachineMappingConstraints const &partial_solution) { MachineMappingResult optimal_result = [&] { - PCGBinarySeriesSplit series = require_series(make_pcg_series_split( - get_left_child(parallel), - get_right_child(parallel))); + MMProblemTreeSeriesSplit series = MMProblemTreeSeriesSplit{ + MMProblemTreeSeriesSplitLabel{empty_abstracted_tensor_set_movement()}, + parallel.left, + parallel.right, + }; + return get_optimal_machine_mapping_internal(result_cache, context, series, @@ -208,17 +208,22 @@ MachineMappingResult get_optimal_machine_mapping_internal( partial_solution); }(); + MachineMappingConstraints left_sub_solution = restrict_domain(partial_solution, + get_leaves(parallel.left)); + MachineMappingConstraints right_sub_solution = restrict_domain(partial_solution, + get_leaves(parallel.right)); + for (auto const &resource_split : get_machine_resource_splits(resources)) { MachineMappingResult left_result = get_optimal_machine_mapping_internal(result_cache, context, - left_subtree, + parallel.left, resource_split.first, left_sub_solution); MachineMappingResult right_result = get_optimal_machine_mapping_internal(result_cache, context, - right_subtree, + parallel.right, resource_split.second, right_sub_solution); @@ -230,23 +235,25 @@ MachineMappingResult get_optimal_machine_mapping_internal( return optimal_result; } -MachineMappingResult get_optimal_machine_mapping_internal( +MachineMappingResultTree get_optimal_machine_mapping_internal( MachineMappingCache &result_cache, MachineMappingContext const &context, - parallel_layer_guid_t const &layer, + PCGOperatorAttrs const &layer, MachineSpecification const &resource, - PartialMachineMapping const &partial_solution) { + MachineMappingConstraints const &constraints) { + + assert (get_all_layers(constraints, IncludeUnconstrained{true}) == std::unordered_set{layer}); - assert (get_all_layers(partial_solution, IncludeUnconstrained{true}) == std::unordered_set{layer}); + MachineMapping concrete_mapping = require_fully_constrained(constraints); float cost = estimate_layer_cost(context.transitive_reduced_pcg.full_pcg, context.cost_estimator, layer, - get_machine_view_for_layer(partial_solution, layer).value()); + concrete_mapping.machine_views.at(layer)); - return MachineMappingResult{ + return make_leaf_node( /*runtime=*/cost, - /*machine_mapping=*/require_complete_mapping(partial_solution), + /*machine_mapping=*/concrete_mapping, }; } diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 8c84e227a7..f237fba88f 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,4 +1,6 @@ #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_abstracted_tensor_set_movement_across_split.h" #include "compiler/machine_mapping/partial_machine_mapping.h" #include "compiler/machine_mapping/transitive_reduced_pcg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -15,39 +17,8 @@ TensorSetMovement get_tensor_set_movement_across_split(TransitiveReducedPCG cons PCGBinarySeriesSplit const &split, PartialMachineMapping const &pre_mapping, PartialMachineMapping const &post_mapping) { - std::unordered_set - edges_across_split = pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); - - auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t) { - std::unordered_set tensor_edges = filter(edges_across_split, - [&](ParallelComputationGraphEdge const &e) { return get_parallel_tensor(e) == t; }); - - std::unordered_set src_machine_views = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_machine_view_for_layer(pre_mapping, get_src_layer(e)).value(); - }); - - std::unordered_set dst_machine_views = - transform(tensor_edges, - [&](ParallelComputationGraphEdge const &e) { - return get_machine_view_for_layer(post_mapping, get_dst_layer(e)).value(); - }); - - return SingleTensorMovement{ - /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), - /*src_machine_views=*/src_machine_views, - /*dst_machine_views=*/dst_machine_views, - }; - }; - - std::unordered_map single_tensor_movements = - generate_map(pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), - get_movement_for_tensor); - - return TensorSetMovement{ - values(single_tensor_movements), - }; + AbstractedTensorSetMovement abstracted = get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement(abstracted, pre_mapping, post_mapping); } diff --git a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc similarity index 69% rename from lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc rename to lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 5ae3126184..721fa1e32b 100644 --- a/lib/compiler/src/compiler/machine_mapping/partial_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -10,8 +10,8 @@ namespace FlexFlow { -PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set const &layers) { - return PartialMachineMapping{ +MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set const &layers) { + return MachineMappingConstraints{ generate_map(layers, [](parallel_layer_guid_t const &) -> std::optional { return std::nullopt; @@ -19,7 +19,7 @@ PartialMachineMapping get_unconstrained_solution_for_layers(std::unordered_set

get_all_layers(PartialMachineMapping const &partial_solution, +std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution, IncludeUnconstrained const &include_unconstrained) { std::unordered_set with_unconstrained = keys(partial_solution.machine_views); @@ -31,24 +31,22 @@ std::unordered_set get_all_layers(PartialMachineMapping c } } -std::optional get_machine_view_for_layer(PartialMachineMapping const &partial_solution, +std::optional get_machine_view_for_layer(MachineMappingConstraints const &partial_solution, parallel_layer_guid_t const &layer) { return partial_solution.machine_views.at(layer); } -PartialMachineMapping get_sub_solution(PartialMachineMapping const &partial_solution, - PCGBinarySPDecomposition const &sub_problem) { +MachineMappingConstraints get_sub_solution(MachineMappingConstraints const &partial_solution, + std::unordered_set const &sub_problem) { - std::unordered_set sub_problem_layers = unordered_set_of(get_parallel_layers(sub_problem)); - - return PartialMachineMapping{ - restrict_keys(partial_solution.machine_views, sub_problem_layers), + return MachineMappingConstraints{ + restrict_keys(partial_solution.machine_views, sub_problem), }; } -PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping const &partial_solution, +MachineMappingConstraints with_additional_layer_machine_views(MachineMappingConstraints const &partial_solution, std::unordered_map const &additional) { - PartialMachineMapping result = partial_solution; + MachineMappingConstraints result = partial_solution; for (auto const &[layer, machine_view] : additional) { std::optional current_machine_view = result.machine_views.at(layer); @@ -68,7 +66,7 @@ PartialMachineMapping with_additional_layer_machine_views(PartialMachineMapping } -MachineMapping require_complete_mapping(PartialMachineMapping const &partial_mapping) { +MachineMapping require_complete_mapping(MachineMappingConstraints const &partial_mapping) { return MachineMapping{ map_values(partial_mapping.machine_views, [](std::optional const &mv) { return mv.value(); }), diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..3aace6b332 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree.cc @@ -0,0 +1,95 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/full_binary_tree/get_left_child.h" +#include "compiler/machine_mapping/full_binary_tree/get_right_child.h" +#include "compiler/machine_mapping/full_binary_tree/require.h" +#include "compiler/machine_mapping/full_binary_tree/visit.h" +#include "compiler/machine_mapping/full_binary_tree/get_leaves.h" +#include "utils/overload.h" +#include "compiler/machine_mapping/mm_problem_tree_split_label.h" + +namespace FlexFlow { + +MachineMappingProblemTree mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + FullBinaryTree{ + FullBinaryTreeParentNode{ + /*label=*/MMProblemTreeSplitLabel{ + MMProblemTreeSeriesSplitLabel{ + /*tensor_set_movement=*/tensor_set_movement, + }, + }, + /*lhs=*/lhs.raw_tree, + /*rhs=*/rhs.raw_tree, + }, + }, + }; +} + +MachineMappingProblemTree mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + FullBinaryTree{ + FullBinaryTreeParentNode{ + /*label=*/MMProblemTreeSplitLabel{ + MMProblemTreeParallelSplitLabel{}, + }, + /*lhs=*/lhs.raw_tree, + /*rhs=*/rhs.raw_tree, + }, + }, + }; +} + +MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &layer) { + return MachineMappingProblemTree{ + FullBinaryTree{ + layer, + }, + }; +} + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) { + return visit( + tree.raw_tree, + overload { + [](FullBinaryTreeParentNode const &parent) { + return split_label_get_node_type(parent.label); + }, + [](PCGOperatorAttrs const &) { + return SPDecompositionTreeNodeType::NODE; + } + }); +} + + +MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { + FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); + + return MMProblemTreeSeriesSplit{ + /*label=*/raw_node.label.get(), + /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, + /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, + }; +} + +MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { + FullBinaryTreeParentNode raw_node = require_parent_node(t.raw_tree); + + return MMProblemTreeParallelSplit{ + /*label=*/raw_node.label.get(), + /*left=*/MachineMappingProblemTree{get_left_child(raw_node)}, + /*right=*/MachineMappingProblemTree{get_right_child(raw_node)}, + }; +} + +PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &t) { + return require_leaf(t.raw_tree); +} + +std::unordered_multiset get_leaves(MachineMappingProblemTree const &t) { + return get_leaves(t.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc new file mode 100644 index 0000000000..28c6137440 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_series_split.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/mm_problem_tree_series_split.h" +#include "compiler/machine_mapping/full_binary_tree/require.h" + +namespace FlexFlow { + +MachineMappingProblemTree const &get_left_child(MMProblemTreeSeriesSplit const &s) { + FullBinaryTree< require_parent(s.problem_tree.raw_tree); +} + +MachineMappingProblemTree const &get_right_child(MMProblemTreeSeriesSplit const &) { + +} + +AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &) { + +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc new file mode 100644 index 0000000000..54b7a4eaf8 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/mm_problem_tree_split_label.cc @@ -0,0 +1,17 @@ +#include "compiler/machine_mapping/mm_problem_tree_split_label.h" +#include "utils/overload.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType split_label_get_node_type(MMProblemTreeSplitLabel const &l) { + return l.visit(overload { + [](MMProblemTreeSeriesSplitLabel const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplitLabel const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc index 0fe344aef8..dad21c6c8c 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_parallel_split.cc @@ -1,5 +1,5 @@ #include "compiler/series_parallel/pcg_binary_parallel_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" namespace FlexFlow { diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc index efa919d5b9..31a90533ff 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_series_split.cc @@ -1,7 +1,6 @@ #include "compiler/series_parallel/pcg_binary_series_split.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" namespace FlexFlow { diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..de4da010e5 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_mapping_problem_tree.cc @@ -0,0 +1,176 @@ +#include "compiler/machine_mapping/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree.h" +#include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_mapping_problem_tree") { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + auto make_output_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + + SUBCASE("single layer") { + ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_leaf_node(input_layer); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree correct = mm_problem_tree_make_leaf(input_attrs); + + CHECK(result == correct); + } + + SUBCASE("two layers in series") { + ParallelLayerAddedResult input_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + parallel_tensor_guid_t input = get_only(input_added.outputs); + + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }; + ParallelTensorShape relu_output_shape = input_shape; + ParallelLayerAddedResult relu_added = add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); + parallel_layer_guid_t relu_layer = relu_added.parallel_layer; + parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_leaf_node(input_layer), + make_pcg_leaf_node(relu_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_series_split( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + input_shape, + {input_layer}, + {relu_layer}, + }, + }}, + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(relu_attrs)); + + CHECK(result == correct); + } + + SUBCASE("two layers in parallel") { + ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + + ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_leaf_node(input1_layer), + make_pcg_leaf_node(input2_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(input_attrs)); + + CHECK(result == correct); + } + + SUBCASE("multiple tensors across split") { + ParallelLayerAddedResult input1_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + + ParallelLayerAddedResult input2_added = pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + + PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + ParallelTensorShape ew_op_output_shape = input_shape; + ParallelLayerAddedResult ew_op_added = add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); + parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + + PCGBinarySPDecomposition sp_decomposition = \ + make_pcg_series_split( + make_pcg_parallel_split( + make_pcg_leaf_node(input1_layer), + make_pcg_leaf_node(input2_layer)), + make_pcg_leaf_node(ew_op_layer)); + + MachineMappingProblemTree result = get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = \ + mm_problem_tree_make_series_split( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{input1_layer}, + /*dst_machine_views=*/{ew_op_layer}, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{input2_layer}, + /*dst_machine_views=*/{ew_op_layer}, + }, + }}, + /*pre=*/mm_problem_tree_make_parallel_split( + mm_problem_tree_make_leaf(input_attrs), + mm_problem_tree_make_leaf(input_attrs)), + /*post=*/mm_problem_tree_make_leaf(ew_op_attrs)); + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 3c4ac1174c..02b3fe4a03 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -8,12 +8,13 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_optimal_machine_mapping") { + TEST_CASE("get_optimal_machine_mapping_internal") { auto allowed_machine_views1 = [&](ParallelLayerAttrs const &, MachineSpecification const &) { return std::unordered_set{ make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; }; + MachineSpecification machine_spec = MachineSpecification{ /*num_nodes=*/2, /*num_cpus_per_node=*/1, @@ -32,7 +33,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineView mv1 = make_1d_machine_view(gpu_id_t{1}, gpu_id_t{2}); auto allowed_machine_views = [&](ParallelLayerAttrs const &, - MachineSpecification const &) { + MachineSpecification const &) { return std::unordered_set{mv1}; }; @@ -93,6 +94,10 @@ TEST_SUITE(FF_TEST_SUITE) { FAIL("TODO"); } + SUBCASE("multiple edges across split") { + FAIL("TODO"); + } + // SUBCASE("simple PCG") { // // ParallelComputationGraph pcg_simple = [&] { diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index b6f7790c49..a799e01dbc 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -46,6 +46,8 @@ std::vector ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b26478107d..1562425a80 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -143,6 +143,11 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, return pcg.raw_graph.at(l.raw_graph_node); } +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).op_attrs; +} + ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, parallel_tensor_guid_t const &t) { diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h new file mode 100644 index 0000000000..3d94996079 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::string format_as(FullBinaryTreeParentNode const &t) { + return fmt::format("<{} ({} {})>", + t.label, + get_left_child(t), + get_right_child(t)); +} + +template +std::string format_as(FullBinaryTree const &t) { + return visit( + t, + overload{ + [](FullBinaryTreeParentNode const &parent) { + return fmt::to_string(parent); + }, + [](LeafLabel const &leaf) { + return fmt::format("{}", leaf); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h new file mode 100644 index 0000000000..f90ffb88c4 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h @@ -0,0 +1,87 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct FullBinaryTree; + +template +struct FullBinaryTreeParentNode { + explicit FullBinaryTreeParentNode( + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) + : label(label), + left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) + { } + + FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; + + bool operator==(FullBinaryTreeParentNode const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTreeParentNode const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTreeParentNode const &other) const { + return this->tie() < other.tie(); + } +public: + ParentLabel label; + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; +private: + std::tuple const &, + FullBinaryTree const &> + tie() const { + return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct FullBinaryTree { +public: + FullBinaryTree() = delete; + explicit FullBinaryTree(FullBinaryTreeParentNode const &t) + : root{t} {} + + explicit FullBinaryTree(LeafLabel const &t) + : root{t} {} + + bool operator==(FullBinaryTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTree const &other) const { + return this->tie() < other.tie(); + } +public: + std::variant, LeafLabel> root; +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml new file mode 100644 index 0000000000..1f8af17cf3 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h new file mode 100644 index 0000000000..c58a850a6d --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include +#include "utils/containers/multiset_union.h" + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(FullBinaryTree const &t) { + return visit>( + t, + overload { + [](FullBinaryTreeParentNode const &parent) { + return multiset_union(get_leaves(get_left_child(parent)), + get_leaves(get_right_child(parent))); + }, + [](ChildLabel const &leaf) { + return std::unordered_multiset{leaf}; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h new file mode 100644 index 0000000000..163503abfd --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { + return *t.left_child_ptr; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h new file mode 100644 index 0000000000..e1cbe909d5 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H + +#include "utils/overload.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" + +namespace FlexFlow { + +template +FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { + return visit( + t, + overload { + [](FullBinaryTreeParentNode const &) { + return FullBinaryTreeNodeType::PARENT; + }, + [](LeafLabel const &) { + return FullBinaryTreeNodeType::LEAF; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h new file mode 100644 index 0000000000..e40f2024a1 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { + return *t.right_child_ptr; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h new file mode 100644 index 0000000000..a29836f972 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/hash.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::FullBinaryTreeParentNode> { + size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { + return get_std_hash(t.tie()); + } +}; + +template +struct hash<::FlexFlow::FullBinaryTree> { + size_t operator()(::FlexFlow::FullBinaryTree const &t) const { + return get_std_hash(t.tie()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h new file mode 100644 index 0000000000..0e5ad4914a --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +FullBinaryTreeParentNode const &require_parent_node(FullBinaryTree const &t) { + return std::get>(t.root); +} + +template +LeafLabel const &require_leaf(FullBinaryTree const &t) { + return std::get(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/transform.h b/lib/utils/include/utils/full_binary_tree/transform.h new file mode 100644 index 0000000000..3fef8efd18 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/transform.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_TRANSFORM_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/overload.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template , + typename LeafLabel2 = std::invoke_result_t> +FullBinaryTreeParentNode transform(FullBinaryTreeParentNode const &t, F f) { + return FullBinaryTreeParentNode{ + transform(get_left_child(t), f), + transform(get_right_child(t), f), + }; +} + +template , + typename LeafLabel2 = std::invoke_result_t> +FullBinaryTree transform(FullBinaryTree const &t, F f) { + return visit> + ( t, + overload { + [&](FullBinaryTreeParentNode const &parent) { + return FullBinaryTree{ + transform(parent, f), + }; + }, + [&](LeafLabel const &leaf) { + return FullBinaryTree{ + f(leaf), + }; + } + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h new file mode 100644 index 0000000000..93e5bfb504 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H + +#include "utils/full_binary_tree/full_binary_tree.h" + +namespace FlexFlow { + +template +Result visit(FullBinaryTree const &tt, F f) { + if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative(tt.root)) { + return f(std::get(tt.root)); + } else { + throw mk_runtime_error( + "Unexpected case in visit(FullBinaryTree)"); + } +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml index 985fb3089d..0dcae5177a 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinaryParallelSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinaryParallelSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml index c7c89da6d2..45472cb243 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinarySeriesSplit" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_split" -type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinarySeriesSplit<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml index 1241311150..0000213398 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml @@ -2,21 +2,15 @@ namespace = "FlexFlow" name = "BinarySPDecompositionTree" features = [ "eq", - "ord", "hash", "fmt", ] includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", "utils/graph/node/node.dtg.h", ] -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h deleted file mode 100644 index 42d71ce54e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace FlexFlow { - -template -std::string format_as(GenericBinarySeriesSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySeriesSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinaryParallelSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinaryParallelSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return fmt::format("", s); - }, - [](GenericBinaryParallelSplit const &s) { - return fmt::format("", s); - }, - [](T const &t) { - return fmt::format("", t); - }, - }); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySPDecompositionTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..e3d92c7409 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "GenericBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "label" +type = "ParallelSplitLabel" + +[[fields]] +name = "lhs" +type = "GenericBinarySPDecompositionTree" + +[[fields]] +name = "rhs" +type = "GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml new file mode 100644 index 0000000000..db11340d6e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "GenericBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "label" +type = "SeriesSplitLabel" + +[[fields]] +name = "pre" +type = "GenericBinarySPDecompositionTree" + +[[fields]] +name = "post" +type = "GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h deleted file mode 100644 index 74f5ba5d8a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h +++ /dev/null @@ -1,155 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct GenericBinarySPDecompositionTree; - -template -struct GenericBinarySeriesSplit { -public: - GenericBinarySeriesSplit() = delete; - explicit GenericBinarySeriesSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; - - bool operator==(GenericBinarySeriesSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySeriesSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySeriesSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinaryParallelSplit { -public: - GenericBinaryParallelSplit() = delete; - explicit GenericBinaryParallelSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; - - bool operator==(GenericBinaryParallelSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinaryParallelSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinaryParallelSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinarySPDecompositionTree { -public: - GenericBinarySPDecompositionTree() = delete; - explicit GenericBinarySPDecompositionTree( - GenericBinarySeriesSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree( - GenericBinaryParallelSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} - - GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = - default; - - bool operator==(GenericBinarySPDecompositionTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySPDecompositionTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySPDecompositionTree const &other) const { - return this->tie() < other.tie(); - } - -public: - std::variant, GenericBinaryParallelSplit, T> - root; - -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -// namespace rc { -// -// template <> -// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { -// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { -// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { -// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); -// }; -// -// } // namespace rc - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..236274e617 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "SeriesSplitLabel", + "ParallelSplitLabel", + "LeafLabel", +] + +includes = [ + "utils/full_binary_tree/full_binary_tree.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::FullBinaryTree, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h deleted file mode 100644 index c6c1186d3d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -TT const &get(GenericBinarySPDecompositionTree const &t) { - return std::get(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 51e1e20bac..cad88d25b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #include "utils/containers/multiset_union.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" @@ -11,26 +11,26 @@ namespace FlexFlow { -template -std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { - return visit>( +template +std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &tt) { + return visit>( tt, overload{ - [](T const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, - [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + [](LeafLabel const &t) { return std::unordered_multiset{t}; }, + [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, + [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, }); } -template -std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { +template +std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { return multiset_union(get_leaves(get_left_child(s)), get_leaves(get_right_child(s))); } -template -std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { +template +std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { return multiset_union(get_leaves(get_left_child(p)), get_leaves(get_right_child(p))); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h index 46a460b64e..9e857341c6 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -1,42 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &s) { - return *s.left_child_ptr; +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &s) { + return s.pre; } -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &p) { - return *p.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_left_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_left_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_left_child incorrectly called on leaf node"); - }, - }); +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &p) { + return p.lhs; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index 883acda480..888d3c6627 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -1,27 +1,32 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/full_binary_tree/visit.h" #include "utils/overload.h" namespace FlexFlow { -template +template SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tt) { + get_node_type(GenericBinarySPDecompositionTree const &tt) { return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplit const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](T const &) { return SPDecompositionTreeNodeType::NODE; }, - }); + tt.raw_tree, + overload { + [](LeafLabel const &) { + return SPDecompositionTreeNodeType::NODE; + }, + [](FullBinaryTreeParentNode, LeafLabel> const &parent) { + if (std::holds_alternative(parent.label)) { + return SPDecompositionTreeNodeType::SERIES; + } else { + assert (std::holds_alternative(parent.label)); + + return SPDecompositionTreeNodeType::PARALLEL; + } + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h index f0bfba43a2..766995b8a9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -1,42 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &s) { - return *s.right_child_ptr; +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &s) { + return s.post; } -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &p) { - return *p.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_right_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_right_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_right_child incorrectly called on leaf node"); - }, - }); +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &p) { + return p.rhs; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h deleted file mode 100644 index 983dc4a572..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::GenericBinarySeriesSplit> { - size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinaryParallelSplit> { - size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { - size_t operator()( - ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { - return get_std_hash(s.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h index 8086f38244..bdaf8bcc2b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -1,23 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" namespace FlexFlow { -template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); +template +bool is_series_split(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::SERIES; } -template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); +template +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::PARALLEL; } -template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative(t.root); +template +bool is_leaf(GenericBinarySPDecompositionTree const &t) { + return get_node_type(t) == SPDecompositionTreeNodeType::NODE; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 3ffa63753a..1ec84f194f 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" @@ -9,19 +9,19 @@ namespace FlexFlow { -template +template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ - [](T const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](LeafLabel const &) { return true; }, + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_right_child(s)) && is_binary_sp_tree_left_associative(get_left_child(s)) && is_binary_sp_tree_left_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_right_child(p)) && is_binary_sp_tree_left_associative(get_left_child(p)) && is_binary_sp_tree_left_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index d88459b432..a3ff9d4012 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" @@ -9,19 +9,19 @@ namespace FlexFlow { -template +template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { + GenericBinarySPDecompositionTree const &tt) { return visit( tt, overload{ - [](T const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { + [](LeafLabel const &t) { return true; }, + [](GenericBinarySeriesSplit const &s) { return !is_series_split(get_left_child(s)) && is_binary_sp_tree_right_associative(get_left_child(s)) && is_binary_sp_tree_right_associative(get_right_child(s)); }, - [](GenericBinaryParallelSplit const &p) { + [](GenericBinaryParallelSplit const &p) { return !is_parallel_split(get_left_child(p)) && is_binary_sp_tree_right_associative(get_left_child(p)) && is_binary_sp_tree_right_associative(get_right_child(p)); diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index f55b71146a..e925292b35 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -1,37 +1,49 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { -template -GenericBinarySPDecompositionTree make_generic_binary_series_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - lhs, - rhs, - }, +template +GenericBinarySPDecompositionTree make_generic_binary_series_split( + SeriesLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + label, + lhs.raw_tree, + rhs.raw_tree, + } + } }; } -template -GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - lhs, - rhs, - }, +template +GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + SeriesLabel const &label, + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + FullBinaryTreeParentNode, LeafLabel>{ + label, + lhs.raw_tree, + rhs.raw_tree, + } + } }; } -template -GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { - return GenericBinarySPDecompositionTree{t}; +template +GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(LeafLabel const &leaf) { + return GenericBinarySPDecompositionTree{ + FullBinaryTree, LeafLabel>{ + leaf, + }, + }; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index 4137585c1a..1c20de06dc 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -1,26 +1,38 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" namespace FlexFlow { -template -GenericBinarySeriesSplit const & - require_series(GenericBinarySPDecompositionTree const &t) { - return get>(t); +template +GenericBinarySeriesSplit + require_series(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + + return GenericBinarySeriesSplit{ + /*label=*/std::get(parent.label), + /*pre=*/get_left_child(parent), + /*post=*/get_right_child(parent), + }; } -template -GenericBinaryParallelSplit const & - require_parallel(GenericBinarySPDecompositionTree const &t) { - return get>(t); +template +GenericBinaryParallelSplit + require_parallel(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + + return GenericBinarySeriesSplit{ + /*label=*/std::get(parent.label), + /*pre=*/get_left_child(parent), + /*post=*/get_right_child(parent), + }; } -template -T const &require_leaf(GenericBinarySPDecompositionTree const &t) { - return get(t); +template +LeafLabel require_leaf(GenericBinarySPDecompositionTree const &t) { + return require_leaf(t.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h index 08ab99a292..c557711a3b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -1,49 +1,70 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" namespace FlexFlow { -template > -GenericBinarySeriesSplit - transform(GenericBinarySeriesSplit const &s, F f) { - return GenericBinarySeriesSplit{ +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &s, F f) { + return GenericBinarySeriesSplit{ + f(s.label), transform(get_left_child(s), f), transform(get_right_child(s), f), }; }; -template > -GenericBinaryParallelSplit - transform(GenericBinaryParallelSplit const &s, F f) { - return GenericBinaryParallelSplit{ +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinaryParallelSplit + transform(GenericBinaryParallelSplit const &s, F f) { + return GenericBinaryParallelSplit{ + f(s.label), transform(get_left_child(s), f), transform(get_right_child(s), f), }; }; -template > -GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, F f) { - return visit>( +template , + typename ParallelLabel2 = std::invoke_result_t, + typename LeafLabel2 = std::invoke_result_t> +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, F f) { + return visit>( tt, overload{ - [&](GenericBinarySeriesSplit const &s) { - return GenericBinarySPDecompositionTree{ + [&](GenericBinarySeriesSplit const &s) { + return GenericBinarySPDecompositionTree{ transform(s, f), }; }, - [&](GenericBinaryParallelSplit const &s) { - return GenericBinarySPDecompositionTree{ + [&](GenericBinaryParallelSplit const &s) { + return GenericBinarySPDecompositionTree{ transform(s, f), }; }, - [&](T const &t) { - return GenericBinarySPDecompositionTree{ + [&](LeafLabel const &t) { + return GenericBinarySPDecompositionTree{ f(t), }; }, diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 0d9503e59f..ce4e4ebf55 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -2,34 +2,29 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, F f) { - if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative(tt.root)) { - return f(std::get(tt.root)); - } else { - throw mk_runtime_error( - "Unexpected case in visit(GenericBinarySPDecompositionTree)"); +template +Result visit(GenericBinarySPDecompositionTree const &tt, F f) { + SPDecompositionTreeNodeType node_type = get_node_type(tt); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + Result result = f(require_series_split(t)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + Result result = f(require_parallel_split(t)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + Result result = f(require_leaf(t)); + return result; + } + default: + throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type)); } - - // return std::visit(tt.root, overload { - // [&](GenericBinarySeriesSplit const &s) -> Result { - // return f(s); - // }, - // [&](GenericBinaryParallelSplit const &p) -> Result { - // return f(p); - // }, - // [&](T const &t) -> Result { - // return f(t); - // }, - // }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..628cf89a44 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template +std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &t) { + return get_leaves(t.raw_tree); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h new file mode 100644 index 0000000000..9d4ce10cb4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_PARALLEL_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &s) { + return s.lhs; +} + +template +LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &s) { + return s.rhs; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..b92175b16f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h", +] + +[[fields]] +name = "lhs" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" + +[[fields]] +name = "rhs" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml new file mode 100644 index 0000000000..0506d36227 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinaryParallelSplitLabel" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h new file mode 100644 index 0000000000..853def2c60 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SERIES_SPLIT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &s) { + return s.pre; +} + +template +LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &s) { + return s.post; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml new file mode 100644 index 0000000000..a7ff2dcc70 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "pre" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" + +[[fields]] +name = "post" +type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml new file mode 100644 index 0000000000..b780bfeea6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySeriesSplitLabel" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..dacab0244a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "LeafOnlyBinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt" +] + +template_params = [ + "LeafLabel", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.dtg.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::LeafOnlyBinarySeriesSplitLabel, ::FlexFlow::LeafOnlyBinaryParallelSplitLabel, LeafLabel>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h new file mode 100644 index 0000000000..222799dbe9 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_MAKE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySPDecompositionTree make_series_split(LeafOnlyBinarySPDecompositionTree const &pre, + LeafOnlyBinarySPDecompositionTree const &post) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_series_split( + LeafOnlyBinaryParallelSplitLabel{}, + pre, + post), + }; +} + +template +LeafOnlyBinarySPDecompositionTree make_parallel_split(LeafOnlyBinarySPDecompositionTree const &lhs, + LeafOnlyBinarySPDecompositionTree const &rhs) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_series_split( + LeafOnlyBinaryParallelSplitLabel{}, + lhs, + rhs), + }; +} + +template +LeafOnlyBinarySPDecompositionTree make_leaf_node(LeafLabel const &label) { + return LeafOnlyBinarySPDecompositionTree{ + make_generic_binary_sp_leaf< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel>(label), + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h new file mode 100644 index 0000000000..9011fadd78 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -0,0 +1,52 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" + +namespace FlexFlow { + +template +LeafOnlyBinarySeriesSplit + require_series(LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinarySeriesSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel> raw = + require_series(t.raw_tree); + + return LeafOnlyBinarySeriesSplit{ + LeafOnlyBinarySeriesSplitLabel{}, + LeafOnlyBinarySPDecompositionTree{raw.pre}, + LeafOnlyBinarySPDecompositionTree{raw.post}, + }; +} + +template +LeafOnlyBinaryParallelSplit + require_parallel(LeafOnlyBinarySPDecompositionTree const &t) { + GenericBinarySeriesSplit< + LeafOnlyBinarySeriesSplitLabel, + LeafOnlyBinaryParallelSplitLabel, + LeafLabel> raw = + require_series(t.raw_tree); + + return LeafOnlyBinarySeriesSplit{ + LeafOnlyBinaryParallelSplitLabel{}, + LeafOnlyBinarySPDecompositionTree{raw.pre}, + LeafOnlyBinarySPDecompositionTree{raw.post}, + }; +} + +template +LeafLabel require_leaf(LeafOnlyBinarySPDecompositionTree const &t) { + return require_leaf(t.raw_tree); +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h new file mode 100644 index 0000000000..364a3200b1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEAF_ONLY_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template > +LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinarySeriesSplit{ + transform(t.pre, f), + transform(t.post, f), + }; +} + +template > +LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinaryParallelSplit{ + transform(t.lhs, f), + transform(t.rhs, f), + }; +} + +template > +LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, F &&f) { + auto ff = overload { + [&](T const &t) { + return f(t); + }, + [&](auto const &x) { + return x; + }, + }; + + return LeafOnlyBinarySPDecompositionTree{ + transform(t.raw_tree, ff), + }; +} + + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc new file mode 100644 index 0000000000..41b7b79101 --- /dev/null +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -0,0 +1,35 @@ +#include "utils/containers/flatmap.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::unordered_set, F)") { + auto get_chars = [](std::string const &s) { + std::unordered_set result; + for (char c : s) { + result.insert(c); + } + return result; + }; + + SUBCASE("type changing") { + std::unordered_set input = {"hello", " ", "", "world", "!"}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::unordered_set input = {}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +}