diff --git a/lib/utils/include/utils/graph/adjacency_digraph.h b/lib/utils/include/utils/graph/adjacency_digraph.h index dfdab31d14..49e5f7553a 100644 --- a/lib/utils/include/utils/graph/adjacency_digraph.h +++ b/lib/utils/include/utils/graph/adjacency_digraph.h @@ -29,7 +29,7 @@ class AdjacencyDiGraph : public IDiGraph { private: using ContentsType = std::unordered_map>; - + AdjacencyDiGraph(std::size_t next_node_idx, ContentsType adjacency) : next_node_idx(next_node_idx), adjacency(adjacency) {} std::size_t next_node_idx = 0; diff --git a/lib/utils/include/utils/graph/adjacency_multidigraph.h b/lib/utils/include/utils/graph/adjacency_multidigraph.h index d5d9ee5d1a..a1be0dc112 100644 --- a/lib/utils/include/utils/graph/adjacency_multidigraph.h +++ b/lib/utils/include/utils/graph/adjacency_multidigraph.h @@ -21,7 +21,8 @@ class AdjacencyMultiDiGraph : public IMultiDiGraph { std::unordered_set query_nodes(NodeQuery const &) const override; AdjacencyMultiDiGraph *clone() const override { - return new AdjacencyMultiDiGraph(this->next_node_idx, this->next_node_port, this->adjacency); + return new AdjacencyMultiDiGraph( + this->next_node_idx, this->next_node_port, this->adjacency); } private: @@ -31,9 +32,12 @@ class AdjacencyMultiDiGraph : public IMultiDiGraph { Node, std::unordered_map>>>; - AdjacencyMultiDiGraph(std::size_t next_node_idx, std::size_t next_node_port, ContentsType const & adjacency) - : next_node_idx(next_node_idx), next_node_port(next_node_port), adjacency(adjacency) {} - + AdjacencyMultiDiGraph(std::size_t next_node_idx, + std::size_t next_node_port, + ContentsType const &adjacency) + : next_node_idx(next_node_idx), next_node_port(next_node_port), + adjacency(adjacency) {} + std::size_t next_node_idx = 0; std::size_t next_node_port = 0; ContentsType adjacency; diff --git a/lib/utils/include/utils/graph/digraph.h b/lib/utils/include/utils/graph/digraph.h index 17d7d6dc1a..a30e023a38 100644 --- a/lib/utils/include/utils/graph/digraph.h +++ b/lib/utils/include/utils/graph/digraph.h @@ -41,7 +41,7 @@ struct IDiGraphView : public IGraphView { IDiGraphView &operator=(IDiGraphView const &) = delete; virtual std::unordered_set query_edges(EdgeQuery const &) const = 0; - virtual ~IDiGraphView()=default; + virtual ~IDiGraphView() = default; protected: IDiGraphView() = default; @@ -76,8 +76,9 @@ struct DiGraphView { return DiGraphView(std::make_shared(std::forward(args)...)); } - DiGraphView(std::shared_ptr ptr): ptr(ptr) {} + DiGraphView(std::shared_ptr ptr) : ptr(ptr) {} static DiGraphView unsafe_create(IDiGraphView const &graphView); + private: friend DiGraphView unsafe_create(IDiGraphView const &); diff --git a/lib/utils/include/utils/graph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph.h index 2e85c523b4..62e7c16f1d 100644 --- a/lib/utils/include/utils/graph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph.h @@ -29,7 +29,7 @@ struct MultiDiGraphView { return MultiDiGraphView( std::make_shared(std::forward(args)...)); } - MultiDiGraphView(std::shared_ptr ptr):ptr(ptr){} + MultiDiGraphView(std::shared_ptr ptr) : ptr(ptr) {} static MultiDiGraphView unsafe_create(IMultiDiGraphView const &); private: @@ -74,7 +74,7 @@ struct MultiDiGraph { } private: - MultiDiGraph(std::unique_ptr ptr):ptr(std::move(ptr)){} + MultiDiGraph(std::unique_ptr ptr) : ptr(std::move(ptr)) {} private: cow_ptr_t ptr; diff --git a/lib/utils/include/utils/graph/node.h b/lib/utils/include/utils/graph/node.h index 039c752a07..2c7e4d38cd 100644 --- a/lib/utils/include/utils/graph/node.h +++ b/lib/utils/include/utils/graph/node.h @@ -66,7 +66,7 @@ struct GraphView { return GraphView(std::make_shared(std::forward(args)...)); } - GraphView(std::shared_ptr ptr):ptr(ptr){} + GraphView(std::shared_ptr ptr) : ptr(ptr) {} private: std::shared_ptr ptr; @@ -75,7 +75,7 @@ CHECK_RC_COPY_VIRTUAL_COMPLIANT(IGraphView); struct IGraph : IGraphView { IGraph(IGraph const &) = delete; - IGraph()=default; + IGraph() = default; IGraph &operator=(IGraph const &) = delete; virtual Node add_node() = 0; diff --git a/lib/utils/include/utils/graph/open_graph_interfaces.h b/lib/utils/include/utils/graph/open_graph_interfaces.h index d8794ec5fd..6ccfbab2a3 100644 --- a/lib/utils/include/utils/graph/open_graph_interfaces.h +++ b/lib/utils/include/utils/graph/open_graph_interfaces.h @@ -25,7 +25,8 @@ struct OutputMultiDiEdge { }; FF_VISITABLE_STRUCT(OutputMultiDiEdge, uid, src, srcIdx); -using OpenMultiDiEdge =variant; +using OpenMultiDiEdge = + variant; using DownwardOpenMultiDiEdge = variant; diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index c6b86ead51..e9e0479692 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -36,7 +36,8 @@ struct OpenMultiDiGraphView { } private: - OpenMultiDiGraphView(std::shared_ptr ptr):ptr(ptr){} + OpenMultiDiGraphView(std::shared_ptr ptr) + : ptr(ptr) {} std::shared_ptr ptr; }; diff --git a/lib/utils/include/utils/graph/undirected.h b/lib/utils/include/utils/graph/undirected.h index 54e1d30930..ff952597c6 100644 --- a/lib/utils/include/utils/graph/undirected.h +++ b/lib/utils/include/utils/graph/undirected.h @@ -44,7 +44,7 @@ struct IUndirectedGraphView : public IGraphView { virtual std::unordered_set query_edges(UndirectedEdgeQuery const &) const = 0; - virtual ~IUndirectedGraphView()=default; + virtual ~IUndirectedGraphView() = default; protected: IUndirectedGraphView() = default; @@ -81,8 +81,8 @@ struct UndirectedGraphView { std::make_shared(std::forward(args)...)); } - UndirectedGraphView(std::shared_ptr ptr): ptr(ptr) { - } + UndirectedGraphView(std::shared_ptr ptr) + : ptr(ptr) {} private: friend UndirectedGraphView unsafe(IUndirectedGraphView const &); diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views.h index e10b74f33c..2596721aab 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views.h @@ -61,8 +61,8 @@ struct DiSubgraphView : public IDiGraphView { struct MultiDiSubgraphView : public IMultiDiGraphView { public: MultiDiSubgraphView() = delete; - explicit MultiDiSubgraphView(MultiDiGraphView const & g, - std::unordered_set const & subgraph_nodes); + explicit MultiDiSubgraphView(MultiDiGraphView const &g, + std::unordered_set const &subgraph_nodes); std::unordered_set query_edges(MultiDiEdgeQuery const &) const override; @@ -87,8 +87,8 @@ enum class LRDirection { LEFT, RIGHT }; struct JoinNodeKey { JoinNodeKey() = delete; - JoinNodeKey(Node const & node, LRDirection direction): - node(node), direction(direction) {} + JoinNodeKey(Node const &node, LRDirection direction) + : node(node), direction(direction) {} bool operator==(JoinNodeKey const &) const; bool operator<(JoinNodeKey const &) const; @@ -218,9 +218,10 @@ struct SingleSourceNodeView : public IDiGraphView { struct ContractNodeView : public IDiGraphView { ContractNodeView() = delete; - explicit ContractNodeView(DiGraphView const & g, + explicit ContractNodeView(DiGraphView const &g, Node const &removed, - Node const &into): g(g), from(removed), to(into) {} + Node const &into) + : g(g), from(removed), to(into) {} std::unordered_set query_edges(DirectedEdgeQuery const &) const override; @@ -302,7 +303,7 @@ struct ViewMultiDiGraphAsDiGraph : public IDiGraphView { struct ViewOpenMultiDiGraphAsMultiDiGraph : public IMultiDiGraphView { public: ViewOpenMultiDiGraphAsMultiDiGraph() = delete; - explicit ViewOpenMultiDiGraphAsMultiDiGraph(OpenMultiDiGraphView const & g) + explicit ViewOpenMultiDiGraphAsMultiDiGraph(OpenMultiDiGraphView const &g) : g(g) {} std::unordered_set diff --git a/lib/utils/src/graph/adjacency_digraph.cc b/lib/utils/src/graph/adjacency_digraph.cc index 442c9cca5e..1438edc78b 100644 --- a/lib/utils/src/graph/adjacency_digraph.cc +++ b/lib/utils/src/graph/adjacency_digraph.cc @@ -10,7 +10,6 @@ Node AdjacencyDiGraph::add_node() { return node; } - void AdjacencyDiGraph::add_node_unsafe(Node const &node) { adjacency[node]; this->next_node_idx = std::max(this->next_node_idx, node.value() + 1); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index b49fadb56a..6ecdb0049f 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -93,13 +93,15 @@ std::size_t num_nodes(GraphView const &g) { return get_nodes(g).size(); } -DiGraphView contract_node(DiGraphView const &g , Node const &from, Node const &into) { +DiGraphView + contract_node(DiGraphView const &g, Node const &from, Node const &into) { return DiGraphView::create(g, from, into); } -DiGraphView apply_contraction(DiGraphView const & g, std::unordered_map const & nodes){ - DiGraphView contractedView = g; - for(auto const & kv : nodes){ +DiGraphView apply_contraction(DiGraphView const &g, + std::unordered_map const &nodes) { + DiGraphView contractedView = g; + for (auto const &kv : nodes) { Node from = kv.first; Node into = kv.second; contractedView = contract_node(contractedView, from, into); @@ -216,9 +218,9 @@ std::unordered_set return to_directed_edges(get_outgoing_edges(multidigraph_view, dsts)); } -std::unordered_set get_outgoing_edges(DiGraphView const & g, - Node const & n){ - return get_outgoing_edges(g, std::unordered_set{n}); +std::unordered_set get_outgoing_edges(DiGraphView const &g, + Node const &n) { + return get_outgoing_edges(g, std::unordered_set{n}); } std::unordered_map> @@ -269,26 +271,24 @@ std::vector return {bfs_view.begin(), bfs_view.end()}; } - -std::unordered_set get_sinks(DiGraphView const & g){ - std::unordered_set dsts ; - for(Node const &n : get_nodes(g)) { +std::unordered_set get_sinks(DiGraphView const &g) { + std::unordered_set dsts; + for (Node const &n : get_nodes(g)) { auto outgoing = get_outgoing_edges(g, n); - if(outgoing.size() == 0){ + if (outgoing.size() == 0) { dsts.insert(n); } } return dsts; } -std::unordered_set get_sinks(MultiDiGraphView const & g){ +std::unordered_set get_sinks(MultiDiGraphView const &g) { DiGraphView digraph_view = as_digraph(g); return get_sinks(digraph_view); } -DiGraphView flipped(DiGraphView const & g) { +DiGraphView flipped(DiGraphView const &g) { return DiGraphView::create(g); - } std::unordered_set get_sources(DiGraphView const &g) { @@ -381,13 +381,16 @@ std::vector get_edge_topological_ordering(DiGraphView const &g) { } /* -transform(get_outgoing_edges(g, n), [](DirectedEdge const &n) { return n.dst; }) return std::unorder_set -set_union return st::unorder_set -as_vector convert the std::unorder_set to std::vector +transform(get_outgoing_edges(g, n), [](DirectedEdge const &n) { return n.dst; }) +return std::unorder_set set_union return st::unorder_set as_vector +convert the std::unorder_set to std::vector */ -std::vector get_neighbors(DiGraphView const & g, Node const & n) { - return as_vector(set_union( transform(get_outgoing_edges(g, n), [](DirectedEdge const &n) { return n.dst; }), transform(get_incoming_edges(g, n), [](DirectedEdge const &n) { return n.src; }) )); - +std::vector get_neighbors(DiGraphView const &g, Node const &n) { + return as_vector( + set_union(transform(get_outgoing_edges(g, n), + [](DirectedEdge const &n) { return n.dst; }), + transform(get_incoming_edges(g, n), + [](DirectedEdge const &n) { return n.src; }))); } std::vector @@ -500,23 +503,24 @@ optional imm_post_dominator(MultiDiGraphView const &g, Node const &n) { return get_imm_post_dominators(g).at(n); } -optional get_imm_post_dominator(DiGraphView const & g, Node const & n) { +optional get_imm_post_dominator(DiGraphView const &g, Node const &n) { return get_imm_post_dominators(g).at(n); } +optional get_imm_post_dominator(DiGraphView const &g, + std::unordered_set const &nodes) { + std::unordered_set commonDoms = + get_post_dominators(g).at(get_first(nodes)); -optional get_imm_post_dominator(DiGraphView const & g, std::unordered_set const & nodes ){ - std::unordered_set commonDoms = get_post_dominators(g).at(get_first(nodes)); - - for(Node const & node : nodes){ - commonDoms = intersection(get_post_dominators(g).at(node), commonDoms); - } + for (Node const &node : nodes) { + commonDoms = intersection(get_post_dominators(g).at(node), commonDoms); + } - if (!commonDoms.empty()) { - return get_first(commonDoms); - } else { - return tl::nullopt; - } + if (!commonDoms.empty()) { + return get_first(commonDoms); + } else { + return tl::nullopt; + } } std::pair @@ -581,44 +585,45 @@ MultiDiGraphView as_multidigraph(OpenMultiDiGraphView const &g) { } std::vector> - get_weakly_connected_components(DiGraphView const & g) { - std::unordered_set start_pointes = get_sources(g); - std::vector dfs_order = get_dfs_ordering(g, start_pointes); + get_weakly_connected_components(DiGraphView const &g) { + std::unordered_set start_pointes = get_sources(g); + std::vector dfs_order = get_dfs_ordering(g, start_pointes); - std::vector> components; - std::unordered_set visited; + std::vector> components; + std::unordered_set visited; - for (Node const & node : dfs_order) { - if (contains(visited, node)) { - continue; // Skip nodes already in a component - } - - std::unordered_set component; - std::stack stack; - stack.push(node); + for (Node const &node : dfs_order) { + if (contains(visited, node)) { + continue; // Skip nodes already in a component + } - while (!stack.empty()) { - Node current = stack.top(); - stack.pop(); + std::unordered_set component; + std::stack stack; + stack.push(node); - if (contains(visited, current)) { - continue; - } + while (!stack.empty()) { + Node current = stack.top(); + stack.pop(); - component.insert(current); - visited.insert(current); + if (contains(visited, current)) { + continue; + } - std::vector neighbors = get_neighbors(g, current); // Replace with your own function to get neighbors + component.insert(current); + visited.insert(current); - for (Node const & neighbor : neighbors) { - stack.push(neighbor); - } - } + std::vector neighbors = get_neighbors( + g, current); // Replace with your own function to get neighbors - components.push_back(std::move(component)); + for (Node const &neighbor : neighbors) { + stack.push(neighbor); + } } - return components; + components.push_back(std::move(component)); + } + + return components; } } // namespace FlexFlow diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index f5627ac4ce..f45078cca2 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -53,35 +53,40 @@ bool DiGraphView::operator!=(DiGraphView const &other) const { return ptr != other.ptr; } -std::unordered_set DiGraphView::query_nodes(NodeQuery const& q) const { +std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { return this->ptr->query_nodes(q); } -std::unordered_set DiGraphView::query_edges(EdgeQuery const & query) const { +std::unordered_set + DiGraphView::query_edges(EdgeQuery const &query) const { return ptr->query_edges(query); } /* unsafe_create: -1 use the graphView to creae the std::shared_ptr ptr, and define a empty lambda function to delete the ptr -2 we use this ptr to create a DiGraphView, this DiGraphView is read-only. It creates a DiGraphView object that is not responsible for ownership management +1 use the graphView to creae the std::shared_ptr ptr, and +define a empty lambda function to delete the ptr 2 we use this ptr to create a +DiGraphView, this DiGraphView is read-only. It creates a DiGraphView object that +is not responsible for ownership management */ DiGraphView unsafe_create(IDiGraphView const &graphView) { std::shared_ptr ptr((&graphView), - [](IDiGraphView const *){}); + [](IDiGraphView const *) {}); return DiGraphView(ptr); } -DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, DirectedEdgeQuery const &rhs){ +DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, + DirectedEdgeQuery const &rhs) { assert(lhs != tl::nullopt); assert(rhs != tl::nullopt); - assert (lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() && rhs.dsts.has_value()); + assert(lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() && + rhs.dsts.has_value()); - tl::optional> srcs_t1 = intersection(*lhs.srcs, *rhs.srcs); - tl::optional> dsts_t1 = intersection(*lhs.dsts, *rhs.dsts); + tl::optional> srcs_t1 = + intersection(*lhs.srcs, *rhs.srcs); + tl::optional> dsts_t1 = + intersection(*lhs.dsts, *rhs.dsts); return DirectedEdgeQuery(srcs_t1, dsts_t1); } - - } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index eed6b1e245..716124924a 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -21,23 +21,22 @@ MultiDiEdgeQuery } <<<<<<< HEAD -std::ostream& operator<<(std::ostream& os, const MultiDiEdge& edge) { - return os<<"MultiDiEdge{"<> const &srcs, - tl::optional> const &dsts, - tl::optional> const &srcIdxs , - tl::optional> const &dstIdxs ) - :srcs(srcs), dsts(dsts),srcIdxs(srcIdxs), dstIdxs(dstIdxs) -{} +MultiDiEdgeQuery::MultiDiEdgeQuery( + tl::optional> const &srcs, + tl::optional> const &dsts, + tl::optional> const &srcIdxs, + tl::optional> const &dstIdxs) + : srcs(srcs), dsts(dsts), srcIdxs(srcIdxs), dstIdxs(dstIdxs) {} MultiDiEdgeQuery MultiDiEdgeQuery::with_src_node(Node const &n) const { return this->with_src_nodes({n}); } - MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { MultiDiEdgeQuery e = *this; @@ -48,10 +47,14 @@ MultiDiEdgeQuery return e; } -MultiDiEuery query_intersection(MultiDiEdgeQuery const &lhs, MultiDiEdgeQuery const &rhs){ - assert (lhs.srcs.has_value() &&dgeQ lhs.dsts.has_value() && rhs.srcs.has_value() && rhs.dsts.has_value()); - tl::optional> srcs = intersection(*lhs.srcs, *rhs.srcs); - tl::optional> dsts = intersection(*lhs.dsts, *rhs.dsts); +MultiDiEuery query_intersection(MultiDiEdgeQuery const &lhs, + MultiDiEdgeQuery const &rhs) { + assert(lhs.srcs.has_value() && dgeQ lhs.dsts.has_value() && + rhs.srcs.has_value() && rhs.dsts.has_value()); + tl::optional> srcs = + intersection(*lhs.srcs, *rhs.srcs); + tl::optional> dsts = + intersection(*lhs.dsts, *rhs.dsts); return MultiDiEdgeQuery(srcs, dsts); } @@ -86,15 +89,17 @@ MultiDiEdgeQuery MultiDiEdgeQuery::all() { matchall()}; } -std::unordered_set MultiDiGraphView::query_nodes(NodeQuery const & q) const { +std::unordered_set + MultiDiGraphView::query_nodes(NodeQuery const &q) const { return this->ptr->query_nodes(q); } -std::unordered_set MultiDiGraphView::query_edges(MultiDiEdgeQuery const &q) const { +std::unordered_set + MultiDiGraphView::query_edges(MultiDiEdgeQuery const &q) const { return this->ptr->query_edges(q); } -NodePort IMultiDiGraph::add_node_port(){ +NodePort IMultiDiGraph::add_node_port() { NodePort np{this->next_nodeport_idx}; this->next_nodeport_idx += 1; return np; @@ -109,12 +114,11 @@ MultiDiGraphView::operator GraphView() const { } MultiDiGraphView unsafe_create(IMultiDiGraphView const &graphView) { - std::shared_ptr ptr((&graphView), - [](IMultiDiGraphView const *ptr) {}); + std::shared_ptr ptr( + (&graphView), [](IMultiDiGraphView const *ptr) {}); return MultiDiGraphView(ptr); } - void swap(MultiDiGraphView &lhs, MultiDiGraphView &rhs) { using std::swap; @@ -139,7 +143,7 @@ NodePort MultiDiGraph::add_node_port() { return this->ptr.get_mutable()->add_node_port(); } -void MultiDiGraph::add_node_port_unsafe(NodePort const & np) { +void MultiDiGraph::add_node_port_unsafe(NodePort const &np) { return this->ptr.get_mutable()->add_node_port_unsafe(np); } @@ -161,10 +165,10 @@ void MultiDiGraph::remove_edge(MultiDiEdge const &e) { std::unordered_set MultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->ptr->query_edges(q); } -std::unordered_set MultiDiGraph::query_nodes(NodeQuery const & q) const { +std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { return this->ptr->query_nodes(q); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index e1c9667227..45640b7b5f 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -3,8 +3,8 @@ namespace FlexFlow { -std::ostream &operator<<(std::ostream & os, Node const &node){ - return os << fmt::format("Node({})", node.value()); +std::ostream &operator<<(std::ostream &os, Node const &node) { + return os << fmt::format("Node({})", node.value()); } NodeQuery::NodeQuery(std::unordered_set const &nodes) @@ -13,19 +13,19 @@ NodeQuery::NodeQuery(std::unordered_set const &nodes) NodeQuery::NodeQuery(tl::optional> const &nodes) : nodes(nodes) {} -NodeQuery query_intersection(NodeQuery const & lhs, NodeQuery const & rhs){ - assert(lhs != tl::nullopt && rhs != tl::nullopt); - return intersection(*lhs.nodes, *rhs.nodes) ; +NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { + assert(lhs != tl::nullopt && rhs != tl::nullopt); + return intersection(*lhs.nodes, *rhs.nodes); } -std::unordered_set GraphView::query_nodes(NodeQuery const & g) const { - return this->ptr->query_nodes(g); -} +std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { + return this->ptr->query_nodes(g); +} GraphView GraphView::unsafe_create(IGraphView const &graphView) { - std::shared_ptr ptr((&graphView), - [](IGraphView const *) { }); - return GraphView(ptr); - } + std::shared_ptr ptr((&graphView), + [](IGraphView const *) {}); + return GraphView(ptr); +} } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index e597f01a92..ad72d22cd2 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -2,13 +2,14 @@ namespace FlexFlow { - - std::unordered_set OpenMultiDiGraphView::query_nodes(NodeQuery const & q) const{ - return this->ptr->query_nodes(q); - } - std::unordered_set OpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const & q) const{ - return this->ptr->query_edges(q); - } +std::unordered_set + OpenMultiDiGraphView::query_nodes(NodeQuery const &q) const { + return this->ptr->query_nodes(q); +} +std::unordered_set + OpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { + return this->ptr->query_edges(q); +} OpenMultiDiGraph::OpenMultiDiGraph(OpenMultiDiGraph const &other) : ptr(other.ptr->clone()) {} diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 308b5610ba..21997e8a06 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -131,9 +131,14 @@ SplitAST parallel_decomposition(DiGraphView const &g) { return split; } -SplitASTNode::SplitASTNode(SplitType type, SplitAST const & lhs, SplitAST const & rhs):type(type), children({lhs, rhs}){} - -SplitASTNode::SplitASTNode(SplitType type, std::vector const & children):type(type), children(children){} +SplitASTNode::SplitASTNode(SplitType type, + SplitAST const &lhs, + SplitAST const &rhs) + : type(type), children({lhs, rhs}) {} + +SplitASTNode::SplitASTNode(SplitType type, + std::vector const &children) + : type(type), children(children) {} struct FlattenAST { void add_flattened_child_to_parent(SplitASTNode &parent, diff --git a/lib/utils/src/graph/serialparallel_internal.h b/lib/utils/src/graph/serialparallel_internal.h index 1f8b0da709..4c9c22d95e 100644 --- a/lib/utils/src/graph/serialparallel_internal.h +++ b/lib/utils/src/graph/serialparallel_internal.h @@ -18,7 +18,7 @@ struct SplitASTNode; using SplitAST = mpark::variant; struct SplitASTNode { - SplitASTNode(SplitType type): type(type) {} + SplitASTNode(SplitType type) : type(type) {} SplitASTNode(SplitType, SplitAST const &, SplitAST const &); SplitASTNode(SplitType, std::vector const &); diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index 2497369910..b35c19e062 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -48,19 +48,21 @@ std::unordered_set UndirectedGraph::UndirectedGraph(std::unique_ptr _ptr) : ptr(std::move(_ptr)) {} -UndirectedGraph:: operator UndirectedGraphView() const { - return UndirectedGraphView(ptr.get()); +UndirectedGraph::operator UndirectedGraphView() const { + return UndirectedGraphView(ptr.get()); } -std::unordered_set UndirectedGraphView::query_edges(UndirectedEdgeQuery const& q) const { +std::unordered_set + UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { return this->ptr->query_edges(q); } -std::unordered_set UndirectedGraphView::query_nodes(NodeQuery const & q) const { +std::unordered_set + UndirectedGraphView::query_nodes(NodeQuery const &q) const { return this->ptr->query_nodes(q); } -UndirectedGraphView::operator GraphView const&() const { +UndirectedGraphView::operator GraphView const &() const { return GraphView(this->ptr); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index b87120044f..b6bc8a7085 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -21,40 +21,41 @@ std::unordered_set return this->g.query_nodes(query); } -bool JoinNodeKey::operator==(JoinNodeKey const & jnk) const { - return node== jnk.node && direction == jnk.direction; +bool JoinNodeKey::operator==(JoinNodeKey const &jnk) const { + return node == jnk.node && direction == jnk.direction; } -std::unordered_set ContractNodeView::query_edges(DirectedEdgeQuery const & q) const { - return g.query_edges(q); +std::unordered_set + ContractNodeView::query_edges(DirectedEdgeQuery const &q) const { + return g.query_edges(q); } -std::unordered_set ContractNodeView::query_nodes(NodeQuery const& q) const { +std::unordered_set + ContractNodeView::query_nodes(NodeQuery const &q) const { return g.query_nodes(q); } - -std::unordered_set ViewOpenMultiDiGraphAsMultiDiGraph::query_nodes(NodeQuery const & query) const { +std::unordered_set ViewOpenMultiDiGraphAsMultiDiGraph::query_nodes( + NodeQuery const &query) const { return g.query_nodes(query); } -std::unordered_set - ViewOpenMultiDiGraphAsMultiDiGraph::query_edges(MultiDiEdgeQuery const & query) const { - OpenMultiDiEdgeQuery q; - q.standard_edge_query = query; - std::unordered_set edges = g.query_edges(q); - std::unordered_set result; - - for (const auto& edge : edges) { - if(holds_alternative(edge)){ - result.insert(get(edge)); - } - - } +std::unordered_set ViewOpenMultiDiGraphAsMultiDiGraph::query_edges( + MultiDiEdgeQuery const &query) const { + OpenMultiDiEdgeQuery q; + q.standard_edge_query = query; + std::unordered_set edges = g.query_edges(q); + std::unordered_set result; - return result; + for (auto const &edge : edges) { + if (holds_alternative(edge)) { + result.insert(get(edge)); + } } + return result; +} + DirectedEdge flipped(DirectedEdge const &e) { return {e.src, e.dst}; } @@ -229,8 +230,10 @@ std::unordered_set std::unordered_set JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { <<<<<<< HEAD - std::unordered_set srcs = query.srcs.value_or(get_nodes(unsafe_create(*this))); - std::unordered_set dsts = query.dsts.value_or(get_nodes(unsafe_create(*this))); + std::unordered_set srcs = + query.srcs.value_or(get_nodes(unsafe_create(*this))); + std::unordered_set dsts = + query.dsts.value_or(get_nodes(unsafe_create(*this))); ======= std::unordered_set srcs = this->query_nodes(query.srcs); std::unordered_set dsts = this->query_nodes(query.dsts); @@ -273,8 +276,10 @@ std::unordered_set std::unordered_set JoinedMultiDigraphView::query_edges(MultiDiEdgeQuery const &query) const { <<<<<<< HEAD - std::unordered_set srcs = query.srcs.value_or(get_nodes(unsafe_create(*this))); - std::unordered_set dsts = query.dsts.value_or(get_nodes(unsafe_create(*this))); + std::unordered_set srcs = + query.srcs.value_or(get_nodes(unsafe_create(*this))); + std::unordered_set dsts = + query.dsts.value_or(get_nodes(unsafe_create(*this))); ======= std::unordered_set srcs = this->query_nodes(query.srcs); std::unordered_set dsts = this->query_nodes(query.dsts); diff --git a/lib/utils/test/src/test_adjacency_multidigraph.cc b/lib/utils/test/src/test_adjacency_multidigraph.cc index 1278f6656b..badeb7d572 100644 --- a/lib/utils/test/src/test_adjacency_multidigraph.cc +++ b/lib/utils/test/src/test_adjacency_multidigraph.cc @@ -10,7 +10,7 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") { Node n2 = g.add_node(); NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); + NodePort p2 = g.add_node_port(); MultiDiEdge e1{n0, n1, p0, p1}; MultiDiEdge e2{n0, n2, p0, p2}; MultiDiEdge e3{n2, n0, p2, p0}; @@ -24,11 +24,10 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") { CHECK(g.query_nodes(NodeQuery{{n0, n2}}) == std::unordered_set{n0, n2}); - CHECK(g.query_edges({}) == - std::unordered_set{e1, e2, e3, e4}); + CHECK(g.query_edges({}) == std::unordered_set{e1, e2, e3, e4}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n1)) == - std::unordered_set{}); + std::unordered_set{}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n1)) == std::unordered_set{e1, e4}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(p1)) == @@ -60,7 +59,7 @@ TEST_CASE("AdjacencyMultiDiGraph:remove_node") { Node n2 = g.add_node(); NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); + NodePort p2 = g.add_node_port(); MultiDiEdge e1{n0, n1, p0, p1}; MultiDiEdge e2{n0, n2, p0, p2}; MultiDiEdge e3{n2, n0, p2, p0}; @@ -74,20 +73,18 @@ TEST_CASE("AdjacencyMultiDiGraph:remove_node") { CHECK(g.query_nodes({}) == std::unordered_set{n1, n2}); - CHECK(g.query_edges({}) == - std::unordered_set{e3, e4}); + CHECK(g.query_edges({}) == std::unordered_set{e3, e4}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0)) == - std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n0)) == - std::unordered_set{e3}); + std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p2})) == - std::unordered_set{e3, e4}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0})) == + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n0)) == std::unordered_set{e3}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p2})) == + std::unordered_set{e3, e4}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0})) == + std::unordered_set{e3}); } TEST_CASE("AdjacencyMultiDiGraph:remove_edge") { @@ -97,7 +94,7 @@ TEST_CASE("AdjacencyMultiDiGraph:remove_edge") { Node n2 = g.add_node(); NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); - NodePort p2 = g.add_node_port(); + NodePort p2 = g.add_node_port(); MultiDiEdge e1{n0, n1, p0, p1}; MultiDiEdge e2{n0, n2, p0, p2}; MultiDiEdge e3{n2, n0, p2, p0}; @@ -108,15 +105,13 @@ TEST_CASE("AdjacencyMultiDiGraph:remove_edge") { g.add_edge(e4); g.remove_edge(e1); - - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0).with_dst_node(n1)) == - std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0).with_dst_node( + n1)) == std::unordered_set{}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n2)) == - std::unordered_set{e2}); + std::unordered_set{e2}); CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(p2)) == - std::unordered_set{e3, e4}); - + std::unordered_set{e3, e4}); } diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 8adca1340c..7d64f6f0d1 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -1,9 +1,9 @@ #include "doctest.h" +#include "utils/containers.h" #include "utils/graph/adjacency_digraph.h" #include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" -#include "utils/containers.h" #include using namespace FlexFlow; @@ -17,7 +17,7 @@ TEST_CASE("MultiDiGraph") { NodePort p0 = g.add_node_port(); NodePort p1 = g.add_node_port(); NodePort p2 = g.add_node_port(); - NodePort p3 = g.add_node_port(); + NodePort p3 = g.add_node_port(); MultiDiEdge e0{n0, n3, p0, p3}; MultiDiEdge e1{n1, n2, p0, p2}; MultiDiEdge e2{n1, n3, p1, p3}; @@ -35,12 +35,12 @@ TEST_CASE("MultiDiGraph") { CHECK(get_outgoing_edges(g, {n2, n3}) == std::unordered_set{e3}); auto res = get_predecessors(g, {n1, n2, n3}); auto expected_result = std::unordered_map>{ - {n1, {}}, - {n2, {n1}}, - {n3, {n0,n1, n2}}, - }; + {n1, {}}, + {n2, {n1}}, + {n3, {n0, n1, n2}}, + }; - for(auto kv : res) { + for (auto kv : res) { CHECK(expected_result[kv.first] == kv.second); } } @@ -63,15 +63,14 @@ TEST_CASE("DiGraph") { CHECK(g.query_edges({}) == std::unordered_set{e0, e1, e2, e3}); CHECK(get_incoming_edges(g, {n2, n3}) == std::unordered_set{e0, e2, e3}); - CHECK(get_outgoing_edges(g, {n2, n3}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n1, {n0}}, - {n2, {n0, n1}}, - {n3, {n0}}, + CHECK(get_outgoing_edges(g, {n2, n3}) == std::unordered_set{}); + auto expected_result = std::unordered_map>{ + {n1, {n0}}, + {n2, {n0, n1}}, + {n3, {n0}}, }; auto res = get_predecessors(g, {n1, n2, n3}); - for(auto kv : res) { + for (auto kv : res) { CHECK(expected_result[kv.first] == kv.second); } } @@ -88,46 +87,49 @@ TEST_CASE("traversal") { CHECK(get_sources(g) == std::unordered_set{n[0]}); CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_bfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == true); SUBCASE("with root") { g.add_edge({n[3], n[2]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == false); } SUBCASE("without root") { g.add_edge({n[3], n[0]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); CHECK(is_acyclic(g) == false); } -// SUBCASE("nonlinear") { -// g.add_edge({n[1], n[3]}); -// CHECK(is_acyclic(g) == true);//TODO, maybe a bug about the unchecked_dfs -// } + // SUBCASE("nonlinear") { + // g.add_edge({n[1], n[3]}); + // CHECK(is_acyclic(g) == true);//TODO, maybe a bug about the + // unchecked_dfs + // } } TEST_CASE("bfs") { DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector edges= - { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - for(DirectedEdge edge: edges) { + std::vector const n = add_nodes(g, 7); + + std::vector edges = { + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[6]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}, + {n[5], n[6]}, + {n[6], n[0]}, + }; + + for (DirectedEdge edge : edges) { g.add_edge(edge); } std::vector ordering = get_bfs_ordering(g, {n[0]}); @@ -154,19 +156,18 @@ TEST_CASE("bfs") { TEST_CASE("topological_ordering") { DiGraph g = DiGraph::create(); - std::vector n ; - for(int i = 0; i < 6; i++) { + std::vector n; + for (int i = 0; i < 6; i++) { n.push_back(g.add_node()); } - std::vector edges = - {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - - for(DirectedEdge const & edge: edges) { + std::vector edges = {{n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[5]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}}; + + for (DirectedEdge const &edge : edges) { g.add_edge(edge); }