diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc index 6b75d3943b..6d14fbe3cf 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -46,18 +46,18 @@ SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree) MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) { return MMProblemTreeSeriesSplit{ - require_series(t.raw_tree), + require_generic_binary_series_split(t.raw_tree), }; } MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) { return MMProblemTreeParallelSplit{ - require_parallel(t.raw_tree), + require_generic_binary_parallel_split(t.raw_tree), }; } UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) { - return require_leaf(t.raw_tree); + return require_generic_binary_leaf(t.raw_tree); } MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &series) { diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc index 00d0d74959..e1c118f891 100644 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -17,7 +17,7 @@ SPDecompositionTreeNodeType } layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_leaf(d.raw_tree); + return require_leaf_only_binary_leaf(d.raw_tree); } std::optional diff --git a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc index df0245a4d2..f15bf0fe53 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg_binary_sp_decomposition.cc @@ -54,18 +54,18 @@ PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) { PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) { return PCGBinarySeriesSplit{ - require_series(d.raw_tree), + require_leaf_only_binary_series_split(d.raw_tree), }; } PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) { return PCGBinaryParallelSplit{ - require_parallel(d.raw_tree), + require_leaf_only_binary_parallel_split(d.raw_tree), }; } parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) { - return require_leaf(d.raw_tree); + return require_leaf_only_binary_leaf(d.raw_tree); } std::unordered_set find_paths_to_leaf(PCGBinarySPDecomposition const &spd, diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h similarity index 89% rename from lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h rename to lib/utils/include/utils/any_value_type/any_value_type.h index 8cd7d62101..eb211b1a1b 100644 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/any_value_type.h +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H #include #include @@ -64,6 +64,6 @@ struct hash<::FlexFlow::any_value_type> { size_t operator()(::FlexFlow::any_value_type const &) const; }; -} +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h new file mode 100644 index 0000000000..b03609171f --- /dev/null +++ b/lib/utils/include/utils/fmt/monostate.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H + +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::monostate, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::std::monostate const &, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = ""; + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &, std::monostate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h index 833013d6f6..4410f06e67 100644 --- a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -2,15 +2,37 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" #include +#include "utils/overload.h" +#include "utils/containers/transform.h" +#include "utils/containers/set_union.h" namespace FlexFlow { template std::unordered_set find_paths_to_leaf(FullBinaryTree const &tree, LeafLabel const &leaf) { - return find_paths_to_leaf(tree.raw_tree, make_any_value_type(leaf)); + return visit>( + tree, + overload { + [&](LeafLabel const &l) -> std::unordered_set { + if (l == leaf) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + [&](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(find_paths_to_leaf(get_left_child(parent), leaf), + nest_inside_left_child), + transform(find_paths_to_leaf(get_right_child(parent), leaf), + nest_inside_right_child)); + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/fmt.h b/lib/utils/include/utils/full_binary_tree/fmt.h new file mode 100644 index 0000000000..96d384c3ae --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/fmt.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::string format_as(FullBinaryTreeParentNode const &t) { + return fmt::format("<{} ({} {})>", + t.label, + get_left_child(t), + get_right_child(t)); +} + +template +std::string format_as(FullBinaryTree const &t) { + auto visitor = FullBinaryTreeVisitor{ + [](FullBinaryTreeParentNode const &parent) { + return fmt::to_string(parent); + }, + [](LeafLabel const &leaf) { + return fmt::format("{}", leaf); + }, + }; + + return visit(t, visitor); +} + +template +std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode const &t) { + return (s << fmt::to_string(t)); +} + +template +std::ostream &operator<<(std::ostream &s, FullBinaryTree const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.h b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h new file mode 100644 index 0000000000..45d0c5f151 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree.h @@ -0,0 +1,102 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct FullBinaryTree; + +template +struct FullBinaryTreeParentNode { + explicit FullBinaryTreeParentNode( + ParentLabel const &label, + FullBinaryTree const &lhs, + FullBinaryTree const &rhs) + : label(label), + left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) + { } + + FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default; + + bool operator==(FullBinaryTreeParentNode const &other) const { + if (this->tie_ptr() == other.tie_ptr()) { + return true; + } + + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTreeParentNode const &other) const { + if (this->tie_ptr() == other.tie_ptr()) { + return false; + } + + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTreeParentNode const &other) const { + return this->tie() < other.tie(); + } +public: + ParentLabel label; + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; +private: + std::tuple> const &, + std::shared_ptr> const &> + tie_ptr() const { + return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); + } + + std::tuple const &, + FullBinaryTree const &> + tie() const { + return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct FullBinaryTree { +public: + FullBinaryTree() = delete; + explicit FullBinaryTree(FullBinaryTreeParentNode const &t) + : root{t} {} + + explicit FullBinaryTree(LeafLabel const &t) + : root{t} {} + + bool operator==(FullBinaryTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(FullBinaryTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(FullBinaryTree const &other) const { + return this->tie() < other.tie(); + } +public: + std::variant, LeafLabel> root; +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml deleted file mode 100644 index aa9a1d8574..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTree" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml deleted file mode 100644 index 277405a23c..0000000000 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_parent_node.struct.toml +++ /dev/null @@ -1,20 +0,0 @@ -namespace = "FlexFlow" -name = "FullBinaryTreeParentNode" -features = [ - "eq", - "hash", - "fmt", -] - -template_params = [ - "ParentLabel", - "LeafLabel", -] - -includes = [ - "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::RawBinaryTree" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml index 0849ba2683..cb637057db 100644 --- a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -10,7 +10,7 @@ template_params = [ includes = [ "", - "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h", + "utils/full_binary_tree/full_binary_tree.h", ] [[fields]] diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h index 4076447f57..926cc0ea9c 100644 --- a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -2,15 +2,36 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" #include -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/overload.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" namespace FlexFlow { template std::unordered_set get_all_leaf_paths(FullBinaryTree const &tree) { - return get_all_leaf_paths(tree.raw_tree); + return visit> + (tree, + overload { + [](LeafLabel const &) { + return std::unordered_set{binary_tree_root_path()}; + }, + [](FullBinaryTreeParentNode const &parent) { + return set_union( + transform(get_all_leaf_paths(get_left_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(get_right_child(parent)), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h index 675e385ca3..e9ceddff6d 100644 --- a/lib/utils/include/utils/full_binary_tree/get_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -1,18 +1,26 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_left_child.h" +#include "utils/full_binary_tree/get_right_child.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/exception.h" +#include namespace FlexFlow { template FullBinaryTree get_child(FullBinaryTreeParentNode const &t, BinaryTreePathEntry const &e) { - return FullBinaryTreeParentNode{ - get_child(t.raw_tree, e), - }; + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return get_left_child(t); + case BinaryTreePathEntry::RIGHT_CHILD: + return get_right_child(t); + default: + throw mk_runtime_error(fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_label.h b/lib/utils/include/utils/full_binary_tree/get_label.h index 9f0099e609..1b48965b01 100644 --- a/lib/utils/include/utils/full_binary_tree/get_label.h +++ b/lib/utils/include/utils/full_binary_tree/get_label.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LABEL_H -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template ParentLabel get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &p) { - return p.raw_tree.label.template get(); + return p.label; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h index 41fea3c5c2..c58a850a6d 100644 --- a/lib/utils/include/utils/full_binary_tree/get_leaves.h +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -1,15 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include +#include "utils/containers/multiset_union.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(FullBinaryTree const &t) { - return transform(get_leaves(t.raw_tree), [](any_value_type const &v) { return v.get(); }); +template +std::unordered_multiset + get_leaves(FullBinaryTree const &t) { + return visit>( + t, + overload { + [](FullBinaryTreeParentNode const &parent) { + return multiset_union(get_leaves(get_left_child(parent)), + get_leaves(get_right_child(parent))); + }, + [](ChildLabel const &leaf) { + return std::unordered_multiset{leaf}; + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_left_child.h b/lib/utils/include/utils/full_binary_tree/get_left_child.h index 394b9042fe..163503abfd 100644 --- a/lib/utils/include/utils/full_binary_tree/get_left_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_left_child.h @@ -1,16 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEFT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTree get_left_child(FullBinaryTreeParentNode const &t) { - return FullBinaryTree{ - t.raw_tree.left_child(), - }; +FullBinaryTree const &get_left_child(FullBinaryTreeParentNode const &t) { + return *t.left_child_ptr; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_node_type.h b/lib/utils/include/utils/full_binary_tree/get_node_type.h index 5d2c613101..0ee8eea6d8 100644 --- a/lib/utils/include/utils/full_binary_tree/get_node_type.h +++ b/lib/utils/include/utils/full_binary_tree/get_node_type.h @@ -1,14 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NODE_TYPE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" namespace FlexFlow { template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &t) { - return get_node_type(t.raw_tree); + if (std::holds_alternative(t.root)) { + return FullBinaryTreeNodeType::LEAF; + } else { + bool is_parent = std::holds_alternative>(t.root); + assert (is_parent); + + return FullBinaryTreeNodeType::PARENT; + } } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_right_child.h b/lib/utils/include/utils/full_binary_tree/get_right_child.h index 957ddbede8..e40f2024a1 100644 --- a/lib/utils/include/utils/full_binary_tree/get_right_child.h +++ b/lib/utils/include/utils/full_binary_tree/get_right_child.h @@ -1,16 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_RIGHT_CHILD_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTree get_right_child(FullBinaryTreeParentNode const &t) { - return FullBinaryTree{ - t.raw_tree.right_child(), - }; +FullBinaryTree const &get_right_child(FullBinaryTreeParentNode const &t) { + return *t.right_child_ptr; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h index 59d24b6aad..6909d9e1ef 100644 --- a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -2,19 +2,35 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H #include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/containers/transform.h" -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include namespace FlexFlow { template std::optional> get_subtree_at_path(FullBinaryTree const &t, BinaryTreePath const &p) { - return transform(get_subtree_at_path(t.raw_tree, p), - [](RawBinaryTree const &raw) { - return FullBinaryTree{raw}; - }); + if (p == binary_tree_root_path()) { + return t; + } + + return visit>>( + t, + overload { + [&](FullBinaryTreeParentNode const &parent) { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, curr), rest); + }, + [&](LeafLabel const &leaf) { + return std::nullopt; + } + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/hash.h b/lib/utils/include/utils/full_binary_tree/hash.h new file mode 100644 index 0000000000..a29836f972 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/hash.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_HASH_H + +#include "utils/full_binary_tree/full_binary_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::FullBinaryTreeParentNode> { + size_t operator()(::FlexFlow::FullBinaryTreeParentNode const &t) const { + return get_std_hash(t.tie()); + } +}; + +template +struct hash<::FlexFlow::FullBinaryTree> { + size_t operator()(::FlexFlow::FullBinaryTree const &t) const { + return get_std_hash(t.tie()); + } +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/json.h b/lib/utils/include/utils/full_binary_tree/json.h index 585c05813e..0d830890dc 100644 --- a/lib/utils/include/utils/full_binary_tree/json.h +++ b/lib/utils/include/utils/full_binary_tree/json.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_JSON_H #include "utils/exception.h" -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" #include "utils/full_binary_tree/get_left_child.h" #include "utils/full_binary_tree/get_right_child.h" #include "utils/full_binary_tree/visit.h" diff --git a/lib/utils/include/utils/full_binary_tree/make.h b/lib/utils/include/utils/full_binary_tree/make.h index ac458f0f4d..a4ef47c7df 100644 --- a/lib/utils/include/utils/full_binary_tree/make.h +++ b/lib/utils/include/utils/full_binary_tree/make.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_MAKE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { @@ -10,14 +10,18 @@ FullBinaryTree make_full_binary_tree_parent(ParentLabel FullBinaryTree const &lhs, FullBinaryTree const &rhs) { return FullBinaryTree{ - raw_binary_tree_make_parent(make_any_value_type(label), lhs.raw_tree, rhs.raw_tree), + FullBinaryTreeParentNode{ + label, + lhs, + rhs, + }, }; } template FullBinaryTree make_full_binary_tree_leaf(LeafLabel const &label) { return FullBinaryTree{ - raw_binary_tree_make_leaf(make_any_value_type(label)), + label, }; } diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h deleted file mode 100644 index 6d0d77caa9..0000000000 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/algorithms.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ALGORITHMS_H - -#include "utils/full_binary_tree/binary_tree_path.dtg.h" -#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_node_type.dtg.h" -#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" -#include - -namespace FlexFlow { - -RawBinaryTree get_child(RawBinaryTree const &, BinaryTreePathEntry const &); -std::unordered_set get_all_leaf_paths(RawBinaryTree const &); -std::unordered_set find_paths_to_leaf(RawBinaryTree const &, any_value_type const &leaf); -std::unordered_multiset get_leaves(RawBinaryTree const &); -FullBinaryTreeNodeType get_node_type(RawBinaryTree const &); -std::optional get_subtree_at_path(RawBinaryTree const &, BinaryTreePath const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h b/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h deleted file mode 100644 index 0bebe12109..0000000000 --- a/lib/utils/include/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_RAW_BINARY_TREE_H - -#include -#include -#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" -#include - -namespace FlexFlow { - -struct RawBinaryTree { - explicit RawBinaryTree( - any_value_type const &label, - RawBinaryTree const &lhs, - RawBinaryTree const &rhs); - explicit RawBinaryTree( - any_value_type const &label); - - RawBinaryTree(RawBinaryTree const &) = default; - - bool operator==(RawBinaryTree const &) const; - bool operator!=(RawBinaryTree const &) const; - - RawBinaryTree const &left_child() const; - RawBinaryTree const &right_child() const; - - bool is_leaf() const; -public: - any_value_type label; - std::shared_ptr left_child_ptr; - std::shared_ptr right_child_ptr; -private: - std::tuple, - std::optional> - value_tie() const; - std::tuple const &, - std::shared_ptr const &> - ptr_tie() const; - - friend std::hash; -}; - -std::string format_as(RawBinaryTree const &); -std::ostream &operator<<(std::ostream &, RawBinaryTree const &); - -RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label); -RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::RawBinaryTree> { - size_t operator()(::FlexFlow::RawBinaryTree const &) const; -}; - -} - -#endif diff --git a/lib/utils/include/utils/full_binary_tree/require.h b/lib/utils/include/utils/full_binary_tree/require.h index f897908c86..f7be417945 100644 --- a/lib/utils/include/utils/full_binary_tree/require.h +++ b/lib/utils/include/utils/full_binary_tree/require.h @@ -1,29 +1,18 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_REQUIRE_H -#include "utils/full_binary_tree/full_binary_tree.dtg.h" -#include "utils/full_binary_tree/full_binary_tree_parent_node.dtg.h" +#include "utils/full_binary_tree/full_binary_tree.h" namespace FlexFlow { template -FullBinaryTreeParentNode require_parent_node(FullBinaryTree const &t) { - if (t.raw_tree.is_leaf()) { - throw mk_runtime_error(fmt::format("require_parent_node called on leaf node {}", t)); - } - - return FullBinaryTreeParentNode{ - t.raw_tree, - }; +FullBinaryTreeParentNode const &require_full_binary_tree_parent_node(FullBinaryTree const &t) { + return std::get>(t.root); } template -LeafLabel require_leaf(FullBinaryTree const &t) { - if (!t.raw_tree.is_leaf()) { - throw mk_runtime_error(fmt::format("require_leaf called on non-leaf node {}", t)); - } - - return t.raw_tree.label.template get(); +LeafLabel const &require_full_binary_tree_leaf(FullBinaryTree const &t) { + return std::get(t.root); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h index ea5729bd6c..860e60fcca 100644 --- a/lib/utils/include/utils/full_binary_tree/visit.h +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -4,6 +4,7 @@ #include "utils/full_binary_tree/get_node_type.h" #include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" #include "utils/exception.h" +#include "utils/full_binary_tree/require.h" namespace FlexFlow { @@ -21,9 +22,9 @@ Result visit(FullBinaryTree const &t, FullBinaryTreeVisi FullBinaryTreeNodeType node_type = get_node_type(t); switch (node_type) { case FullBinaryTreeNodeType::PARENT: - return v.parent_func(require_parent_node(t)); + return v.parent_func(require_full_binary_tree_parent_node(t)); case FullBinaryTreeNodeType::LEAF: - return v.leaf_func(require_leaf(t)); + return v.leaf_func(require_full_binary_tree_leaf(t)); default: throw mk_runtime_error(fmt::format("Unhandled FullBinaryTreeNodeType value: {}", node_type)); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml index 9734912f35..00c49992ef 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.struct.toml @@ -14,12 +14,14 @@ template_params = [ ] includes = [ - "utils/full_binary_tree/full_binary_tree.dtg.h", + "utils/full_binary_tree/full_binary_tree.h", "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.dtg.h", ] src_includes = [ "utils/full_binary_tree/json.h", + "utils/full_binary_tree/hash.h", + "utils/full_binary_tree/fmt.h", ] [[fields]] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h index c856f35d68..0c08a0462b 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h @@ -10,11 +10,39 @@ namespace FlexFlow { template SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &label) { return label.template visit(overload { - [](SeriesLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, - [](ParallelLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, + [](GenericBinarySeriesSplitLabel const &) { return SPDecompositionTreeNodeType::SERIES; }, + [](GenericBinaryParallelSplitLabel const &) { return SPDecompositionTreeNodeType::PARALLEL; }, }); } +template +GenericBinarySPSplitLabel make_generic_binary_series_split_label(SeriesLabel const &label) { + return GenericBinarySPSplitLabel{ + GenericBinarySeriesSplitLabel{ + label, + }, + }; +} + +template +GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(ParallelLabel const &label) { + return GenericBinarySPSplitLabel{ + GenericBinaryParallelSplitLabel{ + label, + }, + }; +} + +template +SeriesLabel require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &label) { + return label.template get>().raw_label; +} + +template +ParallelLabel require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &label) { + return label.template get>().raw_label; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml index c50a7b878b..c528c61f37 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.variant.toml @@ -12,15 +12,15 @@ template_params = [ "ParallelSplitLabel", ] -# includes = [ -# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", -# "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", -# ] +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split_label.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split_label.dtg.h", +] [[values]] -type = "SeriesSplitLabel" +type = "::FlexFlow::GenericBinarySeriesSplitLabel" key = "series" [[values]] -type = "ParallelSplitLabel" +type = "::FlexFlow::GenericBinaryParallelSplitLabel" key = "parallel" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h index c46be1c651..1dedf581fe 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -4,6 +4,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/full_binary_tree/get_label.h" #include "utils/full_binary_tree/visit.h" #include "utils/overload.h" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h index 20ea7e744e..98382c78c8 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h" #include "utils/full_binary_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" namespace FlexFlow { @@ -13,7 +14,7 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - GenericBinarySPSplitLabel{label}, + make_generic_binary_series_split_label(label), lhs.raw_tree, rhs.raw_tree), }; @@ -26,7 +27,7 @@ GenericBinarySPDecompositionTree make_gen GenericBinarySPDecompositionTree const &rhs) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - GenericBinarySPSplitLabel{label}, + make_generic_binary_parallel_split_label(label), lhs.raw_tree, rhs.raw_tree), }; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h index b8b18c4125..4961dc7b61 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" #include "utils/full_binary_tree/require.h" #include "utils/full_binary_tree/get_label.h" @@ -10,11 +11,11 @@ namespace FlexFlow { template GenericBinarySeriesSplit - require_series(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + require_generic_binary_series_split(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinarySeriesSplit{ - /*label=*/get_full_binary_tree_parent_label(parent).template get(), + /*label=*/require_generic_binary_series_split_label(get_full_binary_tree_parent_label(parent)), /*pre=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, @@ -26,11 +27,11 @@ GenericBinarySeriesSplit template GenericBinaryParallelSplit - require_parallel(GenericBinarySPDecompositionTree const &t) { - FullBinaryTreeParentNode, LeafLabel> parent = require_parent_node(t.raw_tree); + require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &t) { + FullBinaryTreeParentNode, LeafLabel> parent = require_full_binary_tree_parent_node(t.raw_tree); return GenericBinaryParallelSplit{ - /*label=*/get_full_binary_tree_parent_label(parent).template get(), + /*label=*/require_generic_binary_parallel_split_label(get_full_binary_tree_parent_label(parent)), /*lhs=*/GenericBinarySPDecompositionTree{ get_left_child(parent), }, @@ -41,8 +42,8 @@ GenericBinaryParallelSplit } template -LeafLabel require_leaf(GenericBinarySPDecompositionTree const &t) { - return require_leaf(t.raw_tree); +LeafLabel require_generic_binary_leaf(GenericBinarySPDecompositionTree const &t) { + return require_full_binary_tree_leaf(t.raw_tree); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index a56ed952e9..a1ac10a6a0 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -13,15 +13,15 @@ Result visit(GenericBinarySPDecompositionTree wrap_series_split(GenericBinarySeriesSplit const &series_split) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - /*label=*/GenericBinarySPSplitLabel{series_split.label}, + /*label=*/make_generic_binary_series_split_label(series_split.label), /*lhs=*/series_split.pre.raw_tree, /*rhs=*/series_split.post.raw_tree), }; @@ -23,7 +25,7 @@ GenericBinarySPDecompositionTree wrap_parallel_split(GenericBinaryParallelSplit const ¶llel_split) { return GenericBinarySPDecompositionTree{ make_full_binary_tree_parent( - /*label=*/GenericBinarySPSplitLabel{parallel_split.label}, + /*label=*/make_generic_binary_parallel_split_label(parallel_split.label), /*lhs=*/parallel_split.lhs.raw_tree, /*rhs=*/parallel_split.rhs.raw_tree), }; diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml deleted file mode 100644 index 0506d36227..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinaryParallelSplitLabel" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml deleted file mode 100644 index b780bfeea6..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.struct.toml +++ /dev/null @@ -1,12 +0,0 @@ -namespace = "FlexFlow" -name = "LeafOnlyBinarySeriesSplitLabel" -features = [ - "eq", - "ord", - "hash", - "fmt", - "json", - "rapidcheck", -] - -fields = [] diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml index dacab0244a..bf52ecc6df 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree.struct.toml @@ -12,10 +12,13 @@ template_params = [ includes = [ "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split_label.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split_label.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/monostate.h", ] [[fields]] name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::LeafOnlyBinarySeriesSplitLabel, ::FlexFlow::LeafOnlyBinaryParallelSplitLabel, LeafLabel>" +type = "::FlexFlow::GenericBinarySPDecompositionTree" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h index a9dcb17f0d..3297a30ec7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h @@ -11,7 +11,7 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnl LeafOnlyBinarySPDecompositionTree const &post) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_series_split( - LeafOnlyBinarySeriesSplitLabel{}, + std::monostate{}, pre.raw_tree, post.raw_tree), }; @@ -22,7 +22,7 @@ LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafO LeafOnlyBinarySPDecompositionTree const &rhs) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_parallel_split( - LeafOnlyBinaryParallelSplitLabel{}, + std::monostate{}, lhs.raw_tree, rhs.raw_tree), }; @@ -32,8 +32,8 @@ template LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(LeafLabel const &label) { return LeafOnlyBinarySPDecompositionTree{ make_generic_binary_sp_leaf< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>(label), }; } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h index 65d42eee7c..400b6be1de 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h @@ -12,12 +12,12 @@ namespace FlexFlow { template LeafOnlyBinarySeriesSplit - require_series(LeafOnlyBinarySPDecompositionTree const &t) { + require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &t) { GenericBinarySeriesSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel> raw = - require_series(t.raw_tree); + require_generic_binary_series_split(t.raw_tree); return LeafOnlyBinarySeriesSplit{ LeafOnlyBinarySPDecompositionTree{raw.pre}, @@ -27,12 +27,12 @@ LeafOnlyBinarySeriesSplit template LeafOnlyBinaryParallelSplit - require_parallel(LeafOnlyBinarySPDecompositionTree const &t) { + require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &t) { GenericBinaryParallelSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel> raw = - require_parallel(t.raw_tree); + require_generic_binary_parallel_split(t.raw_tree); return LeafOnlyBinaryParallelSplit{ LeafOnlyBinarySPDecompositionTree{raw.lhs}, @@ -41,8 +41,8 @@ LeafOnlyBinaryParallelSplit } template -LeafLabel require_leaf(LeafOnlyBinarySPDecompositionTree const &t) { - return require_leaf(t.raw_tree); +LeafLabel require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &t) { + return require_generic_binary_leaf(t.raw_tree); } diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h index b4f4239d39..4cbd2b26bd 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h @@ -30,18 +30,18 @@ template LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &t, LeafOnlyBinarySPDecompositionTreeVisitor const &visitor) { using GenericVisitor = GenericBinarySPDecompositionTreeVisitor - ; GenericVisitor generic_visitor = GenericVisitor{ - [&](LeafOnlyBinarySeriesSplitLabel const &x) { + [&](std::monostate const &x) { return x; }, - [&](LeafOnlyBinaryParallelSplitLabel const &x) { + [&](std::monostate const &x) { return x; }, [&](LeafLabel const &t) { diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h index 0284f6ba41..21fae97633 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h @@ -14,10 +14,10 @@ LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySer return LeafOnlyBinarySPDecompositionTree{ wrap_series_split( GenericBinarySeriesSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>{ - LeafOnlyBinarySeriesSplitLabel{}, + std::monostate{}, split.pre.raw_tree, split.post.raw_tree, } @@ -30,10 +30,10 @@ LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryP return LeafOnlyBinarySPDecompositionTree{ wrap_parallel_split( GenericBinaryParallelSplit< - LeafOnlyBinarySeriesSplitLabel, - LeafOnlyBinaryParallelSplitLabel, + std::monostate, + std::monostate, LeafLabel>{ - LeafOnlyBinaryParallelSplitLabel{}, + std::monostate{}, split.lhs.raw_tree, split.rhs.raw_tree, } diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc similarity index 93% rename from lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc rename to lib/utils/src/utils/any_value_type/any_value_type.cc index d54796ae49..b3a72dafa9 100644 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/any_value_type.cc +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -1,4 +1,4 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/any_value_type.h" +#include "utils/any_value_type/any_value_type.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/fmt/monostate.cc b/lib/utils/src/utils/fmt/monostate.cc new file mode 100644 index 0000000000..55988cdce0 --- /dev/null +++ b/lib/utils/src/utils/fmt/monostate.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/monostate.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::monostate const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/fmt.cc b/lib/utils/src/utils/full_binary_tree/fmt.cc new file mode 100644 index 0000000000..82bf382821 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/fmt.cc @@ -0,0 +1,10 @@ +#include "utils/full_binary_tree/fmt.h" + +namespace FlexFlow { + +template std::string format_as(FullBinaryTreeParentNode const &); +template std::string format_as(FullBinaryTree const &); +template std::ostream &operator<<(std::ostream &, FullBinaryTreeParentNode const &); +template std::ostream &operator<<(std::ostream &, FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_label.cc b/lib/utils/src/utils/full_binary_tree/get_label.cc new file mode 100644 index 0000000000..25ed6cf3f6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_label.cc @@ -0,0 +1,8 @@ +#include "utils/full_binary_tree/get_label.h" + +namespace FlexFlow { + +template + int get_full_binary_tree_parent_label(FullBinaryTreeParentNode const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_node_type.cc b/lib/utils/src/utils/full_binary_tree/get_node_type.cc new file mode 100644 index 0000000000..a4c88a03f3 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_node_type.cc @@ -0,0 +1,7 @@ +#include "utils/full_binary_tree/get_node_type.h" + +namespace FlexFlow { + +template FullBinaryTreeNodeType get_node_type(FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/make.cc b/lib/utils/src/utils/full_binary_tree/make.cc new file mode 100644 index 0000000000..da48d2a2c4 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/make.cc @@ -0,0 +1,12 @@ +#include "utils/full_binary_tree/make.h" + +namespace FlexFlow { + +template + FullBinaryTree make_full_binary_tree_parent(int const &, + FullBinaryTree const &, + FullBinaryTree const &); +template + FullBinaryTree make_full_binary_tree_leaf(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc deleted file mode 100644 index bc833f95d4..0000000000 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/algorithms.cc +++ /dev/null @@ -1,83 +0,0 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h" -#include "utils/full_binary_tree/binary_tree_path.h" -#include "utils/containers/transform.h" -#include "utils/containers/set_union.h" -#include "utils/containers/multiset_union.h" - -namespace FlexFlow { - -RawBinaryTree get_child(RawBinaryTree const &t, BinaryTreePathEntry const &e) { - if (e == BinaryTreePathEntry::LEFT_CHILD) { - return t.left_child(); - } else { - assert (e == BinaryTreePathEntry::RIGHT_CHILD); - return t.right_child(); - } -} - -std::unordered_set get_all_leaf_paths(RawBinaryTree const &t) { - if (t.is_leaf()) { - return {binary_tree_root_path()}; - } else { - return set_union( - transform(get_all_leaf_paths(t.left_child()), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(get_all_leaf_paths(t.right_child()), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } -} - -std::unordered_set find_paths_to_leaf(RawBinaryTree const &t, any_value_type const &leaf) { - if (t.is_leaf()) { - if (t.label == leaf) { - return {binary_tree_root_path()}; - } else { - return {}; - } - } else { - return set_union( - transform(find_paths_to_leaf(t.left_child(), leaf), - [](BinaryTreePath const &path) { - return nest_inside_left_child(path); - }), - transform(find_paths_to_leaf(t.right_child(), leaf), - [](BinaryTreePath const &path) { - return nest_inside_right_child(path); - })); - } -} - -std::unordered_multiset get_leaves(RawBinaryTree const &t) { - if (t.is_leaf()) { - return {t.label}; - } else { - return multiset_union(get_leaves(t.left_child()), get_leaves(t.right_child())); - } -} - -FullBinaryTreeNodeType get_node_type(RawBinaryTree const &t) { - if (t.is_leaf()) { - return FullBinaryTreeNodeType::LEAF; - } else { - return FullBinaryTreeNodeType::PARENT; - } -} - -std::optional get_subtree_at_path(RawBinaryTree const &t, BinaryTreePath const &p) { - if (p == binary_tree_root_path()) { - return t; - } else if (t.is_leaf()) { - return std::nullopt; - } else { - BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); - BinaryTreePath rest = binary_tree_path_get_non_top_level(p); - - return get_subtree_at_path(get_child(t, curr), rest); - } -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc b/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc deleted file mode 100644 index d432d32eb9..0000000000 --- a/lib/utils/src/utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.cc +++ /dev/null @@ -1,101 +0,0 @@ -#include "utils/full_binary_tree/raw_full_binary_tree/raw_binary_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace FlexFlow { - -RawBinaryTree::RawBinaryTree( - any_value_type const &label, - RawBinaryTree const &lhs, - RawBinaryTree const &rhs) - : label(label), - left_child_ptr(std::make_shared(lhs)), - right_child_ptr(std::make_shared(rhs)) -{ } - -RawBinaryTree::RawBinaryTree( - any_value_type const &label) - : label(label), left_child_ptr(nullptr), right_child_ptr(nullptr) -{ } - -bool RawBinaryTree::operator==(RawBinaryTree const &other) const { - if (this->ptr_tie() == other.ptr_tie()) { - return true; - } - - return (this->value_tie() == other.value_tie()); -} - -bool RawBinaryTree::operator!=(RawBinaryTree const &other) const { - if (this->ptr_tie() == other.ptr_tie()) { - return false; - } - - return (this->value_tie() != other.value_tie()); -} - -RawBinaryTree const &RawBinaryTree::left_child() const { - return *this->left_child_ptr; -} - -RawBinaryTree const &RawBinaryTree::right_child() const { - return *this->right_child_ptr; -} - -bool RawBinaryTree::is_leaf() const { - return this->left_child_ptr == nullptr && this->right_child_ptr == nullptr; -} - -std::tuple, - std::optional> - RawBinaryTree::value_tie() const { - - auto ptr_to_optional = [](std::shared_ptr const &ptr) - -> std::optional { - if (ptr == nullptr) { - return std::nullopt; - } else { - return *ptr; - } - }; - - return {this->label, ptr_to_optional(this->left_child_ptr), ptr_to_optional(this->right_child_ptr)}; -} - -std::tuple const &, - std::shared_ptr const &> - RawBinaryTree::ptr_tie() const { - return std::tie(this->label, this->left_child_ptr, this->right_child_ptr); -} - -std::string format_as(RawBinaryTree const &t) { - if (t.is_leaf()) { - return fmt::to_string(t.label); - } else { - return fmt::format("({} {} {})", t.label, t.left_child(), t.right_child()); - } -} - -std::ostream &operator<<(std::ostream &s, RawBinaryTree const &t) { - return (s << fmt::to_string(t)); -} - -RawBinaryTree raw_binary_tree_make_leaf(any_value_type const &label) { - return RawBinaryTree{label}; -} - -RawBinaryTree raw_binary_tree_make_parent(any_value_type const &label, RawBinaryTree const &lhs, RawBinaryTree const &rhs) { - return RawBinaryTree{label, lhs, rhs}; -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::RawBinaryTree>::operator()(::FlexFlow::RawBinaryTree const &t) const { - return ::FlexFlow::get_std_hash(t.value_tie()); -} - -} // namespace std diff --git a/lib/utils/src/utils/full_binary_tree/require.cc b/lib/utils/src/utils/full_binary_tree/require.cc new file mode 100644 index 0000000000..d6b0bbeb68 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/require.cc @@ -0,0 +1,11 @@ +#include "utils/full_binary_tree/require.h" + +namespace FlexFlow { + +template + FullBinaryTreeParentNode const & + require_full_binary_tree_parent_node(FullBinaryTree const &); +template + int const &require_full_binary_tree_leaf(FullBinaryTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc new file mode 100644 index 0000000000..b43eb5bce6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -0,0 +1,8 @@ +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template + int visit(FullBinaryTree const &, FullBinaryTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 2f51762db2..92a46c030d 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -44,18 +44,18 @@ std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { BinarySeriesSplit require_series(BinarySPDecompositionTree const &tt) { return BinarySeriesSplit{ - require_series(tt.raw_tree), + require_leaf_only_binary_series_split(tt.raw_tree), }; } BinaryParallelSplit require_parallel(BinarySPDecompositionTree const &tt) { return BinaryParallelSplit{ - require_parallel(tt.raw_tree), + require_leaf_only_binary_parallel_split(tt.raw_tree), }; } Node require_leaf(BinarySPDecompositionTree const &tt) { - return require_leaf(tt.raw_tree); + return require_leaf_only_binary_leaf(tt.raw_tree); } SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &tt) { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..2ecd4c94d2 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" + +namespace FlexFlow { + +template + std::unordered_set find_paths_to_leaf(GenericBinarySPDecompositionTree const &, + int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc new file mode 100644 index 0000000000..10bbc60c6d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.cc @@ -0,0 +1,16 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_split_label.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType get_node_type(GenericBinarySPSplitLabel const &); +template + GenericBinarySPSplitLabel make_generic_binary_series_split_label(int const &); +template + GenericBinarySPSplitLabel make_generic_binary_parallel_split_label(int const &); +template + int require_generic_binary_series_split_label(GenericBinarySPSplitLabel const &); +template + int require_generic_binary_parallel_split_label(GenericBinarySPSplitLabel const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..31e664b726 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" + +namespace FlexFlow { + +template + std::unordered_set get_all_leaf_paths(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 71b67acc54..20ba3fa5d7 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1 +1,13 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template + std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &); +template + std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &); +template + std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc index 227e5bd79c..783a7a974b 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc index 1618128226..9d652d44da 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 05ec6b5925..6c67fdc244 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" + +namespace FlexFlow { + +template + int get_num_tree_nodes(GenericBinarySPDecompositionTree const &); +template + int get_num_tree_nodes(GenericBinarySeriesSplit const &); +template + int get_num_tree_nodes(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc index f168ba1e2f..03c154fb67 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..6bfb573359 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" + +namespace FlexFlow { + +template + std::optional> + get_subtree_at_path(GenericBinarySPDecompositionTree const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc index 3da024743c..5e5b768ed7 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -1 +1,12 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" + +namespace FlexFlow { + +template + bool is_series_split(GenericBinarySPDecompositionTree const &); +template + bool is_parallel_split(GenericBinarySPDecompositionTree const &); +template + bool is_leaf(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 8fe9397003..87ae55b900 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index d202f55964..5a40a3b6bf 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1 +1,9 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc index fb1532b3ef..a36ccce359 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree make_generic_binary_series_split( + int const &, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template + GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + int const &label, + GenericBinarySPDecompositionTree const &, + GenericBinarySPDecompositionTree const &); +template + GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc index 3fee45fcf5..8305a1243e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -1 +1,14 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" + +namespace FlexFlow { + +template + GenericBinarySeriesSplit + require_generic_binary_series_split(GenericBinarySPDecompositionTree const &); +template + GenericBinaryParallelSplit + require_generic_binary_parallel_split(GenericBinarySPDecompositionTree const &); +template + int require_generic_binary_leaf(GenericBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc index cabd66cff7..4495a60f92 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -1 +1,10 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template + GenericBinarySeriesSplit + transform(GenericBinarySeriesSplit const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc new file mode 100644 index 0000000000..0b3189b47b --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/wrap.h" + +namespace FlexFlow { + +template + GenericBinarySPDecompositionTree + wrap_series_split(GenericBinarySeriesSplit const &); +template + GenericBinarySPDecompositionTree + wrap_parallel_split(GenericBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..41accc79d0 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_leaves.h" + +namespace FlexFlow { + +template + std::unordered_multiset get_leaves(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc new file mode 100644 index 0000000000..0959a42f01 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/get_node_type.h" + +namespace FlexFlow { + +template + SPDecompositionTreeNodeType get_node_type(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..dd94936997 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_left_associative(LeafOnlyBinarySPDecompositionTree const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..46b89aa98f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,8 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" + +namespace FlexFlow { + +template + bool is_binary_sp_tree_right_associative(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc new file mode 100644 index 0000000000..5690ebe8a8 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_parallel_split.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinaryParallelSplit const &); +template + LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinaryParallelSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc new file mode 100644 index 0000000000..ed0e5892da --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/leaf_only_binary_series_split.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree get_left_child(LeafOnlyBinarySeriesSplit const &); +template + LeafOnlyBinarySPDecompositionTree get_right_child(LeafOnlyBinarySeriesSplit const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc new file mode 100644 index 0000000000..602aebc7e8 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.cc @@ -0,0 +1,14 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_series_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_parallel_split(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTree const &); +template + LeafOnlyBinarySPDecompositionTree leaf_only_make_leaf_node(int const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc new file mode 100644 index 0000000000..1a1cd9909d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.cc @@ -0,0 +1,9 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/require.h" + +namespace FlexFlow { + +template LeafOnlyBinarySeriesSplit require_leaf_only_binary_series_split(LeafOnlyBinarySPDecompositionTree const &); +template LeafOnlyBinaryParallelSplit require_leaf_only_binary_parallel_split(LeafOnlyBinarySPDecompositionTree const &); +template int require_leaf_only_binary_leaf(LeafOnlyBinarySPDecompositionTree const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..22dd5e0db5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1,16 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/transform.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySeriesSplit transform(LeafOnlyBinarySeriesSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); +template + LeafOnlyBinaryParallelSplit transform(LeafOnlyBinaryParallelSplit const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); + +template + LeafOnlyBinarySPDecompositionTree transform(LeafOnlyBinarySPDecompositionTree const &, + LeafOnlyBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc new file mode 100644 index 0000000000..3836124eb6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.cc @@ -0,0 +1,10 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/leaf_only_binary_sp_decomposition_tree/wrap.h" + +namespace FlexFlow { + +template + LeafOnlyBinarySPDecompositionTree wrap_series_split(LeafOnlyBinarySeriesSplit const &); +template + LeafOnlyBinarySPDecompositionTree wrap_parallel_split(LeafOnlyBinaryParallelSplit const &); + +} // namespace FlexFlow