diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index f5d4c788af..f9ad92a147 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,8 +1,8 @@ -add_subdirectory(pcg) -add_subdirectory(compiler) +#add_subdirectory(pcg) +#add_subdirectory(compiler) # add_subdirectory(runtime) -add_subdirectory(op-attrs) -add_subdirectory(kernels) +#add_subdirectory(op-attrs) +#add_subdirectory(kernels) add_subdirectory(utils) # add_subdirectory(ffi) -add_subdirectory(substitutions) +#add_subdirectory(substitutions) diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index 9c8012543a..f92ae58637 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -36,4 +36,4 @@ define_ff_vars(${project_target}) ff_set_cxx_properties(${project_target}) add_subdirectory(ffi) -# add_subdirectory(test) +add_subdirectory(test) diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index c75876ecf1..df156c9060 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -4,6 +4,7 @@ #include "bidict.h" #include "invoke.h" #include "optional.h" +#include "utils/exception.h" #include #include #include @@ -137,7 +138,7 @@ template map_keys(std::unordered_map const &m, F const &f) { std::unordered_map result; - for (auto const &kv : f) { + for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } return result; @@ -149,7 +150,7 @@ template ()(std::declval()))> bidict map_keys(bidict const &m, F const &f) { bidict result; - for (auto const &kv : f) { + for (auto const &kv : m) { result.equate(f(kv.first), kv.second); } return result; @@ -159,7 +160,7 @@ template std::unordered_map filter_keys(std::unordered_map const &m, F const &f) { std::unordered_map result; - for (auto const &kv : f) { + for (auto const &kv : m) { if (f(kv.first)) { result.insert(kv); } @@ -167,6 +168,17 @@ std::unordered_map filter_keys(std::unordered_map const &m, return result; } +template +std::unordered_map filter_values(bidict const &m, F const &f) { + std::unordered_map result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.insert(kv); + } + } + return result; +} + template intersection(std::unordered_set const &l, return result; } +template +std::unordered_set + intersection(std::vector> const &v) { + if (v.empty()) { + throw mk_runtime_error("cannot take intersection of no sets"); + } + + std::unordered_set result = v[0]; + for (int i = 1; i < v.size(); i++) { + result = intersection(result, v[i]); + } + return result; +} + template bool are_disjoint(std::unordered_set const &l, std::unordered_set const &r) { @@ -311,12 +337,12 @@ std::function lookup_in(std::unordered_map const &m) { template std::function lookup_in_l(bidict const &m) { - return [&m](L const &l) -> L { return m.at_l(l); }; + return [&m](L const &l) -> R { return m.at_l(l); }; } template std::function lookup_in_r(bidict const &m) { - return [&m](R const &r) -> R { return m.at_r(r); }; + return [&m](R const &r) -> L { return m.at_r(r); }; } template @@ -509,13 +535,6 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } -template -std::vector sorted_by(std::unordered_set const &s, F const &f) { - std::vector result(s.begin(), s.end()); - inplace_sorted_by(s, f); - return result; -} - template void inplace_sorted_by(std::vector &v, F const &f) { auto custom_comparator = [&](T const &lhs, T const &rhs) -> bool { @@ -524,6 +543,13 @@ void inplace_sorted_by(std::vector &v, F const &f) { std::sort(v.begin(), v.end(), custom_comparator); } +template +std::vector sorted_by(std::unordered_set const &s, F const &f) { + std::vector result(s.begin(), s.end()); + inplace_sorted_by(result, f); + return result; +} + template C filter(C const &v, F const &f) { C result(v); @@ -559,13 +585,13 @@ std::pair, std::vector> vector_split(std::vector const &v, template typename C::value_type maximum(C const &v) { - return std::max_element(v.begin(), v.end()); + return *std::max_element(v.begin(), v.end()); } template T reversed(T const &t) { T r; - for (auto i = t.cend() - 1; i >= t.begin(); i++) { + for (auto i = t.cend() - 1; i >= t.begin(); i--) { r.push_back(*i); } return r; @@ -573,13 +599,13 @@ T reversed(T const &t) { template std::vector value_all(std::vector> const &v) { - std::vector result; - - for (auto const &element : v) { - result.push_back(element.value()); - } - - return result; + return transform(v, [](optional const &t) { + if (t == nullopt) { + throw mk_runtime_error("value is nullopt"); + } else { + return t.value(); + } + }); } template @@ -602,7 +628,7 @@ std::vector subvec(std::vector const &v, begin_iter += resolve_loc(maybe_start.value()); } if (maybe_end.has_value()) { - end_iter = v.cbegin() + resolve_loc(maybe_start.value()); + end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } std::vector output(begin_iter, end_iter); @@ -688,4 +714,4 @@ reversed_container_t reversed_container(C const &c) { } // namespace FlexFlow -#endif +#endif \ No newline at end of file diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 4c1d62fe93..f7f277cc17 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -14,6 +14,12 @@ template struct is_fmtable()))>> : std::true_type {}; +// template +// typename std::enable_if::value, std::ostream &>::type +// operator<<(std::ostream &s, T const &t) { +// return s << fmt::to_string(t); +// } + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/adjacency_digraph.h b/lib/utils/include/utils/graph/adjacency_digraph.h index d8bd410874..6909821382 100644 --- a/lib/utils/include/utils/graph/adjacency_digraph.h +++ b/lib/utils/include/utils/graph/adjacency_digraph.h @@ -9,6 +9,7 @@ namespace FlexFlow { class AdjacencyDiGraph : public IDiGraph { public: + AdjacencyDiGraph() = default; Node add_node() override; void add_node_unsafe(Node const &) override; void remove_node_unsafe(Node const &) override; @@ -28,8 +29,8 @@ class AdjacencyDiGraph : public IDiGraph { private: using ContentsType = std::unordered_map>; - AdjacencyDiGraph(std::size_t, ContentsType); - + AdjacencyDiGraph(std::size_t next_node_idx, ContentsType adjacency) + : next_node_idx(next_node_idx), adjacency(adjacency) {} std::size_t next_node_idx = 0; ContentsType adjacency; }; diff --git a/lib/utils/include/utils/graph/adjacency_multidigraph.h b/lib/utils/include/utils/graph/adjacency_multidigraph.h index e244d58fd1..a1be0dc112 100644 --- a/lib/utils/include/utils/graph/adjacency_multidigraph.h +++ b/lib/utils/include/utils/graph/adjacency_multidigraph.h @@ -9,8 +9,11 @@ namespace FlexFlow { class AdjacencyMultiDiGraph : public IMultiDiGraph { public: + AdjacencyMultiDiGraph() = default; Node add_node() override; void add_node_unsafe(Node const &) override; + NodePort add_node_port() override; + void add_node_port_unsafe(NodePort const &) override; void remove_node_unsafe(Node const &) override; void add_edge(Edge const &) override; void remove_edge(Edge const &) override; @@ -18,7 +21,8 @@ class AdjacencyMultiDiGraph : public IMultiDiGraph { std::unordered_set query_nodes(NodeQuery const &) const override; AdjacencyMultiDiGraph *clone() const override { - return new AdjacencyMultiDiGraph(this->next_node_idx, this->adjacency); + return new AdjacencyMultiDiGraph( + this->next_node_idx, this->next_node_port, this->adjacency); } private: @@ -28,9 +32,12 @@ class AdjacencyMultiDiGraph : public IMultiDiGraph { Node, std::unordered_map>>>; - AdjacencyMultiDiGraph(std::size_t, ContentsType const &); + AdjacencyMultiDiGraph(std::size_t next_node_idx, + std::size_t next_node_port, + ContentsType const &adjacency) + : next_node_idx(next_node_idx), next_node_port(next_node_port), + adjacency(adjacency) {} -private: std::size_t next_node_idx = 0; std::size_t next_node_port = 0; ContentsType adjacency; diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 720463dcb4..beab50be2e 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -15,8 +15,7 @@ #include namespace FlexFlow { - -std::vector add_nodes(Graph &, int); +std::vector add_nodes(DiGraph &, int); std::unordered_set get_nodes(GraphView const &); std::unordered_set get_node_ports(MultiDiGraphView const &); @@ -144,6 +143,10 @@ std::unordered_map> std::unordered_map> get_predecessors(DiGraphView const &, std::unordered_set const &); +std::unordered_set get_neighbors(UndirectedGraphView const &, + Node const &); +std::unordered_set get_neighbors(DiGraphView const &, Node const &); +std::unordered_set get_neighbors(MultiDiGraphView const &, Node const &); std::unordered_set get_sources(DiGraphView const &); std::unordered_set get_sources(MultiDiGraphView const &); @@ -230,6 +233,10 @@ MultiDiGraphView get_subgraph(MultiDiGraphView const &, OpenMultiDiGraphView get_subgraph(OpenMultiDiGraphView const &, std::unordered_set const &); +std::unordered_map calculate_topo_rank(DiGraphView const &); +Node get_node_with_greatest_topo_rank(std::unordered_set const &, + DiGraphView const &); + MultiDiGraphView join(MultiDiGraphView const &lhs, MultiDiGraphView const &rhs); DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs); UndirectedGraphView join(UndirectedGraphView const &lhs, @@ -241,6 +248,7 @@ DiGraphView with_added_edges(DiGraphView const &, std::unordered_set const &); UndirectedGraphView as_undirected(DiGraphView const &); +UndirectedGraphView as_undirected(MultiDiGraphView const &); MultiDiGraphView as_multidigraph(DiGraphView const &); DiGraphView as_digraph(MultiDiGraphView const &); MultiDiGraphView as_multidigraph(OpenMultiDiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph.h index 16f00a4cfb..0bfcd11853 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph.h @@ -14,15 +14,16 @@ struct DirectedEdge { Node src; Node dst; }; + +std::ostream &operator<<(std::ostream &s, DirectedEdge const &e); + FF_VISITABLE_STRUCT(DirectedEdge, src, dst); struct DirectedEdgeQuery { query_set srcs; query_set dsts; - static DirectedEdgeQuery all() { - NOT_IMPLEMENTED(); - } + static DirectedEdgeQuery all(); }; FF_VISITABLE_STRUCT(DirectedEdgeQuery, srcs, dsts); @@ -38,7 +39,7 @@ struct IDiGraphView : public IGraphView { IDiGraphView &operator=(IDiGraphView const &) = delete; virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; - virtual ~IDiGraphView(); + virtual ~IDiGraphView() = default; protected: IDiGraphView() = default; @@ -56,15 +57,9 @@ struct DiGraphView { friend void swap(DiGraphView &, DiGraphView &); - bool operator==(DiGraphView const &) const; - bool operator!=(DiGraphView const &) const; - std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &) const; - - IDiGraphView const *unsafe() const { - return this->ptr.get(); - } + friend bool is_ptr_equal(DiGraphView const &, DiGraphView const &); template static typename std::enable_if::value, @@ -73,18 +68,19 @@ struct DiGraphView { return DiGraphView(std::make_shared(std::forward(args)...)); } -private: - DiGraphView(std::shared_ptr); + static DiGraphView + unsafe_create_without_ownership(IDiGraphView const &graphView); - friend DiGraphView unsafe(IDiGraphView const &); + DiGraphView(std::shared_ptr const &ptr, + should_only_be_used_internally_tag_t const &tag) + : DiGraphView(ptr) {} private: + DiGraphView(std::shared_ptr ptr) : ptr(ptr) {} std::shared_ptr ptr; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); -DiGraphView unsafe(IDiGraphView const &); - struct IDiGraph : public IDiGraphView, public IGraph { virtual void add_edge(Edge const &) = 0; virtual void remove_edge(Edge const &) = 0; @@ -101,7 +97,15 @@ struct DiGraph { DiGraph(DiGraph const &) = default; DiGraph &operator=(DiGraph const &) = default; - operator DiGraphView() const; + operator DiGraphView() const { + return DiGraphView(this->ptr.get(), should_only_be_used_internally_tag_t{}); + } + + operator Graph() const { + return Graph(std::static_pointer_cast(this->ptr.get_mutable()), + should_only_be_used_internally_tag_t{}); + } // Note(lambda):because ptr is cow_ptr_t, but ptr in Graph is + // cow_ptr_t, so we need to static_pointer_cast friend void swap(DiGraph &, DiGraph &); @@ -123,9 +127,7 @@ struct DiGraph { } private: - DiGraph(std::unique_ptr); - -private: + DiGraph(std::unique_ptr ptr); cow_ptr_t ptr; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/multidiedge.h b/lib/utils/include/utils/graph/multidiedge.h index f521357816..3d5d0332f3 100644 --- a/lib/utils/include/utils/graph/multidiedge.h +++ b/lib/utils/include/utils/graph/multidiedge.h @@ -26,6 +26,7 @@ struct MultiDiEdge { NodePort srcIdx, dstIdx; }; FF_VISITABLE_STRUCT(MultiDiEdge, src, dst, srcIdx, dstIdx); +std::ostream &operator<<(std::ostream &os, MultiDiEdge const &edge); struct MultiDiInput { Node node; diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index 5a0ae65490..5866f82665 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -15,10 +15,6 @@ struct MultiDiGraphView { friend void swap(MultiDiGraphView &, MultiDiGraphView &); - IMultiDiGraphView const *unsafe() const { - return this->ptr.get(); - } - std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &) const; @@ -30,18 +26,15 @@ struct MultiDiGraphView { std::make_shared(std::forward(args)...)); } -private: - MultiDiGraphView(std::shared_ptr); - - friend struct MultiDiGraph; - friend MultiDiGraphView unsafe(IMultiDiGraphView const &); + static MultiDiGraphView + unsafe_create_without_ownership(IMultiDiGraphView const &); private: + MultiDiGraphView(std::shared_ptr ptr) : ptr(ptr) {} std::shared_ptr ptr; }; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); -MultiDiGraphView unsafe(IMultiDiGraphView const &); +CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(MultiDiGraphView); struct MultiDiGraph { public: @@ -57,13 +50,15 @@ struct MultiDiGraph { friend void swap(MultiDiGraph &, MultiDiGraph &); Node add_node(); + std::vector add_nodes(size_t); NodePort add_node_port(); + std::vector add_node_ports(size_t); void add_node_unsafe(Node const &); void add_node_port_unsafe(NodePort const &); void remove_node_unsafe(Node const &); - - void add_edge(Edge const &e); - void remove_edge(Edge const &e); + void add_edges(std::vector const &); + void add_edge(Edge const &); + void remove_edge(Edge const &); std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &) const; @@ -76,7 +71,7 @@ struct MultiDiGraph { } private: - MultiDiGraph(std::unique_ptr); + MultiDiGraph(std::unique_ptr ptr) : ptr(std::move(ptr)) {} private: cow_ptr_t ptr; diff --git a/lib/utils/include/utils/graph/multidigraph_interfaces.h b/lib/utils/include/utils/graph/multidigraph_interfaces.h index aeb8e51a75..a5b122e30c 100644 --- a/lib/utils/include/utils/graph/multidigraph_interfaces.h +++ b/lib/utils/include/utils/graph/multidigraph_interfaces.h @@ -35,13 +35,13 @@ struct IMultiDiGraphView : public IGraphView { using EdgeQuery = MultiDiEdgeQuery; virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; - virtual ~IMultiDiGraphView(); + virtual ~IMultiDiGraphView() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraphView); struct IMultiDiGraph : public IMultiDiGraphView, public IGraph { - virtual NodePort add_node_port(); - virtual void add_node_port_unsafe(NodePort const &); + virtual NodePort add_node_port() = 0; + virtual void add_node_port_unsafe(NodePort const &) = 0; virtual void add_edge(Edge const &) = 0; virtual void remove_edge(Edge const &) = 0; diff --git a/lib/utils/include/utils/graph/node.h b/lib/utils/include/utils/graph/node.h index 2379693a45..4e2ce21770 100644 --- a/lib/utils/include/utils/graph/node.h +++ b/lib/utils/include/utils/graph/node.h @@ -4,6 +4,7 @@ #include "cow_ptr_t.h" #include "query_set.h" #include "utils/fmt.h" +#include "utils/internal_only_tag.h" #include "utils/optional.h" #include "utils/strong_typedef.h" #include "utils/type_traits.h" @@ -23,14 +24,16 @@ struct Node : public strong_typedef { FF_TYPEDEF_HASHABLE(Node); FF_TYPEDEF_PRINTABLE(Node, "Node"); +std::ostream &operator<<(std::ostream &, Node const &); +std::ostream &operator<<(std::ostream &os, + std::unordered_set const &nodes); + struct NodeQuery { NodeQuery(query_set const &nodes) : nodes(nodes) {} query_set nodes; - static NodeQuery all() { - NOT_IMPLEMENTED(); - } + static NodeQuery all(); }; FF_VISITABLE_STRUCT(NodeQuery, nodes); @@ -53,11 +56,7 @@ struct GraphView { std::unordered_set query_nodes(NodeQuery const &) const; - IGraphView const *unsafe() const { - return this->ptr.get(); - } - - static GraphView unsafe_create(IGraphView const &); + static GraphView unsafe_create_without_ownership(IGraphView const &); template static typename std::enable_if::value, @@ -65,17 +64,19 @@ struct GraphView { create(Args &&...args) { return GraphView(std::make_shared(std::forward(args)...)); } + GraphView(std::shared_ptr const &ptr, + should_only_be_used_internally_tag_t const &tag) + : GraphView(ptr) {} private: - GraphView(std::shared_ptr); - -private: + GraphView(std::shared_ptr ptr) : ptr(ptr) {} std::shared_ptr ptr; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraphView); struct IGraph : IGraphView { IGraph(IGraph const &) = delete; + IGraph() = default; IGraph &operator=(IGraph const &) = delete; virtual Node add_node() = 0; @@ -106,10 +107,13 @@ struct Graph { return Graph(make_unique()); } -private: - Graph(std::unique_ptr); + Graph(cow_ptr_t const &ptr, + should_only_be_used_internally_tag_t const &tag) + : ptr(std::move(ptr)) {} private: + Graph(std::unique_ptr ptr) : ptr(std::move(ptr)) {} + // Graph(std::shared_ptr ptr): ptr(to_cow_ptr(ptr)) {} cow_ptr_t ptr; }; diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index 60c98a10bc..470d0870f4 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -19,12 +19,8 @@ struct OpenMultiDiGraphView { friend void swap(OpenMultiDiGraphView &, OpenMultiDiGraphView &); - std::unordered_set query_nodes(NodeQuery const &); - std::unordered_set query_edges(EdgeQuery const &); - - IOpenMultiDiGraphView const *unsafe() const { - return this->ptr.get(); - } + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; template static @@ -36,9 +32,8 @@ struct OpenMultiDiGraphView { } private: - OpenMultiDiGraphView(std::shared_ptr); - -private: + OpenMultiDiGraphView(std::shared_ptr ptr) + : ptr(ptr) {} std::shared_ptr ptr; }; @@ -50,8 +45,6 @@ struct OpenMultiDiGraph { OpenMultiDiGraph() = delete; OpenMultiDiGraph(OpenMultiDiGraph const &); - OpenMultiDiGraph &operator=(OpenMultiDiGraph); - friend void swap(OpenMultiDiGraph &, OpenMultiDiGraph &); operator OpenMultiDiGraphView() const; @@ -73,9 +66,7 @@ struct OpenMultiDiGraph { } private: - OpenMultiDiGraph(std::unique_ptr); - -private: + OpenMultiDiGraph(std::unique_ptr ptr); cow_ptr_t ptr; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(OpenMultiDiGraph); @@ -92,10 +83,6 @@ struct UpwardOpenMultiDiGraphView { std::unordered_set query_nodes(NodeQuery const &); std::unordered_set query_edges(EdgeQuery const &); - IUpwardOpenMultiDiGraphView const *unsafe() const { - return this->ptr.get(); - } - template static typename std::enable_if< std::is_base_of::value, @@ -163,10 +150,6 @@ struct DownwardOpenMultiDiGraphView { std::unordered_set query_nodes(NodeQuery const &); std::unordered_set query_edges(EdgeQuery const &); - IDownwardOpenMultiDiGraphView const *unsafe() const { - return this->ptr.get(); - } - template static typename std::enable_if< std::is_base_of::value, diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 2282953c62..b2ef7ff495 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_QUERY_SET_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_QUERY_SET_H +#include "utils/bidict.h" #include "utils/containers.h" #include "utils/exception.h" #include "utils/optional.h" @@ -11,15 +12,11 @@ namespace FlexFlow { template struct query_set { query_set() = delete; - query_set(T const &) { - NOT_IMPLEMENTED(); - } - query_set(std::unordered_set const &) { - NOT_IMPLEMENTED(); - } - query_set(optional> const &) { - NOT_IMPLEMENTED(); - } + query_set(T const &query) : query({query}) {} + + query_set(std::unordered_set const &query) : query(query) {} + + query_set(optional> const &query) : query(query) {} friend bool operator==(query_set const &lhs, query_set const &rhs) { return lhs.value == rhs.value; @@ -73,14 +70,29 @@ template std::unordered_map query_keys(query_set const &q, C const &m) { - NOT_IMPLEMENTED(); + if (is_matchall(q)) { + return m; + } + return filter_keys(m, [&](K const &key) { return includes(q, key); }); +} + +template +std::unordered_map query_keys(query_set const &q, + bidict const &m) { + if (is_matchall(q)) { + return filter_values(m, [](V const &value) { return true; }); + } + return filter_values(m, [&](V const &value) { return includes(q, value); }); } template std::unordered_map query_values(query_set const &q, C const &m) { - NOT_IMPLEMENTED(); + if (is_matchall(q)) { + return m; + } + return filter_values(m, [&](V const &value) { return includes(q, value); }); } template diff --git a/lib/utils/include/utils/graph/serialparallel.h b/lib/utils/include/utils/graph/serialparallel.h index 737bb9e324..6d658749b0 100644 --- a/lib/utils/include/utils/graph/serialparallel.h +++ b/lib/utils/include/utils/graph/serialparallel.h @@ -20,15 +20,12 @@ struct Serial { std::vector> children; }; -FF_VISITABLE_STRUCT(Serial, children); -MAKE_VISIT_HASHABLE(Serial); - struct Parallel { std::vector> children; }; -FF_VISITABLE_STRUCT(Parallel, children); -MAKE_VISIT_HASHABLE(Parallel); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Parallel, children); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(Serial, children); using SerialParallelDecomposition = variant; diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h index edaca75a9d..5fe1225f36 100644 --- a/lib/utils/include/utils/graph/undirected.h +++ b/lib/utils/include/utils/graph/undirected.h @@ -26,9 +26,7 @@ namespace FlexFlow { struct UndirectedEdgeQuery { query_set nodes; - static UndirectedEdgeQuery all() { - NOT_IMPLEMENTED(); - } + static UndirectedEdgeQuery all(); }; FF_VISITABLE_STRUCT(UndirectedEdgeQuery, nodes); @@ -44,7 +42,7 @@ struct IUndirectedGraphView : public IGraphView { virtual std::unordered_set query_edges(UndirectedEdgeQuery const &) const = 0; - virtual ~IUndirectedGraphView(); + virtual ~IUndirectedGraphView() = default; protected: IUndirectedGraphView() = default; @@ -60,18 +58,14 @@ struct UndirectedGraphView { UndirectedGraphView() = delete; - operator GraphView const &() const; - operator GraphView &(); + operator GraphView() const; + // operator GraphView &(); friend void swap(UndirectedGraphView &, UndirectedGraphView &); std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &) const; - IUndirectedGraphView const *unsafe() const { - return this->ptr.get(); - } - template static typename std::enable_if::value, @@ -81,12 +75,15 @@ struct UndirectedGraphView { std::make_shared(std::forward(args)...)); } -private: - UndirectedGraphView(std::shared_ptr); - - friend UndirectedGraphView unsafe(IUndirectedGraphView const &); + static UndirectedGraphView + unsafe_create_without_ownership(IUndirectedGraphView const &); + UndirectedGraphView(std::shared_ptr const &ptr, + should_only_be_used_internally_tag_t const &tag) + : UndirectedGraphView(ptr) {} private: + UndirectedGraphView(std::shared_ptr ptr) + : ptr(ptr) {} std::shared_ptr ptr; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); @@ -129,9 +126,7 @@ struct UndirectedGraph { } private: - UndirectedGraph(std::unique_ptr); - -private: + UndirectedGraph(std::unique_ptr ptr); cow_ptr_t ptr; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraph); diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views.h index 811131e4d1..6455aad554 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views.h @@ -60,8 +60,8 @@ struct DiSubgraphView : public IDiGraphView { struct MultiDiSubgraphView : public IMultiDiGraphView { public: MultiDiSubgraphView() = delete; - explicit MultiDiSubgraphView(MultiDiGraphView const &, - std::unordered_set const &); + explicit MultiDiSubgraphView(MultiDiGraphView const &g, + std::unordered_set const &subgraph_nodes); std::unordered_set query_edges(MultiDiEdgeQuery const &) const override; @@ -86,23 +86,18 @@ enum class LRDirection { LEFT, RIGHT }; struct JoinNodeKey { JoinNodeKey() = delete; - JoinNodeKey(Node const &, LRDirection); + JoinNodeKey(Node const &node, LRDirection direction) + : node(node), direction(direction) {} - bool operator==(JoinNodeKey const &) const; bool operator<(JoinNodeKey const &) const; Node node; - LRDirection direction; + req direction; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(JoinNodeKey, node, direction); -namespace std { -template <> -struct hash<::FlexFlow::JoinNodeKey> { - std::size_t operator()(::FlexFlow::JoinNodeKey const &) const; -}; -} // namespace std +} // namespace FlexFlow namespace FlexFlow { @@ -221,9 +216,10 @@ struct SingleSourceNodeView : public IDiGraphView { struct ContractNodeView : public IDiGraphView { ContractNodeView() = delete; - explicit ContractNodeView(DiGraphView const &, + explicit ContractNodeView(DiGraphView const &g, Node const &removed, - Node const &into); + Node const &into) + : g(g), from(removed), to(into) {} std::unordered_set query_edges(DirectedEdgeQuery const &) const override; @@ -240,8 +236,8 @@ struct ContractNodeView : public IDiGraphView { struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { public: OpenMultiDiSubgraphView() = delete; - explicit OpenMultiDiSubgraphView(OpenMultiDiGraphView const &, - std::unordered_set const &); + OpenMultiDiSubgraphView(OpenMultiDiGraphView const &, + std::unordered_set const &); std::unordered_set query_edges(OpenMultiDiEdgeQuery const &) const override; @@ -305,7 +301,8 @@ struct ViewMultiDiGraphAsDiGraph : public IDiGraphView { struct ViewOpenMultiDiGraphAsMultiDiGraph : public IMultiDiGraphView { public: ViewOpenMultiDiGraphAsMultiDiGraph() = delete; - explicit ViewOpenMultiDiGraphAsMultiDiGraph(OpenMultiDiGraphView const &); + explicit ViewOpenMultiDiGraphAsMultiDiGraph(OpenMultiDiGraphView const &g) + : g(g) {} std::unordered_set query_edges(MultiDiEdgeQuery const &) const override; @@ -349,6 +346,14 @@ Impl materialize_multidigraph_view(IMultiDiGraphView const &g) { } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::JoinNodeKey, node, direction); +namespace std { +template <> +struct hash> { + std::size_t operator()( + FlexFlow::required const &direction) const { + return std::hash{}(direction.value()); + } +}; +} // namespace std #endif diff --git a/lib/utils/include/utils/internal_only_tag.h b/lib/utils/include/utils/internal_only_tag.h new file mode 100644 index 0000000000..649ce4cf12 --- /dev/null +++ b/lib/utils/include/utils/internal_only_tag.h @@ -0,0 +1,10 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_INTERNAL_ONLY_TAG_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_INTERNAL_ONLY_TAG_H + +namespace FlexFlow { +struct should_only_be_used_internally_tag_t { + explicit should_only_be_used_internally_tag_t() = default; +}; +} // namespace FlexFlow + +#endif \ No newline at end of file diff --git a/lib/utils/include/utils/visitable.h b/lib/utils/include/utils/visitable.h index 6a671400cd..4f0dc50cbe 100644 --- a/lib/utils/include/utils/visitable.h +++ b/lib/utils/include/utils/visitable.h @@ -457,4 +457,4 @@ struct Arbitrary< _GET_VISITABLE_CASE_FROM_NUM_ARGS(__VA_ARGS__), \ __VA_ARGS__) -#endif +#endif \ No newline at end of file diff --git a/lib/utils/src/graph/adjacency_multidigraph.cc b/lib/utils/src/graph/adjacency_multidigraph.cc index e9dbbb51ca..79e7de7d75 100644 --- a/lib/utils/src/graph/adjacency_multidigraph.cc +++ b/lib/utils/src/graph/adjacency_multidigraph.cc @@ -10,6 +10,16 @@ Node AdjacencyMultiDiGraph::add_node() { return node; } +NodePort AdjacencyMultiDiGraph::add_node_port() { + auto nodePort = NodePort{this->next_node_port}; + this->next_node_port++; + return nodePort; +} + +void AdjacencyMultiDiGraph::add_node_port_unsafe(NodePort const &nodePort) { + this->next_node_port = std::max(this->next_node_port, nodePort.value() + 1); +} + void AdjacencyMultiDiGraph::add_node_unsafe(Node const &node) { adjacency[node]; this->next_node_idx = std::max(this->next_node_idx, node.value() + 1); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index adf479f11d..01e5987c80 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -8,11 +8,11 @@ #include namespace FlexFlow { - -std::vector add_nodes(IGraph &g, int num_nodes) { +std::vector add_nodes(DiGraph &g, int num_nodes) { std::vector nodes; - std::generate_n( - std::back_inserter(nodes), num_nodes, [&g]() { return g.add_node(); }); + for (int i = 0; i < num_nodes; i++) { + nodes.push_back(g.add_node()); + } return nodes; } @@ -86,6 +86,22 @@ std::size_t num_nodes(GraphView const &g) { return get_nodes(g).size(); } +DiGraphView + contract_node(DiGraphView const &g, Node const &from, Node const &into) { + return DiGraphView::create(g, from, into); +} + +DiGraphView apply_contraction(DiGraphView const &g, + std::unordered_map const &nodes) { + DiGraphView contractedView = g; + for (auto const &kv : nodes) { + Node from = kv.first; + Node into = kv.second; + contractedView = contract_node(contractedView, from, into); + } + return contractedView; +} + void add_edges(MultiDiGraph &g, std::vector const &edges) { for (MultiDiEdge const &e : edges) { g.add_edge(e); @@ -177,6 +193,11 @@ std::unordered_set return to_directed_edges(get_incoming_edges(multidigraph_view, dsts)); } +std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, + Node const &n) { + return get_outgoing_edges(g, std::unordered_set{n}); +} + std::unordered_set get_outgoing_edges(MultiDiGraphView const &g, std::unordered_set const &srcs) { @@ -190,6 +211,11 @@ std::unordered_set return to_directed_edges(get_outgoing_edges(multidigraph_view, dsts)); } +std::unordered_set get_outgoing_edges(DiGraphView const &g, + Node const &n) { + return get_outgoing_edges(g, std::unordered_set{n}); +} + std::unordered_map> get_predecessors(DiGraphView const &g, std::unordered_set const &nodes) { @@ -204,7 +230,7 @@ std::unordered_map> } std::unordered_set get_predecessors(DiGraphView const &g, Node const &n) { - return get_predecessors(g, {n}); + return get_predecessors(g, std::unordered_set{n}).at(n); } std::unordered_map> @@ -238,15 +264,25 @@ std::vector return {bfs_view.begin(), bfs_view.end()}; } +std::unordered_set get_sinks(DiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_outgoing_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_sinks(MultiDiGraphView const &g) { + DiGraphView digraph_view = as_digraph(g); + return get_sinks(digraph_view); +} + +DiGraphView flipped(DiGraphView const &g) { + return DiGraphView::create(g); +} + std::unordered_set get_sources(DiGraphView const &g) { - std::unordered_set sources; - for (Node const &n : get_nodes(g)) { - auto incoming = get_incoming_edges(g, n); - if (incoming.size() == 0) { - sources.insert(n); - } - } - return sources; + return filter(get_nodes(g), [&](Node const &n) { + return get_incoming_edges(g, n).size() == 0; + }); } optional is_acyclic(DiGraphView const &g) { @@ -282,7 +318,6 @@ std::vector get_unchecked_topological_ordering(DiGraphView const &g) { std::unordered_set seen; std::unordered_map> predecessors = get_predecessors(g, get_nodes(g)); - auto all_predecessors_seen = [&](Node const &n) -> bool { bool result = true; for (Node const &pred : predecessors.at(n)) { @@ -327,6 +362,32 @@ std::vector get_edge_topological_ordering(DiGraphView const &g) { return result; } +std::unordered_set get_neighbors(DiGraphView const &g, Node const &n) { + UndirectedGraphView undirected = as_undirected(g); + return get_neighbors(undirected, n); +} + +std::unordered_set get_neighbors(MultiDiGraphView const &g, + Node const &n) { + UndirectedGraphView undirected = as_undirected(g); + return get_neighbors(undirected, n); +} + +std::unordered_set get_neighbors(UndirectedGraphView const &g, + Node const &n) { + std::unordered_set edges = + filter(get_node_edges(g, n), [&](UndirectedEdge const &edge) { + return ((edge.smaller == n && edge.bigger != n) || + (edge.smaller != n && edge.bigger == n)); + }); + std::cout << "get_neighbors: " << edges.size() << "\n"; + return map_over_unordered_set( + [&](UndirectedEdge const &edge) -> Node { + return (edge.smaller == n) ? edge.bigger : edge.smaller; + }, + edges); +} + std::vector get_edge_topological_ordering(MultiDiGraphView const &g) { std::vector result; @@ -354,14 +415,14 @@ std::unordered_map> for (Node const &n : topo) { for (Node const &pred : get_predecessors(g, n)) { if (contains_key(result, n)) { - result[n] = result.at(pred); + result.at(n) = result.at(pred); } else { + result[n] = std::unordered_set(); result.at(n) = intersection(result.at(n), result.at(pred)); } } result[n].insert(n); } - return result; } @@ -377,29 +438,11 @@ std::unordered_map> std::unordered_map> get_imm_dominators(DiGraphView const &g) { - std::unordered_map topo_rank = [&g]() { - std::vector topo_ordering = get_topological_ordering(g); - std::unordered_map topo_rank; - for (int i = 0; i < topo_ordering.size(); i++) { - topo_rank[topo_ordering[i]] = i; - } - return topo_rank; - }(); - - auto with_greatest_topo_rank = - [&topo_rank](std::unordered_set const &nodes) -> Node { - return *std::max_element(nodes.cbegin(), - nodes.cend(), - [&topo_rank](Node const &lhs, Node const &rhs) { - return topo_rank.at(lhs) < topo_rank.at(rhs); - }); - }; std::unordered_map> result; for (auto const &kv : get_dominators(g)) { Node node = kv.first; std::unordered_set node_dominators = kv.second; - assert(node_dominators.size() >= 1); // a node cannot immediately dominate itself @@ -408,7 +451,7 @@ std::unordered_map> result[node] = nullopt; } else { node_dominators.erase(node); - result[node] = with_greatest_topo_rank(node_dominators); + result[node] = get_node_with_greatest_topo_rank(node_dominators, g); } } return result; @@ -437,6 +480,42 @@ optional imm_post_dominator(MultiDiGraphView const &g, Node const &n) { return get_imm_post_dominators(g).at(n); } +optional get_imm_post_dominator(DiGraphView const &g, Node const &n) { + return get_imm_post_dominators(g).at(n); +} + +std::unordered_map calculate_topo_rank(DiGraphView const &g) { + std::vector topo_ordering = get_topological_ordering(g); + std::unordered_map topo_rank; + for (int i = 0; i < topo_ordering.size(); i++) { + topo_rank[topo_ordering[i]] = i; + } + return topo_rank; +} + +Node get_node_with_greatest_topo_rank(std::unordered_set const &nodes, + DiGraphView const &g) { + std::unordered_map topo_rank = calculate_topo_rank(g); + return *std::max_element(nodes.cbegin(), + nodes.cend(), + [&topo_rank](Node const &lhs, Node const &rhs) { + return topo_rank.at(lhs) < topo_rank.at(rhs); + }); +} + +optional get_imm_post_dominator(DiGraphView const &g, + std::unordered_set const &nodes) { + + std::unordered_set commonDoms = + intersection(values(restrict_keys(get_post_dominators(g), nodes))); + + if (!commonDoms.empty()) { + return get_node_with_greatest_topo_rank(commonDoms, g); + } else { + return nullopt; + } +} + std::pair split_edge(MultiDiEdge const &e) { return {OutputMultiDiEdge{{e.dst.value(), e.dstIdx.value()}, e.src, e.srcIdx}, @@ -486,6 +565,11 @@ UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } +UndirectedGraphView as_undirected(MultiDiGraphView const &g) { + DiGraphView dg = as_digraph(g); + return as_undirected(dg); +} + MultiDiGraphView as_multidigraph(DiGraphView const &g) { return MultiDiGraphView::create(g); } @@ -498,4 +582,51 @@ MultiDiGraphView as_multidigraph(OpenMultiDiGraphView const &g) { return MultiDiGraphView::create(g); } +std::vector> + get_weakly_connected_components(DiGraphView const &g) { + return get_connected_components(as_undirected(g)); +} + +std::vector> + get_weakly_connected_components(MultiDiGraphView const &g) { + return get_connected_components(as_undirected(g)); +} + +std::vector> + get_connected_components(UndirectedGraphView const &g) { + std::vector> components; + std::unordered_set visited; + + std::unordered_set nodes = g.query_nodes(NodeQuery::all()); + for (Node const &node : nodes) { + if (visited.count(node) > 0) { + continue; + } + + std::unordered_set component; + std::stack stack; + stack.push(node); + + while (!stack.empty()) { + Node current = stack.top(); + stack.pop(); + + if (visited.count(current) > 0) { + continue; + } + + component.insert(current); + visited.insert(current); + std::unordered_set neighbors = get_neighbors(g, current); + + for (auto const &neighbor : neighbors) { + stack.push(neighbor); + } + } + + components.push_back(std::move(component)); + } + return components; +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index 7e25324c86..fd0612eceb 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -2,12 +2,21 @@ namespace FlexFlow { +std::ostream &operator<<(std::ostream &s, DirectedEdge const &e) { + std::string str = fmt::format("DirectedEdge(src={}, dst={})", e.src, e.dst); + return s << str; +} + void swap(DiGraph &lhs, DiGraph &rhs) { using std::swap; swap(lhs.ptr, rhs.ptr); } +bool is_ptr_equal(DiGraphView const &lhs, DiGraphView const &rhs) { + return lhs.ptr == rhs.ptr; +} + Node DiGraph::add_node() { return this->ptr.get_mutable()->add_node(); } @@ -35,4 +44,56 @@ std::unordered_set DiGraph::DiGraph(std::unique_ptr _ptr) : ptr(std::move(_ptr)) {} +DiGraphView::operator GraphView() const { + return GraphView(this->ptr, should_only_be_used_internally_tag_t{}); +} + +std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} + +std::unordered_set + DiGraphView::query_edges(EdgeQuery const &query) const { + return ptr->query_edges(query); +} + +// Set the shared_ptr's destructor to a nop so that effectively there is no +// ownership +DiGraphView DiGraphView::unsafe_create_without_ownership( + IDiGraphView const &graphView) { + std::shared_ptr ptr((&graphView), + [](IDiGraphView const *) {}); + return DiGraphView(ptr); +} + +DirectedEdgeQuery DirectedEdgeQuery::all() { + return {matchall(), matchall()}; +} + +DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, + DirectedEdgeQuery const &rhs) { + std::unordered_set srcs_tl; + if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(rhs.srcs); + } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(lhs.srcs); + } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); + } + + std::unordered_set dsts_tl; + if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(rhs.dsts); + } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(lhs.dsts); + } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); + } + + DirectedEdgeQuery result = DirectedEdgeQuery::all(); + result.srcs = srcs_tl; + result.dsts = dsts_tl; + return result; +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 28c00a1aef..98075041b8 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -13,27 +13,78 @@ MultiDiOutput get_output(MultiDiEdge const &e) { MultiDiEdgeQuery MultiDiEdgeQuery::with_src_nodes(query_set const &nodes) const { MultiDiEdgeQuery e = *this; - if (is_matchall(e.srcs)) { + if (!is_matchall(e.srcs)) { throw mk_runtime_error("Expected matchall previous value"); } e.srcs = nodes; return e; } +std::ostream &operator<<(std::ostream &os, MultiDiEdge const &edge) { + return os << "MultiDiEdge{" << edge.src.value() << "," << edge.dst.value() + << "," << edge.srcIdx.value() << "," << edge.dstIdx.value() << "}"; +} + MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { MultiDiEdgeQuery e = *this; - if (is_matchall(e.dsts)) { + if (!is_matchall(e.dsts)) { throw mk_runtime_error("Expected matchall previous value"); } e.dsts = nodes; return e; } +MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs, + MultiDiEdgeQuery const &rhs) { + std::unordered_set srcs_t1; + if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(rhs.srcs); + } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(lhs.srcs); + } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); + } + + std::unordered_set dsts_t1; + if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(rhs.dsts); + } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(lhs.dsts); + } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); + } + + std::unordered_set srcIdxs_t1; + if (is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(rhs.srcIdxs); + } else if (!is_matchall(lhs.srcIdxs) && is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(lhs.srcIdxs); + } else if (!is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(query_intersection(lhs.srcIdxs, rhs.srcIdxs)); + } + + std::unordered_set dstIdxs_t1; + if (is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(rhs.dstIdxs); + } else if (!is_matchall(lhs.dstIdxs) && is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(lhs.dstIdxs); + } else if (!is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(query_intersection(lhs.dstIdxs, rhs.dstIdxs)); + } + + MultiDiEdgeQuery e = MultiDiEdgeQuery::all(); + e.srcs = srcs_t1; + e.dsts = dsts_t1; + e.srcIdxs = srcIdxs_t1; + e.dstIdxs = dstIdxs_t1; + return e; +} + MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idxs(query_set const &idxs) const { - MultiDiEdgeQuery e = *this; - if (is_matchall(e.srcIdxs)) { + MultiDiEdgeQuery e{*this}; + if (!is_matchall(e.srcIdxs)) { throw mk_runtime_error("Expected matchall previous value"); } e.srcIdxs = idxs; @@ -43,7 +94,7 @@ MultiDiEdgeQuery MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_idxs(query_set const &idxs) const { MultiDiEdgeQuery e = *this; - if (is_matchall(e.dstIdxs)) { + if (!is_matchall(e.dstIdxs)) { throw mk_runtime_error("Expected matchall previous value"); } e.dstIdxs = idxs; @@ -57,6 +108,29 @@ MultiDiEdgeQuery MultiDiEdgeQuery::all() { matchall()}; } +std::unordered_set + MultiDiGraphView::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} + +std::unordered_set + MultiDiGraphView::query_edges(MultiDiEdgeQuery const &q) const { + return this->ptr->query_edges(q); +} + +MultiDiGraphView::operator GraphView() const { + return GraphView(this->ptr, should_only_be_used_internally_tag_t{}); +} + +// Set the shared_ptr's destructor to a nop so that effectively there is no +// ownership +MultiDiGraphView MultiDiGraphView::unsafe_create_without_ownership( + IMultiDiGraphView const &graphView) { + std::shared_ptr ptr( + (&graphView), [](IMultiDiGraphView const *ptr) {}); + return MultiDiGraphView(ptr); +} + void swap(MultiDiGraphView &lhs, MultiDiGraphView &rhs) { using std::swap; @@ -70,13 +144,37 @@ void swap(MultiDiGraph &lhs, MultiDiGraph &rhs) { } MultiDiGraph::operator MultiDiGraphView() const { - return MultiDiGraphView(this->ptr.get()); + return MultiDiGraphView::unsafe_create_without_ownership(*this->ptr); } Node MultiDiGraph::add_node() { return this->ptr.get_mutable()->add_node(); } +std::vector MultiDiGraph::add_nodes(size_t n) { + std::vector nodes; + for (size_t i = 0; i < n; i++) { + nodes.push_back(add_node()); + } + return nodes; +} + +std::vector MultiDiGraph::add_node_ports(size_t n) { + std::vector ports; + for (size_t i = 0; i < n; i++) { + ports.push_back(add_node_port()); + } + return ports; +} + +NodePort MultiDiGraph::add_node_port() { + return this->ptr.get_mutable()->add_node_port(); +} + +void MultiDiGraph::add_node_port_unsafe(NodePort const &np) { + return this->ptr.get_mutable()->add_node_port_unsafe(np); +} + void MultiDiGraph::add_node_unsafe(Node const &n) { return this->ptr.get_mutable()->add_node_unsafe(n); } @@ -89,6 +187,12 @@ void MultiDiGraph::add_edge(MultiDiEdge const &e) { return this->ptr.get_mutable()->add_edge(e); } +void MultiDiGraph::add_edges(std::vector const &edges) { + for (MultiDiEdge const &e : edges) { + add_edge(e); + } +} + void MultiDiGraph::remove_edge(MultiDiEdge const &e) { return this->ptr.get_mutable()->remove_edge(e); } @@ -98,4 +202,8 @@ std::unordered_set return this->ptr->query_edges(q); } +std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc new file mode 100644 index 0000000000..70f898306a --- /dev/null +++ b/lib/utils/src/graph/node.cc @@ -0,0 +1,58 @@ +#include "utils/graph/node.h" +#include + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &os, Node const &node) { + return os << fmt::format("Node({})", node.value()); +} + +std::ostream &operator<<(std::ostream &os, + std::unordered_set const &nodes) { + os << fmt::format("{ "); + for (auto it = nodes.begin(); it != nodes.end(); ++it) { + if (it != nodes.begin()) { + os << fmt::format(", "); + } + os << *it; + } + os << fmt::format(" }"); + return os; +} + +NodeQuery NodeQuery::all() { + return {matchall()}; +} + +NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { + + std::unordered_set nodes; + + if (is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { + nodes = allowed_values(rhs.nodes); + } else if (!is_matchall(lhs.nodes) && is_matchall(rhs.nodes)) { + nodes = allowed_values(lhs.nodes); + } else if (!is_matchall(lhs.nodes) && !is_matchall(rhs.nodes)) { + nodes = allowed_values(query_intersection(lhs.nodes, rhs.nodes)); + } + + NodeQuery intersection_result = NodeQuery::all(); + intersection_result.nodes = nodes; + + return intersection_result; +} + +std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { + return this->ptr->query_nodes(g); +} + +// Set the shared_ptr's destructor to a nop so that effectively there is no +// ownership +GraphView + GraphView::unsafe_create_without_ownership(IGraphView const &graphView) { + std::shared_ptr ptr((&graphView), + [](IGraphView const *) {}); + return GraphView(ptr); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index fc1a35b139..0332ddc597 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -1,7 +1,28 @@ #include "utils/graph/open_graphs.h" +#include "utils/graph/query_set.h" namespace FlexFlow { +InputMultiDiEdgeQuery InputMultiDiEdgeQuery::all() { + return {matchall(), matchall()}; +} + +OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::all() { + return {matchall(), matchall()}; +} + +std::unordered_set + OpenMultiDiGraphView::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} +std::unordered_set + OpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { + return this->ptr->query_edges(q); +} + +OpenMultiDiGraph::OpenMultiDiGraph(OpenMultiDiGraph const &other) + : ptr(other.ptr) {} + void swap(OpenMultiDiGraph &lhs, OpenMultiDiGraph &rhs) { using std::swap; diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index de164ff60e..176bcf6cb9 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -131,6 +131,15 @@ SplitAST parallel_decomposition(DiGraphView const &g) { return split; } +SplitASTNode::SplitASTNode(SplitType type, + SplitAST const &lhs, + SplitAST const &rhs) + : SplitASTNode(type, {lhs, rhs}) {} + +SplitASTNode::SplitASTNode(SplitType type, + std::vector const &children) + : type(type), children(children) {} + struct FlattenAST { void add_flattened_child_to_parent(SplitASTNode &parent, SplitAST const &child) { diff --git a/lib/utils/src/graph/serialparallel_internal.h b/lib/utils/src/graph/serialparallel_internal.h index 6b329803e5..4c9c22d95e 100644 --- a/lib/utils/src/graph/serialparallel_internal.h +++ b/lib/utils/src/graph/serialparallel_internal.h @@ -18,7 +18,7 @@ struct SplitASTNode; using SplitAST = mpark::variant; struct SplitASTNode { - SplitASTNode(SplitType); + SplitASTNode(SplitType type) : type(type) {} SplitASTNode(SplitType, SplitAST const &, SplitAST const &); SplitASTNode(SplitType, std::vector const &); diff --git a/lib/utils/src/graph/traversal.cc b/lib/utils/src/graph/traversal.cc index f774f09a08..d0f5951858 100644 --- a/lib/utils/src/graph/traversal.cc +++ b/lib/utils/src/graph/traversal.cc @@ -35,7 +35,13 @@ udi &udi::operator++() { std::unordered_set outgoing = get_outgoing_edges(graph, {last}); for (DirectedEdge const &e : outgoing) { - stack.push_back(e.dst); + auto it = std::find(stack.begin(), stack.end(), e.dst); + if (it == stack.end()) { + stack.push_back(e.dst); + } else { + stack.erase(it); + stack.push_back(e.dst); + } } return *this; } @@ -146,12 +152,12 @@ bfi bfi::operator++(int) { bool bfi::operator==(bfi const &other) const { return this->q == other.q && this->seen == other.seen && - this->graph == other.graph; + is_ptr_equal(this->graph, other.graph); } bool bfi::operator!=(bfi const &other) const { return this->q != other.q || - this->seen != other.seen && this->graph != other.graph; + this->seen != other.seen && is_ptr_equal(this->graph, other.graph); } CheckedDFSView::CheckedDFSView(DiGraphView const &g, diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index 1b87b05ae3..6fa468f93d 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -7,6 +7,10 @@ namespace FlexFlow { UndirectedEdge::UndirectedEdge(Node const &src, Node const &dst) : smaller(std::min(smaller, bigger)), bigger(std::max(smaller, bigger)) {} +UndirectedEdgeQuery UndirectedEdgeQuery::all() { + return {matchall()}; +} + UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, UndirectedEdgeQuery const &rhs) { return { @@ -48,4 +52,32 @@ std::unordered_set UndirectedGraph::UndirectedGraph(std::unique_ptr _ptr) : ptr(std::move(_ptr)) {} +UndirectedGraph::operator UndirectedGraphView() const { + return UndirectedGraphView(this->ptr.get(), + should_only_be_used_internally_tag_t{}); +} + +std::unordered_set + UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { + return this->ptr->query_edges(q); +} + +std::unordered_set + UndirectedGraphView::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} + +// Set the shared_ptr's destructor to a nop so that effectively there is no +// ownership +UndirectedGraphView UndirectedGraphView::unsafe_create_without_ownership( + IUndirectedGraphView const &g) { + std::shared_ptr ptr( + (&g), [](IUndirectedGraphView const *) {}); + return UndirectedGraphView(ptr); +} + +UndirectedGraphView::operator GraphView() const { + return GraphView(this->ptr, should_only_be_used_internally_tag_t{}); +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 32a1bedbd0..863039d567 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -21,6 +21,39 @@ std::unordered_set return this->g.query_nodes(query); } +std::unordered_set + ContractNodeView::query_edges(DirectedEdgeQuery const &q) const { + return g.query_edges(q); +} + +std::unordered_set + ContractNodeView::query_nodes(NodeQuery const &q) const { + return g.query_nodes(q); +} + +std::unordered_set ViewOpenMultiDiGraphAsMultiDiGraph::query_nodes( + NodeQuery const &query) const { + return g.query_nodes(query); +} + +std::unordered_set ViewOpenMultiDiGraphAsMultiDiGraph::query_edges( + MultiDiEdgeQuery const &query) const { + + InputMultiDiEdgeQuery input_edge_query = InputMultiDiEdgeQuery::all(); + OutputMultiDiEdgeQuery output_edge_query = OutputMultiDiEdgeQuery::all(); + + OpenMultiDiEdgeQuery q{input_edge_query, query, output_edge_query}; + + std::unordered_set edges = g.query_edges(q); + + return transform( + filter(edges, + [](OpenMultiDiEdge const &edge) { + return holds_alternative(edge); + }), + [](OpenMultiDiEdge const &edge) { return get(edge); }); +} + DirectedEdge flipped(DirectedEdge const &e) { return {e.src, e.dst}; } @@ -69,6 +102,11 @@ std::unordered_set return this->g.query_edges(query_intersection(query, subgraph_query)); } +std::unordered_set + MultiDiSubgraphView::query_nodes(NodeQuery const &query) const { + return this->g.query_nodes(query_intersection(query, {this->subgraph_nodes})); +} + UndirectedGraphView view_subgraph(UndirectedGraphView const &g, std::unordered_set const &subgraph_nodes) { @@ -189,6 +227,7 @@ std::unordered_set std::unordered_set JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { + std::unordered_set srcs = this->query_nodes(query.srcs); std::unordered_set dsts = this->query_nodes(query.dsts); auto traced_srcs = this->joined_nodes.trace_nodes(srcs); diff --git a/lib/utils/test/src/doctest.h b/lib/utils/test/src/doctest.h index 891cc81d32..ff7c7cff20 100644 --- a/lib/utils/test/src/doctest.h +++ b/lib/utils/test/src/doctest.h @@ -5,6 +5,8 @@ #include #include +using namespace FlexFlow; + namespace doctest { template diff --git a/lib/utils/test/src/test_adjacency_multidigraph.cc b/lib/utils/test/src/test_adjacency_multidigraph.cc index 30f911f689..9006a3a650 100644 --- a/lib/utils/test/src/test_adjacency_multidigraph.cc +++ b/lib/utils/test/src/test_adjacency_multidigraph.cc @@ -1,53 +1,87 @@ #include "doctest.h" #include "utils/graph/adjacency_multidigraph.h" +#include "utils/graph/multidigraph_interfaces.h" -using namespace FlexFlow::utils; +using namespace FlexFlow; TEST_CASE("AdjacencyMultiDiGraph:basic_test") { - AdjacencyMultiDiGraph g; - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - MultiDiEdge e1{n1, n2, 0, 0}; - MultiDiEdge e2{n1, n3, 0, 1}; - MultiDiEdge e3{n3, n1, 1, 1}; - MultiDiEdge e4{n3, n2, 1, 1}; - MultiDiEdge e5{n2, n2, 2, 2}; - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - g.add_edge(e5); - - CHECK(g.query_nodes({}) == std::unordered_set{n1, n2, n3}); - - CHECK(g.query_nodes(NodeQuery{{n1, n3}}) == std::unordered_set{n1, n3}); - - CHECK(g.query_edges({}) == - std::unordered_set{e1, e2, e3, e4, e5}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n1)) == - std::unordered_set{e1, e2}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n1)) == - std::unordered_set{e3}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(1)) == - std::unordered_set{e3, e4}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(1)) == - std::unordered_set{e2, e3, e4}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n1, n2})) == - std::unordered_set{e1, e2, e5}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n1, n3})) == - std::unordered_set{e2, e3}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({1, 2})) == - std::unordered_set{e3, e4, e5}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({0, 2})) == - std::unordered_set{e1, e5}); + MultiDiGraph g = MultiDiGraph::create(); + + std::vector n = g.add_nodes(3); + std::vector p = g.add_node_ports(3); + + std::vector e = {{n[0], n[1], p[0], p[1]}, + {n[0], n[2], p[0], p[2]}, + {n[2], n[0], p[2], p[0]}, + {n[2], n[1], p[2], p[1]}}; + g.add_edges(e); + + CHECK(g.query_nodes(NodeQuery::all()) == + std::unordered_set{n[0], n[1], n[2]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all()) == + std::unordered_set{e[0], e[1], e[2], e[3]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( + {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( + {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( + {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( + {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_node(n1) - .with_dst_node(n3) - .with_src_idx(0) - .with_dst_idx(1)) == - std::unordered_set{e2}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(3)) == + .with_src_nodes({n[1]}) + .with_dst_nodes({n[2]}) + .with_src_idxs({p[1]}) + .with_dst_idxs({p[2]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == + std::unordered_set{e[1]}); + + SUBCASE("remove node") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(NodeQuery::all()) == + std::unordered_set{n[1], n[2]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all()) == + std::unordered_set{e[2], e[3]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == + std::unordered_set{e[2]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == + std::unordered_set{e[2]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges( + MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( + {n[1]})) == std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == + std::unordered_set{e[1]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + } } diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 299cfe086a..35534f5b3a 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -4,118 +4,160 @@ #include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" +#include +#include +#include +#include -using namespace FlexFlow::utils; +using namespace FlexFlow; TEST_CASE("MultiDiGraph") { - AdjacencyMultiDiGraph g; - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); - MultiDiEdge e1{n1, n4, 0, 0}; - MultiDiEdge e2{n1, n2, 0, 1}; - MultiDiEdge e3{n1, n3, 0, 0}; - MultiDiEdge e4{n2, n3, 0, 0}; - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - - CHECK(get_nodes(g) == std::unordered_set{n1, n2, n3, n4}); - CHECK(get_edges(g) == std::unordered_set{e1, e2, e3, e4}); - CHECK(get_incoming_edges(g, {n2, n4}) == - std::unordered_set{e1, e2}); - CHECK(get_incoming_edges(g, {n1}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n2, n4}) == std::unordered_set{e4}); - CHECK(get_predecessors(g, {n1, n2, n3}) == - std::unordered_map>{ - {n1, {}}, - {n2, {n1}}, - {n3, {n1, n2}}, - }); + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = g.add_nodes(4); + std::vector p = g.add_node_ports(4); + + std::vector e = { + {n[0], n[3], p[0], p[3]}, + {n[1], n[2], p[0], p[2]}, + {n[1], n[3], p[1], p[3]}, + {n[2], n[3], p[2], p[3]}, + }; + + g.add_edges(e); + + CHECK(get_incoming_edges(g, {n[1], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{e[3]}); + std::unordered_map> expected_result = + std::unordered_map>{ + {n[1], {}}, + {n[2], {n[1]}}, + {n[3], {n[0], n[1], n[2]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); } TEST_CASE("DiGraph") { - AdjacencyDiGraph g; - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - Node n4 = g.add_node(); - DirectedEdge e1{n1, n4}; - DirectedEdge e2{n1, n2}; - DirectedEdge e3{n1, n3}; - DirectedEdge e4{n2, n3}; - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - - CHECK(get_nodes(g) == std::unordered_set{n1, n2, n3, n4}); - CHECK(get_edges(g) == std::unordered_set{e1, e2, e3, e4}); - CHECK(get_incoming_edges(g, {n2, n4}) == - std::unordered_set{e1, e2}); - CHECK(get_outgoing_edges(g, {n2, n4}) == - std::unordered_set{e4}); - CHECK(get_predecessors(g, {n1, n2, n3}) == - std::unordered_map>{ - {n1, {}}, - {n2, {n1}}, - {n3, {n1, n2}}, - }); + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + {n[0], n[3]}, + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[2]}, + }; + add_edges(g, e); + + CHECK(get_incoming_edges(g, {n[2], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{}); + auto expected_result = std::unordered_map>{ + {n[1], {n[0]}}, + {n[2], {n[0], n[1]}}, + {n[3], {n[0]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); + + SUBCASE("get_imm_dominators") { + std::unordered_map result = + map_values(get_imm_dominators(g), + [&](optional const &node) { return *node; }); + + std::unordered_map expected_result = { + {n[2], n[0]}, + {n[1], n[0]}, + {n[3], n[0]}, + {n[0], n[0]}, + }; + CHECK(result == expected_result); + } + + SUBCASE("get_dominators") { + std::unordered_map> expected = { + {n[0], {n[0]}}, + {n[1], {n[1]}}, + {n[2], {n[0], n[2]}}, + {n[3], {n[3]}}, + }; + CHECK(get_dominators(g) == expected); + } + + SUBCASE("get_sinks") { + auto expected = std::unordered_set{n[2], n[3]}; + CHECK(get_sinks(g) == expected); + } + + SUBCASE("get_bfs") { + std::unordered_set start_points = std::unordered_set{n[0]}; + auto expected = std::vector{n[0], n[2], n[1], n[3]}; + CHECK(get_bfs_ordering(g, start_points) == expected); + } + + SUBCASE("get_predecessors") { + std::unordered_map> expected_result = { + {n[1], {n[0]}}, + {n[2], {n[0], n[1]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); + } } TEST_CASE("traversal") { - AdjacencyDiGraph g; + DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 4); - g.add_edge({n[0], n[1]}); - g.add_edge({n[1], n[2]}); - g.add_edge({n[2], n[3]}); + std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + add_edges(g, edges); - /* CHECK(get_incoming_edges(g, n[0]) == std::unordered_set{}); - */ CHECK(get_sources(g) == std::unordered_set{n[0]}); - CHECK(unchecked_dfs_ordering(g, {n[0]}) == + CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_bfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); - CHECK(bfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == true); SUBCASE("with root") { g.add_edge({n[3], n[2]}); - CHECK(dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == false); } SUBCASE("without root") { g.add_edge({n[3], n[0]}); - CHECK(dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == false); } - SUBCASE("nonlinear") { g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == false); + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs } } TEST_CASE("bfs") { - AdjacencyDiGraph g; + DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 7); - add_edges(g, - { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }); - - std::vector ordering = bfs_ordering(g, {n[0]}); + + std::vector e = { + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[6]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}, + {n[5], n[6]}, + {n[6], n[0]}, + }; + + add_edges(g, e); + + std::vector ordering = get_bfs_ordering(g, {n[0]}); auto CHECK_BEFORE = [&](int l, int r) { CHECK(index_of(ordering, n[l]).has_value()); CHECK(index_of(ordering, n[r]).has_value()); @@ -137,18 +179,17 @@ TEST_CASE("bfs") { CHECK_BEFORE(4, 5); } -TEST_CASE("topological_ordering") { - AdjacencyDiGraph g; - std::vector const n = add_nodes(g, 6); - add_edges(g, - {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}); - - std::vector ordering = topological_ordering(g); +TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {{n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[5]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); auto CHECK_BEFORE = [&](int l, int r) { CHECK(index_of(ordering, n[l]).has_value()); CHECK(index_of(ordering, n[r]).has_value()); @@ -163,3 +204,20 @@ TEST_CASE("topological_ordering") { CHECK_BEFORE(3, 4); CHECK_BEFORE(4, 5); } + +TEST_CASE("get_weakly_connected_components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + + std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + + add_edges(g, edges); + std::vector> expected_components = { + {n[0]}, + {n[1]}, + {n[2]}, + {n[3]}, + }; + + CHECK(get_weakly_connected_components(g) == expected_components); +} \ No newline at end of file