Skip to content

Commit

Permalink
Move back to templated FullBinaryTree
Browse files Browse the repository at this point in the history
  • Loading branch information
lockshaw committed Oct 3, 2024
1 parent 0c2ab05 commit e4073bc
Show file tree
Hide file tree
Showing 78 changed files with 792 additions and 462 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,18 @@ SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &tree)

MMProblemTreeSeriesSplit require_series_split(MachineMappingProblemTree const &t) {
return MMProblemTreeSeriesSplit{
require_series(t.raw_tree),
require_generic_binary_series_split(t.raw_tree),
};
}

MMProblemTreeParallelSplit require_parallel_split(MachineMappingProblemTree const &t) {
return MMProblemTreeParallelSplit{
require_parallel(t.raw_tree),
require_generic_binary_parallel_split(t.raw_tree),
};
}

UnmappedOpCostEstimateKey require_leaf(MachineMappingProblemTree const &t) {
return require_leaf(t.raw_tree);
return require_generic_binary_leaf(t.raw_tree);
}

MachineMappingProblemTree wrap_series_split(MMProblemTreeSeriesSplit const &series) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ SPDecompositionTreeNodeType
}

layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) {
return require_leaf(d.raw_tree);
return require_leaf_only_binary_leaf(d.raw_tree);
}

std::optional<ComputationGraphBinarySPDecomposition>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,18 @@ PCGBinarySPDecomposition wrap_parallel_split(PCGBinaryParallelSplit const &p) {

PCGBinarySeriesSplit require_series(PCGBinarySPDecomposition const &d) {
return PCGBinarySeriesSplit{
require_series(d.raw_tree),
require_leaf_only_binary_series_split(d.raw_tree),
};
}

PCGBinaryParallelSplit require_parallel(PCGBinarySPDecomposition const &d) {
return PCGBinaryParallelSplit{
require_parallel(d.raw_tree),
require_leaf_only_binary_parallel_split(d.raw_tree),
};
}

parallel_layer_guid_t require_leaf(PCGBinarySPDecomposition const &d) {
return require_leaf(d.raw_tree);
return require_leaf_only_binary_leaf(d.raw_tree);
}

std::unordered_set<BinaryTreePath> find_paths_to_leaf(PCGBinarySPDecomposition const &spd,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_RAW_FULL_BINARY_TREE_ANY_VALUE_TYPE_H
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H

#include <any>
#include <functional>
Expand Down Expand Up @@ -64,6 +64,6 @@ struct hash<::FlexFlow::any_value_type> {
size_t operator()(::FlexFlow::any_value_type const &) const;
};

}
} // namespace FlexFlow

#endif
32 changes: 32 additions & 0 deletions lib/utils/include/utils/fmt/monostate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H

#include <variant>
#include <fmt/format.h>

namespace fmt {

template <typename Char>
struct formatter<
::std::monostate,
Char,
std::enable_if_t<!detail::has_format_as<::std::monostate>::value>>
: formatter<::std::string> {
template <typename FormatContext>
auto format(::std::monostate const &, FormatContext &ctx)
-> decltype(ctx.out()) {
std::string result = "<monostate>";

return formatter<std::string>::format(result, ctx);
}
};

} // namespace fmt

namespace FlexFlow {

std::ostream &operator<<(std::ostream &, std::monostate const &);

} // namespace FlexFlow

#endif
26 changes: 24 additions & 2 deletions lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,37 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H

#include "utils/full_binary_tree/binary_tree_path.dtg.h"
#include "utils/full_binary_tree/full_binary_tree.dtg.h"
#include "utils/full_binary_tree/binary_tree_path.h"
#include "utils/full_binary_tree/full_binary_tree.h"
#include "utils/full_binary_tree/visit.h"
#include <unordered_set>
#include "utils/overload.h"
#include "utils/containers/transform.h"
#include "utils/containers/set_union.h"

namespace FlexFlow {

template <typename ParentLabel, typename LeafLabel>
std::unordered_set<BinaryTreePath> find_paths_to_leaf(FullBinaryTree<ParentLabel, LeafLabel> const &tree,
LeafLabel const &leaf) {
return find_paths_to_leaf(tree.raw_tree, make_any_value_type<LeafLabel>(leaf));
return visit<std::unordered_set<BinaryTreePath>>(
tree,
overload {
[&](LeafLabel const &l) -> std::unordered_set<BinaryTreePath> {
if (l == leaf) {
return {binary_tree_root_path()};
} else {
return {};
}
},
[&](FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &parent) {
return set_union(
transform(find_paths_to_leaf(get_left_child(parent), leaf),
nest_inside_left_child),
transform(find_paths_to_leaf(get_right_child(parent), leaf),
nest_inside_right_child));
}
});
}

} // namespace FlexFlow
Expand Down
47 changes: 47 additions & 0 deletions lib/utils/include/utils/full_binary_tree/fmt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FMT_H

#include "utils/full_binary_tree/full_binary_tree.h"
#include "utils/full_binary_tree/get_left_child.h"
#include "utils/full_binary_tree/get_right_child.h"
#include "utils/full_binary_tree/visit.h"
#include "utils/overload.h"
#include <fmt/format.h>

namespace FlexFlow {

template <typename ParentLabel, typename LeafLabel>
std::string format_as(FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &t) {
return fmt::format("<{} ({} {})>",
t.label,
get_left_child(t),
get_right_child(t));
}

template <typename ParentLabel, typename LeafLabel>
std::string format_as(FullBinaryTree<ParentLabel, LeafLabel> const &t) {
auto visitor = FullBinaryTreeVisitor<std::string, ParentLabel, LeafLabel>{
[](FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &parent) {
return fmt::to_string(parent);
},
[](LeafLabel const &leaf) {
return fmt::format("{}", leaf);
},
};

return visit(t, visitor);
}

template <typename ParentLabel, typename LeafLabel>
std::ostream &operator<<(std::ostream &s, FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &t) {
return (s << fmt::to_string(t));
}

template <typename ParentLabel, typename LeafLabel>
std::ostream &operator<<(std::ostream &s, FullBinaryTree<ParentLabel, LeafLabel> const &t) {
return (s << fmt::to_string(t));
}

} // namespace FlexFlow

#endif
102 changes: 102 additions & 0 deletions lib/utils/include/utils/full_binary_tree/full_binary_tree.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FULL_BINARY_TREE_H

#include <memory>
#include <variant>
#include <tuple>

namespace FlexFlow {

template <typename ParentLabel, typename LeafLabel>
struct FullBinaryTree;

template <typename ParentLabel, typename LeafLabel>
struct FullBinaryTreeParentNode {
explicit FullBinaryTreeParentNode(
ParentLabel const &label,
FullBinaryTree<ParentLabel, LeafLabel> const &lhs,
FullBinaryTree<ParentLabel, LeafLabel> const &rhs)
: label(label),
left_child_ptr(
std::make_shared<FullBinaryTree<ParentLabel, LeafLabel>>(lhs)),
right_child_ptr(
std::make_shared<FullBinaryTree<ParentLabel, LeafLabel>>(rhs))
{ }

FullBinaryTreeParentNode(FullBinaryTreeParentNode const &) = default;

bool operator==(FullBinaryTreeParentNode const &other) const {
if (this->tie_ptr() == other.tie_ptr()) {
return true;
}

return this->tie() == other.tie();
}

bool operator!=(FullBinaryTreeParentNode const &other) const {
if (this->tie_ptr() == other.tie_ptr()) {
return false;
}

return this->tie() != other.tie();
}

bool operator<(FullBinaryTreeParentNode const &other) const {
return this->tie() < other.tie();
}
public:
ParentLabel label;
std::shared_ptr<FullBinaryTree<ParentLabel, LeafLabel>> left_child_ptr;
std::shared_ptr<FullBinaryTree<ParentLabel, LeafLabel>> right_child_ptr;
private:
std::tuple<ParentLabel const &,
std::shared_ptr<FullBinaryTree<ParentLabel, LeafLabel>> const &,
std::shared_ptr<FullBinaryTree<ParentLabel, LeafLabel>> const &>
tie_ptr() const {
return std::tie(this->label, this->left_child_ptr, this->right_child_ptr);
}

std::tuple<ParentLabel const &,
FullBinaryTree<ParentLabel, LeafLabel> const &,
FullBinaryTree<ParentLabel, LeafLabel> const &>
tie() const {
return std::tie(this->label, *this->left_child_ptr, *this->right_child_ptr);
}

friend std::hash<FullBinaryTreeParentNode>;
};

template <typename ParentLabel, typename LeafLabel>
struct FullBinaryTree {
public:
FullBinaryTree() = delete;
explicit FullBinaryTree(FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &t)
: root{t} {}

explicit FullBinaryTree(LeafLabel const &t)
: root{t} {}

bool operator==(FullBinaryTree const &other) const {
return this->tie() == other.tie();
}

bool operator!=(FullBinaryTree const &other) const {
return this->tie() != other.tie();
}

bool operator<(FullBinaryTree const &other) const {
return this->tie() < other.tie();
}
public:
std::variant<FullBinaryTreeParentNode<ParentLabel, LeafLabel>, LeafLabel> root;
private:
std::tuple<decltype(root) const &> tie() const {
return std::tie(this->root);
}

friend std::hash<FullBinaryTree>;
};

} // namespace FlexFlow

#endif

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ template_params = [

includes = [
"<functional>",
"utils/full_binary_tree/full_binary_tree_parent_node.dtg.h",
"utils/full_binary_tree/full_binary_tree.h",
]

[[fields]]
Expand Down
27 changes: 24 additions & 3 deletions lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,36 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H

#include "utils/full_binary_tree/binary_tree_path.dtg.h"
#include "utils/full_binary_tree/full_binary_tree.dtg.h"
#include "utils/full_binary_tree/binary_tree_path.h"
#include "utils/full_binary_tree/full_binary_tree.h"
#include "utils/full_binary_tree/visit.h"
#include <unordered_set>
#include "utils/full_binary_tree/raw_full_binary_tree/algorithms.h"
#include "utils/overload.h"
#include "utils/containers/set_union.h"
#include "utils/containers/transform.h"

namespace FlexFlow {

template <typename ParentLabel, typename LeafLabel>
std::unordered_set<BinaryTreePath> get_all_leaf_paths(FullBinaryTree<ParentLabel, LeafLabel> const &tree) {
return get_all_leaf_paths(tree.raw_tree);
return visit<std::unordered_set<BinaryTreePath>>
(tree,
overload {
[](LeafLabel const &) {
return std::unordered_set{binary_tree_root_path()};
},
[](FullBinaryTreeParentNode<ParentLabel, LeafLabel> const &parent) {
return set_union(
transform(get_all_leaf_paths(get_left_child(parent)),
[](BinaryTreePath const &path) {
return nest_inside_left_child(path);
}),
transform(get_all_leaf_paths(get_right_child(parent)),
[](BinaryTreePath const &path) {
return nest_inside_right_child(path);
}));
}
});
}

} // namespace FlexFlow
Expand Down
Loading

0 comments on commit e4073bc

Please sign in to comment.