Skip to content

Commit

Permalink
Move over to ProblemTree/ResultTree framework for machine mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Sep 29, 2024
1 parent 7e73162 commit bdcc10e
Show file tree
Hide file tree
Showing 92 changed files with 1,997 additions and 722 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ name = "src_machine_views"
type = "std::unordered_set<::FlexFlow::MachineView>"

[[fields]]
name = "dst_machine_view"
name = "dst_machine_views"
type = "std::unordered_set<::FlexFlow::MachineView>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
namespace = "FlexFlow"
name = "AbstractedSingleTensorMovement"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"op-attrs/parallel_tensor_shape.dtg.h",
"pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h",
"<unordered_set>",
]

src_includes = [
"utils/hash/unordered_set.h",
"utils/fmt/unordered_set.h",
]

[[fields]]
name = "parallel_tensor_shape"
type = "::FlexFlow::ParallelTensorShape"

[[fields]]
name = "src_machine_views"
type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>"

[[fields]]
name = "dst_machine_views"
type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H

#include "compiler/cost_estimator/tensor_set_movement.dtg.h"
#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h"
#include "compiler/machine_mapping/machine_mapping.dtg.h"

namespace FlexFlow {

AbstractedTensorSetMovement empty_abstracted_tensor_set_movement();

std::unordered_set<parallel_layer_guid_t> get_src_layers(AbstractedTensorSetMovement const &);
std::unordered_set<parallel_layer_guid_t> get_dst_layers(AbstractedTensorSetMovement const &);

TensorSetMovement concretize_abstracted_tensor_set_movement(AbstractedTensorSetMovement const &,
MachineMapping const &pre,
MachineMapping const &post);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
namespace = "FlexFlow"
name = "AbstractedTensorSetMovement"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h",
"<unordered_set>",
]

src_includes = [
"utils/fmt/unordered_multiset.h",
"utils/hash/unordered_multiset.h",
]

[[fields]]
name = "single_tensor_movements"
type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H

#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h"
#include "compiler/series_parallel/pcg_binary_series_split.dtg.h"
#include "compiler/machine_mapping/abstracted_tensor_set_movement.dtg.h"

namespace FlexFlow {

AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(TransitiveReducedPCG const &transitive_reduced_pcg,
PCGBinarySeriesSplit const &split);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

#include "compiler/machine_mapping/machine_mapping.h"
#include "compiler/machine_mapping/machine_mapping_cache.h"
#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h"
#include "compiler/machine_mapping/machine_mapping_context.dtg.h"
#include "compiler/machine_mapping/partial_machine_mapping.dtg.h"
#include "compiler/series_parallel/pcg_binary_parallel_split.dtg.h"
#include "compiler/series_parallel/pcg_binary_series_split.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h"
#include "pcg/machine_specification.h"
#include "pcg/machine_view.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.h"
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h"

namespace FlexFlow {

MachineMappingResult get_optimal_machine_mapping(
MachineMappingResultTree get_optimal_machine_mapping(
ParallelComputationGraph const &pcg,
std::function<std::unordered_set<MachineView>(
ParallelLayerAttrs const &, MachineSpecification const &)> const
Expand All @@ -23,38 +25,38 @@ MachineMappingResult get_optimal_machine_mapping(
MachineSpecification const &resources,
MachineMappingCache &cached_subgraph_results);

MachineMappingResult
MachineMappingResultTree
get_optimal_machine_mapping_internal(MachineMappingCache &result_cache,
MachineMappingContext const &context,
MachineSpecification const &resources);

MachineMappingResult get_optimal_machine_mapping_internal(
std::optional<MachineMappingResultTree> get_optimal_machine_mapping_internal(
MachineMappingCache &result_cache,
MachineMappingContext const &context,
PCGBinarySPDecomposition const &sp_decomposition,
MachineMappingProblemTree const &,
MachineSpecification const &resources,
PartialMachineMapping const &);
MachineMappingConstraints const &);

MachineMappingResult get_optimal_machine_mapping_internal(
std::optional<MachineMappingResultTree> get_optimal_machine_mapping_internal(
MachineMappingCache &result_cache,
MachineMappingContext const &context,
PCGBinarySeriesSplit const &series,
MMProblemTreeSeriesSplit const &,
MachineSpecification const &resources,
PartialMachineMapping const &);
MachineMappingConstraints const &);

MachineMappingResult get_optimal_machine_mapping_internal(
std::optional<MachineMappingResultTree> get_optimal_machine_mapping_internal(
MachineMappingCache &result_cache,
MachineMappingContext const &context,
PCGBinaryParallelSplit const &parallel,
MMProblemTreeParallelSplit const &,
MachineSpecification const &resources,
PartialMachineMapping const &);
MachineMappingConstraints const &);

MachineMappingResult get_optimal_machine_mapping_internal(
std::optional<MachineMappingResultTree> get_optimal_machine_mapping_internal(
MachineMappingCache &result_cache,
MachineMappingContext const &,
parallel_layer_guid_t const &,
MachineSpecification const &,
PartialMachineMapping const &);
MachineMappingConstraints const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H
#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_DP_CACHE_H

#include "compiler/machine_mapping/machine_mapping_result.dtg.h"
#include "compiler/machine_mapping/machine_mapping_result_tree/machine_mapping_result_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_state.dtg.h"
#include "utils/optional.h"

Expand All @@ -11,11 +11,11 @@ class MachineMappingCache {
public:
MachineMappingCache() = default;

std::optional<MachineMappingResult> load(MachineMappingState const &) const;
void save(MachineMappingState const &, MachineMappingResult const &);
std::optional<MachineMappingResultTree> load(MachineMappingState const &) const;
void save(MachineMappingState const &, MachineMappingResultTree const &);

private:
std::unordered_map<MachineMappingState, MachineMappingResult> cache;
std::unordered_map<MachineMappingState, MachineMappingResultTree> cache;
};

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H

#include "compiler/machine_mapping/machine_mapping.dtg.h"
#include "compiler/machine_mapping/machine_mapping_context.dtg.h"
#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h"
#include "compiler/machine_mapping/include_unconstrained.dtg.h"
#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"

namespace FlexFlow {

MachineMappingConstraints get_unconstrained_solution_for_layers(std::unordered_set<parallel_layer_guid_t> const &);

std::unordered_set<parallel_layer_guid_t> get_all_layers(MachineMappingConstraints const &,
IncludeUnconstrained const &);

std::optional<MachineView> get_machine_view_for_layer(MachineMappingConstraints const &,
parallel_layer_guid_t const &);

MachineMappingConstraints restrict_domain(MachineMappingConstraints const &,
std::unordered_set<parallel_layer_guid_t> const &);

MachineMappingConstraints with_additional_constraints(MachineMappingConstraints const &,
MachineMapping const &);

MachineMapping require_fully_constrained(MachineMappingConstraints const &);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
namespace = "FlexFlow"
name = "PartialMachineMapping"
name = "MachineMappingConstraints"
features = [
"eq",
"hash",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@ includes = [
"compiler/cost_estimator/cost_estimator.h",
"pcg/machine_view.dtg.h",
"pcg/machine_specification.dtg.h",
"compiler/machine_mapping/transitive_reduced_pcg.dtg.h",
]

[[fields]]
name = "transitive_reduced_pcg"
type = "::FlexFlow::TransitiveReducedPCG"

[[fields]]
name = "cost_estimator"
type = "::FlexFlow::CostEstimator"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H

#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/machine_specification.dtg.h"
#include "pcg/machine_view.dtg.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"

namespace FlexFlow {

MachineMappingProblemTree get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg,
PCGBinarySPDecomposition const &sp);

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H

#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h"
#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h"

namespace FlexFlow {

MachineMappingProblemTree
mm_problem_tree_make_series_split(AbstractedTensorSetMovement const &tensor_set_movement,
MachineMappingProblemTree const &pre,
MachineMappingProblemTree const &post);
MachineMappingProblemTree
mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs,
MachineMappingProblemTree const &rhs);
MachineMappingProblemTree mm_problem_tree_make_leaf(PCGOperatorAttrs const &);

SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &);

MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &);
MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &);
PCGOperatorAttrs require_leaf(MachineMappingProblemTree const &);

std::unordered_multiset<PCGOperatorAttrs> get_leaves(MachineMappingProblemTree const &);

template <typename Result, typename F>
Result visit(MachineMappingProblemTree const &t, F &&f) {
SPDecompositionTreeNodeType node_type = get_node_type(t);
switch (node_type) {
case SPDecompositionTreeNodeType::SERIES: {
Result result = f(require_series_split(t));
return result;
}
case SPDecompositionTreeNodeType::PARALLEL: {
Result result = f(require_parallel_split(t));
return result;
}
case SPDecompositionTreeNodeType::NODE: {
Result result = f(require_leaf(t));
return result;
}
default:
throw mk_runtime_error(fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type));
}
}

} // namespace FlexFlow

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace = "FlexFlow"
name = "MachineMappingProblemTree"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h",
]

[[fields]]
name = "raw_tree"
type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace = "FlexFlow"
name = "MMProblemTreeParallelSplit"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h",
]

[[fields]]
name = "raw_split"
type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace = "FlexFlow"
name = "MMProblemTreeParallelSplitLabel"
features = [
"eq",
"hash",
"fmt",
]

includes = []

fields = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MM_PROBLEM_TREE_SERIES_SPLIT_H

#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h"

namespace FlexFlow {

MachineMappingProblemTree get_pre_child(MMProblemTreeSeriesSplit const &);
MachineMappingProblemTree get_post_child(MMProblemTreeSeriesSplit const &);
AbstractedTensorSetMovement const &get_abstracted_tensor_movement(MMProblemTreeSeriesSplit const &);

} // namespace FlexFlow

#endif
Loading

0 comments on commit bdcc10e

Please sign in to comment.