Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 8, 2024
1 parent 75f7e98 commit a2b8832
Show file tree
Hide file tree
Showing 76 changed files with 1,146 additions and 955 deletions.
3 changes: 2 additions & 1 deletion bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ tl::expected<JsonSPModelExport, std::string>
to_v1_including_node_numbering(computation_graph);
V1ComputationGraph v1_cg = v1_result.first;
bidict<int, layer_guid_t> layer_numbering = v1_result.second;
V1BinarySPDecomposition v1_sp_decomposition = to_v1(sp_decomposition, layer_numbering);
V1BinarySPDecomposition v1_sp_decomposition =
to_v1(sp_decomposition, layer_numbering);

return JsonSPModelExport{
v1_sp_decomposition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

namespace FlexFlow {

GenericBinarySPDecompositionTreeImplementation<MachineMappingProblemTree, MMProblemTreeSeriesSplit, MMProblemTreeParallelSplit, UnmappedOpCostEstimateKey>
generic_binary_sp_impl_for_mm_problem_tree();
GenericBinarySPDecompositionTreeImplementation<MachineMappingProblemTree,
MMProblemTreeSeriesSplit,
MMProblemTreeParallelSplit,
UnmappedOpCostEstimateKey>
generic_binary_sp_impl_for_mm_problem_tree();

SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
namespace FlexFlow {

GenericBinarySPDecompositionTreeImplementation<
ComputationGraphBinarySPDecomposition,
ComputationGraphBinarySeriesSplit,
ComputationGraphBinaryParallelSplit,
layer_guid_t> generic_impl_for_computation_graph_sp_tree();
ComputationGraphBinarySPDecomposition,
ComputationGraphBinarySeriesSplit,
ComputationGraphBinaryParallelSplit,
layer_guid_t>
generic_impl_for_computation_graph_sp_tree();

SPDecompositionTreeNodeType
get_node_type(ComputationGraphBinarySPDecomposition const &);

ComputationGraphBinarySPDecomposition
computation_graph_sp_decomp_from_binary_sp_decomp(BinarySPDecompositionTree const &);
ComputationGraphBinarySPDecomposition
computation_graph_sp_decomp_from_binary_sp_decomp(
BinarySPDecompositionTree const &);

std::optional<ComputationGraphBinarySPDecomposition>
get_computation_graph_left_assoc_binary_sp_decomposition(
Expand All @@ -34,7 +36,7 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &);
std::unordered_multiset<layer_guid_t>
get_layers(ComputationGraphBinarySPDecomposition const &);

V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &,
V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &,
bidict<int, layer_guid_t> const &layer_numbering);

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

namespace FlexFlow {

BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(PCGBinaryParallelSplit const &);
BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split(
PCGBinaryParallelSplit const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

namespace FlexFlow {

BinarySeriesSplit binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &);
BinarySeriesSplit
binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@

namespace FlexFlow {

GenericBinarySPDecompositionTreeImplementation<
PCGBinarySPDecomposition,
PCGBinarySeriesSplit,
PCGBinaryParallelSplit,
parallel_layer_guid_t> generic_impl_for_pcg_sp_tree();

BinarySPDecompositionTree binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &);
GenericBinarySPDecompositionTreeImplementation<PCGBinarySPDecomposition,
PCGBinarySeriesSplit,
PCGBinaryParallelSplit,
parallel_layer_guid_t>
generic_impl_for_pcg_sp_tree();

BinarySPDecompositionTree
binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &);

std::optional<PCGBinarySPDecomposition>
get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split(
std::unordered_set<ParallelComputationGraphEdge> edges_across_split =
pcg_get_transitive_reduced_edges_across_split(tr_pcg, split);

auto get_movement_for_tensor = [&](parallel_tensor_guid_t const &t)
-> AbstractedSingleTensorMovement
{
auto get_movement_for_tensor =
[&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement {
std::unordered_set<ParallelComputationGraphEdge> tensor_edges =
filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) {
return get_parallel_tensor(e) == t;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ MachineMappingResult
}
}

MachineMappingResult result = problem_tree.visit<MachineMappingResult>(overload{
MachineMappingResult result =
problem_tree.visit<MachineMappingResult>(overload{
[&](MMProblemTreeSeriesSplit const &series_split) {
return get_optimal_machine_mapping(
result_cache,
Expand Down Expand Up @@ -86,8 +87,9 @@ MachineMappingResult
[&](BinaryTreePath const &l) -> std::unordered_set<MachineView> {
UnmappedOpCostEstimateKey leaf =
mm_problem_tree_get_subtree_at_path(
MachineMappingProblemTree{series_split}, l)
.value().get<UnmappedOpCostEstimateKey>();
MachineMappingProblemTree{series_split}, l)
.value()
.get<UnmappedOpCostEstimateKey>();
return context.allowed_machine_views(leaf, resources);
});
return transform(
Expand Down Expand Up @@ -130,7 +132,8 @@ MachineMappingResult
};

MachineMappingResult result = infeasible_machine_mapping_result();
AbstractedTensorSetMovement tensor_movement = series_split.tensor_set_movement;
AbstractedTensorSetMovement tensor_movement =
series_split.tensor_set_movement;

for (ParallelLayerGuidObliviousMachineMapping const
&assigned_pre_machine_views :
Expand Down Expand Up @@ -178,9 +181,9 @@ MachineMappingResult get_optimal_machine_mapping(

MachineMappingResult series_result = [&] {
MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{
/*tensor_set_movement=*/empty_abstracted_tensor_set_movement(),
/*left_child=*/lhs,
/*right_child=*/rhs,
/*tensor_set_movement=*/empty_abstracted_tensor_set_movement(),
/*left_child=*/lhs,
/*right_child=*/rhs,
};

return get_optimal_machine_mapping(result_cache,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,31 @@ MachineMappingProblemTree get_machine_mapping_problem_tree(
to_problem_tree =
[&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree {
return sp.visit<MachineMappingProblemTree>(overload{
[&](PCGBinarySeriesSplit const &series) {
AbstractedTensorSetMovement tensor_movement =
get_abstracted_tensor_set_movement_across_split(tr_pcg,
series);
return MachineMappingProblemTree{
MMProblemTreeSeriesSplit{
/*tensor_set_movement=*/tensor_movement,
/*lhs=*/to_problem_tree(series.get_left_child()),
/*rhs=*/to_problem_tree(series.get_right_child()),
},
};
},
[&](PCGBinaryParallelSplit const &parallel) {
return MachineMappingProblemTree{
MMProblemTreeParallelSplit{
to_problem_tree(parallel.get_left_child()),
to_problem_tree(parallel.get_right_child()),
},
};
},
[&](parallel_layer_guid_t const &leaf) {
return MachineMappingProblemTree{
get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf),
};
},
});
[&](PCGBinarySeriesSplit const &series) {
AbstractedTensorSetMovement tensor_movement =
get_abstracted_tensor_set_movement_across_split(tr_pcg, series);
return MachineMappingProblemTree{
MMProblemTreeSeriesSplit{
/*tensor_set_movement=*/tensor_movement,
/*lhs=*/to_problem_tree(series.get_left_child()),
/*rhs=*/to_problem_tree(series.get_right_child()),
},
};
},
[&](PCGBinaryParallelSplit const &parallel) {
return MachineMappingProblemTree{
MMProblemTreeParallelSplit{
to_problem_tree(parallel.get_left_child()),
to_problem_tree(parallel.get_right_child()),
},
};
},
[&](parallel_layer_guid_t const &leaf) {
return MachineMappingProblemTree{
get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf),
};
},
});
};

return to_problem_tree(sp_decomposition_tree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,69 @@

namespace FlexFlow {

GenericBinarySPDecompositionTreeImplementation<MachineMappingProblemTree, MMProblemTreeSeriesSplit, MMProblemTreeParallelSplit, UnmappedOpCostEstimateKey>
generic_binary_sp_impl_for_mm_problem_tree() {
GenericBinarySPDecompositionTreeImplementation<MachineMappingProblemTree,
MMProblemTreeSeriesSplit,
MMProblemTreeParallelSplit,
UnmappedOpCostEstimateKey>
generic_binary_sp_impl_for_mm_problem_tree() {
return GenericBinarySPDecompositionTreeImplementation<
MachineMappingProblemTree,
MMProblemTreeSeriesSplit,
MMProblemTreeParallelSplit,
UnmappedOpCostEstimateKey>{
/*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & {
return split.get_left_child();
},
/*parallel_get_left_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & {
return split.get_left_child();
},
/*series_get_right_child=*/[](MMProblemTreeSeriesSplit const &split) -> MachineMappingProblemTree const & {
return split.get_right_child();
},
/*parallel_get_right_child=*/[](MMProblemTreeParallelSplit const &split) -> MachineMappingProblemTree const & {
return split.get_right_child();
},
/*get_node_type=*/[](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType {
return get_node_type(tree);
},
/*require_series=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeSeriesSplit const & {
return tree.get<MMProblemTreeSeriesSplit>();
},
/*require_parallel=*/[](MachineMappingProblemTree const &tree) -> MMProblemTreeParallelSplit const & {
return tree.get<MMProblemTreeParallelSplit>();
},
/*require_leaf=*/[](MachineMappingProblemTree const &tree) -> UnmappedOpCostEstimateKey const & {
return tree.get<UnmappedOpCostEstimateKey>();
},
MachineMappingProblemTree,
MMProblemTreeSeriesSplit,
MMProblemTreeParallelSplit,
UnmappedOpCostEstimateKey>{
/*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split)
-> MachineMappingProblemTree const & {
return split.get_left_child();
},
/*parallel_get_left_child=*/
[](MMProblemTreeParallelSplit const &split)
-> MachineMappingProblemTree const & {
return split.get_left_child();
},
/*series_get_right_child=*/
[](MMProblemTreeSeriesSplit const &split)
-> MachineMappingProblemTree const & {
return split.get_right_child();
},
/*parallel_get_right_child=*/
[](MMProblemTreeParallelSplit const &split)
-> MachineMappingProblemTree const & {
return split.get_right_child();
},
/*get_node_type=*/
[](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType {
return get_node_type(tree);
},
/*require_series=*/
[](MachineMappingProblemTree const &tree)
-> MMProblemTreeSeriesSplit const & {
return tree.get<MMProblemTreeSeriesSplit>();
},
/*require_parallel=*/
[](MachineMappingProblemTree const &tree)
-> MMProblemTreeParallelSplit const & {
return tree.get<MMProblemTreeParallelSplit>();
},
/*require_leaf=*/
[](MachineMappingProblemTree const &tree)
-> UnmappedOpCostEstimateKey const & {
return tree.get<UnmappedOpCostEstimateKey>();
},
};
}

SPDecompositionTreeNodeType
get_node_type(MachineMappingProblemTree const &tree) {
return tree.visit<SPDecompositionTreeNodeType>(overload {
[](MMProblemTreeSeriesSplit const &) { return SPDecompositionTreeNodeType::SERIES; },
[](MMProblemTreeParallelSplit const &) { return SPDecompositionTreeNodeType::PARALLEL; },
[](UnmappedOpCostEstimateKey const &) { return SPDecompositionTreeNodeType::NODE; },
return tree.visit<SPDecompositionTreeNodeType>(overload{
[](MMProblemTreeSeriesSplit const &) {
return SPDecompositionTreeNodeType::SERIES;
},
[](MMProblemTreeParallelSplit const &) {
return SPDecompositionTreeNodeType::PARALLEL;
},
[](UnmappedOpCostEstimateKey const &) {
return SPDecompositionTreeNodeType::NODE;
},
});
}

Expand All @@ -61,7 +84,8 @@ std::unordered_set<BinaryTreePath>
std::optional<MachineMappingProblemTree>
mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree,
BinaryTreePath const &path) {
return get_subtree_at_path(tree, generic_binary_sp_impl_for_mm_problem_tree(), path);
return get_subtree_at_path(
tree, generic_binary_sp_impl_for_mm_problem_tree(), path);
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ std::unordered_set<ParallelComputationGraphEdge>
TransitiveReducedDataflowGraphView raw_tr_g =
get_underlying_transitive_reduced_dataflow_graph(tr_pcg);

BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split);
BinarySeriesSplit raw_split =
binary_series_split_from_pcg_series_split(split);

std::unordered_set<DataflowEdge> raw_edges =
get_transitive_reduced_edges_across_split(raw_tr_g, raw_split);
Expand All @@ -57,7 +58,8 @@ std::unordered_set<parallel_tensor_guid_t>
TransitiveReducedDataflowGraphView raw_tr_g =
get_underlying_transitive_reduced_dataflow_graph(tr_pcg);

BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split);
BinarySeriesSplit raw_split =
binary_series_split_from_pcg_series_split(split);

std::unordered_set<DataflowOutput> raw_outputs =
get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split);
Expand All @@ -72,7 +74,8 @@ PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split(
TransitiveReducedDataflowGraphView raw_tr_g =
get_underlying_transitive_reduced_dataflow_graph(tr_pcg);

BinarySeriesSplit raw_split = binary_series_split_from_pcg_series_split(split);
BinarySeriesSplit raw_split =
binary_series_split_from_pcg_series_split(split);

SplitBoundaryNodes raw_boundary =
get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split);
Expand Down
Loading

0 comments on commit a2b8832

Please sign in to comment.