Skip to content

Commit

Permalink
Further code simplification and polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 7, 2024
1 parent 4b180df commit dcd2e13
Show file tree
Hide file tree
Showing 201 changed files with 2,960 additions and 2,988 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H
#define _FLEXFLOW_BIN_EXPORT_MODEL_ARCH_INCLUDE_EXPORT_MODEL_ARCH_JSON_SP_MODEL_EXPORT_H

#include <nlohmann/json.hpp>
#include "export_model_arch/json_sp_model_export.dtg.h"

namespace nlohmann {

template <>
struct adl_serializer<::FlexFlow::JsonSPModelExport> {
static ::FlexFlow::JsonSPModelExport from_json(json const &);
static void to_json(json &, ::FlexFlow::JsonSPModelExport const &);
};

} // namespace nlohmann

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@ name = "JsonSPModelExport"
features = [
"eq",
"hash",
"json",
"fmt",
"json",
]

includes = [
"pcg/file_format/v1/v1_computation_graph.dtg.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.dtg.h",
]

src_includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h",
"utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/json.h",
]

[[fields]]
name = "sp_decomposition"
type = "::FlexFlow::GenericBinarySPDecompositionTree<int>"
type = "::FlexFlow::LeafOnlyBinarySPDecompositionTree<int>"

[[fields]]
name = "computation_graph"
Expand Down
4 changes: 2 additions & 2 deletions bin/export-model-arch/src/export_model_arch.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h"
#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h"
#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h"
#include "export_model_arch/json_sp_model_export.dtg.h"
#include "models/bert/bert.h"
#include "models/candle_uno/candle_uno.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "export_model_arch/json_sp_model_export.h"

using namespace ::FlexFlow;

namespace nlohmann {

JsonSPModelExport adl_serializer<JsonSPModelExport>::from_json(json const &j) {
NOT_IMPLEMENTED();
}

static void sp_decomposition_to_json(json &j, LeafOnlyBinarySPDecompositionTree<int> const &t) {
}

void adl_serializer<JsonSPModelExport>::to_json(json &j, JsonSPModelExport const &m) {
j["computation_graph"] = m.computation_graph;
sp_decomposition_to_json(j["sp_decomposition"], m.sp_decomposition);
}


} // namespace nlohmann
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h"
#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h"
#include "compiler/series_parallel/pcg_binary_series_split.dtg.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h"

namespace FlexFlow {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "compiler/cost_estimator/tensor_set_movement.dtg.h"
#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h"
#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h"
#include "compiler/series_parallel/pcg_binary_series_split.dtg.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h"

namespace FlexFlow {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "compiler/machine_mapping/machine_mapping.dtg.h"
#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h"
#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h"
#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"

namespace FlexFlow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H

#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h"
#include "compiler/series_parallel/pcg_binary_sp_decomposition.dtg.h"
#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h"
#include "pcg/machine_specification.dtg.h"
#include "pcg/machine_view.dtg.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,16 @@
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h"
#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h"
#include "utils/full_binary_tree/binary_tree_path.dtg.h"
#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h"
#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h"

namespace FlexFlow {

MachineMappingProblemTree mm_problem_tree_make_series_split(
AbstractedTensorSetMovement const &tensor_set_movement,
MachineMappingProblemTree const &pre,
MachineMappingProblemTree const &post);
MachineMappingProblemTree
mm_problem_tree_make_parallel_split(MachineMappingProblemTree const &lhs,
MachineMappingProblemTree const &rhs);
MachineMappingProblemTree
mm_problem_tree_make_leaf(UnmappedOpCostEstimateKey const &);
GenericBinarySPDecompositionTreeImplementation<MachineMappingProblemTree, MMProblemTreeSeriesSplit, MMProblemTreeParallelSplit, UnmappedOpCostEstimateKey>
generic_binary_sp_impl_for_mm_problem_tree();

SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &);

MMProblemTreeSeriesSplit
require_series_split(MachineMappingProblemTree const &);
MMProblemTreeParallelSplit
require_parallel_split(MachineMappingProblemTree const &);
UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &);

MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &);
MachineMappingProblemTree
wrap_parallel_split(MMProblemTreeParallelSplit const &);

std::unordered_multiset<UnmappedOpCostEstimateKey>
get_leaves(MachineMappingProblemTree const &);
std::unordered_set<BinaryTreePath>
Expand All @@ -40,28 +24,6 @@ std::optional<MachineMappingProblemTree>
mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &,
BinaryTreePath const &);

template <typename Result, typename F>
Result visit(MachineMappingProblemTree const &t, F &&f) {
SPDecompositionTreeNodeType node_type = get_node_type(t);
switch (node_type) {
case SPDecompositionTreeNodeType::SERIES: {
Result result = f(require_series_split(t));
return result;
}
case SPDecompositionTreeNodeType::PARALLEL: {
Result result = f(require_parallel_split(t));
return result;
}
case SPDecompositionTreeNodeType::NODE: {
Result result = f(require_leaf(t));
return result;
}
default:
throw mk_runtime_error(
fmt::format("Unknown SPDecompositionTreeNodeType: {}", node_type));
}
}

} // namespace FlexFlow

#endif

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace = "FlexFlow"
name = "MachineMappingProblemTree"
features = [
"eq",
"hash",
"fmt",
]

includes = [
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h",
]

[[values]]
type = "::FlexFlow::MMProblemTreeSeriesSplit"
key = "series"

[[values]]
type = "::FlexFlow::MMProblemTreeParallelSplit"
key = "parallel"

[[values]]
type = "::FlexFlow::UnmappedOpCostEstimateKey"
key = "leaf"

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,22 @@ features = [
"fmt",
]

includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h",
fwd_decls = [
"struct MachineMappingProblemTree",
]

post_includes = [
"compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h",
]

includes = []

[[fields]]
name = "left_child"
type = "::FlexFlow::MachineMappingProblemTree"
indirect = true

[[fields]]
name = "raw_split"
type = "::FlexFlow::GenericBinaryParallelSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>"
name = "right_child"
type = "::FlexFlow::MachineMappingProblemTree"
indirect = true

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,28 @@ features = [
"fmt",
]

fwd_decls = [
"struct MachineMappingProblemTree",
]

post_includes = [
"compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h",
]

includes = [
"utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split_label.dtg.h",
"compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h",
"compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h",
]

[[fields]]
name = "raw_split"
type = "::FlexFlow::GenericBinarySeriesSplit<::FlexFlow::MMProblemTreeSeriesSplitLabel, ::FlexFlow::MMProblemTreeParallelSplitLabel, ::FlexFlow::UnmappedOpCostEstimateKey>"
name = "tensor_set_movement"
type = "::FlexFlow::AbstractedTensorSetMovement"

[[fields]]
name = "left_child"
type = "::FlexFlow::MachineMappingProblemTree"
indirect = true

[[fields]]
name = "right_child"
type = "::FlexFlow::MachineMappingProblemTree"
indirect = true

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h"
#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h"
#include "compiler/series_parallel/pcg_binary_series_split.dtg.h"
#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h"
#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h"
#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h"
#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace = "FlexFlow"
name = "ComputationGraphBinaryParallelSplit"
features = [
"eq",
"hash",
"fmt",
]

fwd_decls = [
"struct ComputationGraphBinarySPDecomposition",
]

post_includes = [
"compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h",
]

includes = []

[[fields]]
name = "left_child"
type = "::FlexFlow::ComputationGraphBinarySPDecomposition"
indirect = true

[[fields]]
name = "right_child"
type = "::FlexFlow::ComputationGraphBinarySPDecomposition"
indirect = true
Loading

0 comments on commit dcd2e13

Please sign in to comment.