From 8f48d5e649ad86a3def676095cb5dd7ec619b9c2 Mon Sep 17 00:00:00 2001 From: wmdi Date: Tue, 3 Oct 2023 16:53:25 -0400 Subject: [PATCH] finish rewriting all the graphs --- lib/utils/graph/include/utils/graph/digraph.h | 9 +- .../utils/graph/labelled/node_labelled.h | 2 +- .../utils/graph/labelled/standard_labelled.h | 2 +- .../graph/include/utils/graph/multidiedge.h | 3 - .../graph/include/utils/graph/multidigraph.h | 11 +- lib/utils/graph/include/utils/graph/node.h | 4 +- .../graph/include/utils/graph/open_edge.h | 43 ++--- .../graph/include/utils/graph/open_graphs.h | 40 ++--- .../graph/include/utils/graph/undirected.h | 38 ++--- .../include/utils/graph/undirected_edge.h | 1 + lib/utils/graph/src/utils/graph/diedge.cc | 39 +++++ lib/utils/graph/src/utils/graph/digraph.cc | 93 ++++------- .../graph/src/utils/graph/multidiedge.cc | 113 ++++++++++++- .../graph/src/utils/graph/multidigraph.cc | 152 +++--------------- lib/utils/graph/src/utils/graph/node.cc | 35 ++-- lib/utils/graph/src/utils/graph/open_edge.cc | 54 +++++++ .../graph/src/utils/graph/open_graphs.cc | 136 ++++++++++------ lib/utils/graph/src/utils/graph/undirected.cc | 65 ++------ .../graph/src/utils/graph/undirected_edge.cc | 23 +++ 19 files changed, 471 insertions(+), 392 deletions(-) create mode 100644 lib/utils/graph/src/utils/graph/diedge.cc create mode 100644 lib/utils/graph/src/utils/graph/open_edge.cc create mode 100644 lib/utils/graph/src/utils/graph/undirected_edge.cc diff --git a/lib/utils/graph/include/utils/graph/digraph.h b/lib/utils/graph/include/utils/graph/digraph.h index 92bcf714fe..b4b30518bc 100644 --- a/lib/utils/graph/include/utils/graph/digraph.h +++ b/lib/utils/graph/include/utils/graph/digraph.h @@ -31,11 +31,10 @@ struct DiGraphView : virtual public GraphView { return DiGraphView(make_cow_ptr(std::forward(args)...)); } - static DiGraphView - unsafe_create_without_ownership(IDiGraphView const &graphView); +protected: + DiGraphView(cow_ptr_t ptr); private: - DiGraphView(cow_ptr_t ptr); cow_ptr_t get_ptr() const; friend struct GraphInternal; @@ -70,8 +69,10 @@ struct DiGraph : virtual DiGraphView { return DiGraph(make_cow_ptr()); } -private: +protected: DiGraph(cow_ptr_t); + +private: cow_ptr_t get_ptr(); friend struct GraphInternal; diff --git a/lib/utils/graph/include/utils/graph/labelled/node_labelled.h b/lib/utils/graph/include/utils/graph/labelled/node_labelled.h index c7902615be..5bb288bc16 100644 --- a/lib/utils/graph/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/graph/include/utils/graph/labelled/node_labelled.h @@ -117,7 +117,7 @@ struct NodeLabelledMultiDiGraph : virtual NodeLabelledMultiDiGraphView ptr, cow_ptr_t nl) : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} - cow_ptr_t get_ptr() { + cow_ptr_t get_ptr() const { return static_cast>(ptr); } diff --git a/lib/utils/graph/include/utils/graph/labelled/standard_labelled.h b/lib/utils/graph/include/utils/graph/labelled/standard_labelled.h index 9e2d86c819..9f88315038 100644 --- a/lib/utils/graph/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/graph/include/utils/graph/labelled/standard_labelled.h @@ -131,7 +131,7 @@ struct LabelledMultiDiGraph : virtual LabelledMultiDiGraphView nl, cow_ptr_t el) : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} - cow_ptr_t get_ptr() { + cow_ptr_t get_ptr() const { return static_cast>(ptr); } diff --git a/lib/utils/graph/include/utils/graph/multidiedge.h b/lib/utils/graph/include/utils/graph/multidiedge.h index 74c8413070..ab86e30cb8 100644 --- a/lib/utils/graph/include/utils/graph/multidiedge.h +++ b/lib/utils/graph/include/utils/graph/multidiedge.h @@ -19,9 +19,6 @@ struct MultiDiOutput : virtual DiOutput { }; FF_VISITABLE_STRUCT(MultiDiOutput, src, src_idx); -MultiDiInput get_input(MultiDiEdge const &); -MultiDiOutput get_output(MultiDiEdge const &); - using edge_uid_t = std::pair; struct InputMultiDiEdge : virtual MultiDiInput { diff --git a/lib/utils/graph/include/utils/graph/multidigraph.h b/lib/utils/graph/include/utils/graph/multidigraph.h index 68d2d0b573..29e70054d2 100644 --- a/lib/utils/graph/include/utils/graph/multidigraph.h +++ b/lib/utils/graph/include/utils/graph/multidigraph.h @@ -25,11 +25,10 @@ struct MultiDiGraphView : virtual DiGraphView { make_cow_ptr(std::forward(args)...)); } - static MultiDiGraphView - unsafe_create_without_ownership(IMultiDiGraphView const &); - -private: +protected: MultiDiGraphView(cow_ptr_t ptr); + +private: cow_ptr_t get_ptr() const; friend struct GraphInternal; @@ -66,8 +65,10 @@ struct MultiDiGraph : virtual MultiDiGraphView { return MultiDiGraph(make_cow_ptr()); } -private: +protected: MultiDiGraph(cow_ptr_t); + +private: cow_ptr_t get_ptr() const; friend struct GraphInternal; diff --git a/lib/utils/graph/include/utils/graph/node.h b/lib/utils/graph/include/utils/graph/node.h index 47146601c5..349021489d 100644 --- a/lib/utils/graph/include/utils/graph/node.h +++ b/lib/utils/graph/include/utils/graph/node.h @@ -51,8 +51,6 @@ struct GraphView { std::unordered_set query_nodes(NodeQuery const &) const; - static GraphView unsafe_create_without_ownership(IGraphView const &); - template static typename std::enable_if::value, GraphView>::type @@ -87,7 +85,7 @@ struct Graph : virtual GraphView { Graph() = delete; Graph(Graph const &) = default; - Graph &operator=(Graph); + Graph &operator=(Graph) = default; friend void swap(Graph &, Graph &); diff --git a/lib/utils/graph/include/utils/graph/open_edge.h b/lib/utils/graph/include/utils/graph/open_edge.h index 829a135f8a..3f439834a8 100644 --- a/lib/utils/graph/include/utils/graph/open_edge.h +++ b/lib/utils/graph/include/utils/graph/open_edge.h @@ -20,26 +20,17 @@ struct OpenMultiDiEdgeQuery { OpenMultiDiEdgeQuery() = delete; OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &input_edge_query, MultiDiEdgeQuery const &standard_edge_query, - OutputMultiDiEdgeQuery const &output_edge_query) - : input_edge_query(input_edge_query), - standard_edge_query(standard_edge_query), - output_edge_query(output_edge_query) {} - - OpenMultiDiEdgeQuery(MultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), q, OutputMultiDiEdgeQuery::none()) {} - OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - q, MultiDiEdgeQuery::none(), OutputMultiDiEdgeQuery::none()) {} - OpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &q) - : OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), MultiDiEdgeQuery::none(), q) {} + OutputMultiDiEdgeQuery const &output_edge_query); + + OpenMultiDiEdgeQuery(MultiDiEdgeQuery const &q); + OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &q); + OpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &q); InputMultiDiEdgeQuery input_edge_query; MultiDiEdgeQuery standard_edge_query; OutputMultiDiEdgeQuery output_edge_query; }; -FF_VISITABLE_STRUCT(OpenMultiDiEdgeQuery, +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(OpenMultiDiEdgeQuery, input_edge_query, standard_edge_query, output_edge_query); @@ -47,19 +38,11 @@ FF_VISITABLE_STRUCT(OpenMultiDiEdgeQuery, struct DownwardOpenMultiDiEdgeQuery { DownwardOpenMultiDiEdgeQuery() = delete; DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query, - MultiDiEdgeQuery const &standard_edge_query) - : output_edge_query(output_edge_query), - standard_edge_query(standard_edge_query) {} - DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query) - : DownwardOpenMultiDiEdgeQuery(output_edge_query, - MultiDiEdgeQuery::none()) {} - DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query) - : DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery::all(), - standard_edge_query){}; - - operator OpenMultiDiEdgeQuery() const { - NOT_IMPLEMENTED(); - } + MultiDiEdgeQuery const &standard_edge_query); + DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query); + DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query); + + operator OpenMultiDiEdgeQuery() const; OutputMultiDiEdgeQuery output_edge_query; MultiDiEdgeQuery standard_edge_query; @@ -74,9 +57,7 @@ struct UpwardOpenMultiDiEdgeQuery { MultiDiEdgeQuery const &); UpwardOpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &); UpwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &); - operator OpenMultiDiEdgeQuery() const { - NOT_IMPLEMENTED(); - } + operator OpenMultiDiEdgeQuery() const; InputMultiDiEdgeQuery input_edge_query; MultiDiEdgeQuery standard_edge_query; diff --git a/lib/utils/graph/include/utils/graph/open_graphs.h b/lib/utils/graph/include/utils/graph/open_graphs.h index fadc79d75c..fc917632f8 100644 --- a/lib/utils/graph/include/utils/graph/open_graphs.h +++ b/lib/utils/graph/include/utils/graph/open_graphs.h @@ -32,8 +32,10 @@ struct OpenMultiDiGraphView : virtual MultiDiGraphView { make_cow_ptr(std::forward(args)...)); } -private: +protected: OpenMultiDiGraphView(cow_ptr_t ptr); + +private: cow_ptr_t get_ptr() const; friend struct GraphInternal; @@ -45,7 +47,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { using EdgeQuery = OpenMultiDiEdgeQuery; OpenMultiDiGraph() = delete; - OpenMultiDiGraph(OpenMultiDiGraph const &); + OpenMultiDiGraph(OpenMultiDiGraph const &) = default; friend void swap(OpenMultiDiGraph &, OpenMultiDiGraph &); @@ -65,8 +67,10 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { return make_cow_ptr(); } -private: +protected: OpenMultiDiGraph(cow_ptr_t ptr); + +private: cow_ptr_t get_ptr() const; friend struct GraphInternal; @@ -94,7 +98,7 @@ struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { cow_ptr_t(std::forward(args)...)); } -private: +protected: UpwardOpenMultiDiGraphView( cow_ptr_t); @@ -103,7 +107,7 @@ struct UpwardOpenMultiDiGraphView : virtual MultiDiGraphView { }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraphView); -struct UpwardOpenMultiDiGraph { +struct UpwardOpenMultiDiGraph : UpwardOpenMultiDiGraphView { public: using Edge = UpwardOpenMultiDiEdge; using EdgeQuery = UpwardOpenMultiDiEdgeQuery; @@ -132,15 +136,15 @@ struct UpwardOpenMultiDiGraph { return UpwardOpenMultiDiGraph(make_cow_ptr()); } -private: - UpwardOpenMultiDiGraph(std::unique_ptr); +protected: + UpwardOpenMultiDiGraph(cow_ptr_t); private: - cow_ptr_t get_ptr(); + cow_ptr_t get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UpwardOpenMultiDiGraph); -struct DownwardOpenMultiDiGraphView : virtual OpenMultiDiSubgraphView { +struct DownwardOpenMultiDiGraphView : virtual MultiDiGraphView { public: using Edge = DownwardOpenMultiDiEdge; using EdgeQuery = DownwardOpenMultiDiEdgeQuery; @@ -150,8 +154,8 @@ struct DownwardOpenMultiDiGraphView : virtual OpenMultiDiSubgraphView { friend void swap(OpenMultiDiGraphView &, OpenMultiDiGraphView &); - std::unordered_set query_nodes(NodeQuery const &); - std::unordered_set query_edges(EdgeQuery const &); + std::unordered_set query_nodes(NodeQuery const &) const; + std::unordered_set query_edges(EdgeQuery const &) const; template static typename std::enable_if< @@ -162,7 +166,7 @@ struct DownwardOpenMultiDiGraphView : virtual OpenMultiDiSubgraphView { make_cow_ptr(std::forward(args)...)); } -private: +protected: DownwardOpenMultiDiGraphView( cow_ptr_t); @@ -171,15 +175,14 @@ struct DownwardOpenMultiDiGraphView : virtual OpenMultiDiSubgraphView { }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraphView); -struct DownwardOpenMultiDiGraph { +struct DownwardOpenMultiDiGraph : virtual DownwardOpenMultiDiGraphView { public: using Edge = DownwardOpenMultiDiEdge; using EdgeQuery = DownwardOpenMultiDiEdgeQuery; DownwardOpenMultiDiGraph() = delete; - DownwardOpenMultiDiGraph(DownwardOpenMultiDiGraph const &); - - DownwardOpenMultiDiGraph &operator=(DownwardOpenMultiDiGraph); + DownwardOpenMultiDiGraph(DownwardOpenMultiDiGraph const &) = default; + DownwardOpenMultiDiGraph &operator=(DownwardOpenMultiDiGraph const &) = default; friend void swap(DownwardOpenMultiDiGraph &, DownwardOpenMultiDiGraph &); @@ -190,6 +193,7 @@ struct DownwardOpenMultiDiGraph { void add_edge(Edge const &); void remove_edge(Edge const &); + std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &) const; template @@ -200,11 +204,11 @@ struct DownwardOpenMultiDiGraph { return DownwardOpenMultiDiGraph(make_cow_ptr()); } -private: +protected: DownwardOpenMultiDiGraph(cow_ptr_t); private: - cow_ptr_t get_ptr(); + cow_ptr_t get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DownwardOpenMultiDiGraph); diff --git a/lib/utils/graph/include/utils/graph/undirected.h b/lib/utils/graph/include/utils/graph/undirected.h index eb4ab0832e..5259ba53f0 100644 --- a/lib/utils/graph/include/utils/graph/undirected.h +++ b/lib/utils/graph/include/utils/graph/undirected.h @@ -27,17 +27,13 @@ struct IUndirectedGraphView : public IGraphView { }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IUndirectedGraphView); -struct UndirectedGraphView { +struct UndirectedGraphView : virtual GraphView { public: using Edge = UndirectedEdge; using EdgeQuery = UndirectedEdgeQuery; UndirectedGraphView() = delete; - operator GraphView() const; - - friend void swap(UndirectedGraphView &, UndirectedGraphView &); - std::unordered_set query_nodes(NodeQuery const &) const; std::unordered_set query_edges(EdgeQuery const &query) const; @@ -47,47 +43,37 @@ struct UndirectedGraphView { UndirectedGraphView>::type create(Args &&...args) { return UndirectedGraphView( - std::make_shared(std::forward(args)...)); + make_cow_ptr(std::forward(args)...)); } - static UndirectedGraphView - unsafe_create_without_ownership(IUndirectedGraphView const &); - -private: - UndirectedGraphView(std::shared_ptr ptr); +protected: + UndirectedGraphView(cow_ptr_t ptr); friend struct GraphInternal; private: - std::shared_ptr ptr; + cow_ptr_t get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraphView); -struct IUndirectedGraph : public IUndirectedGraphView, public IGraph { +struct IUndirectedGraph : public IUndirectedGraphView { virtual void add_edge(UndirectedEdge const &) = 0; virtual void remove_edge(UndirectedEdge const &) = 0; virtual std::unordered_set - query_nodes(NodeQuery const &query) const override { - return static_cast(this)->query_nodes(query); - } + query_nodes(NodeQuery const &query) const = 0; virtual IUndirectedGraph *clone() const override = 0; }; -struct UndirectedGraph { +struct UndirectedGraph : virtual UndirectedGraphView { public: using Edge = UndirectedEdge; using EdgeQuery = UndirectedEdgeQuery; UndirectedGraph() = delete; - UndirectedGraph(UndirectedGraph const &); - - UndirectedGraph &operator=(UndirectedGraph); - - operator UndirectedGraphView() const; - - friend void swap(UndirectedGraph &, UndirectedGraph &); + UndirectedGraph(UndirectedGraph const &) = default; + UndirectedGraph &operator=(UndirectedGraph const &) = default; Node add_node(); void add_node_unsafe(Node const &); @@ -106,13 +92,13 @@ struct UndirectedGraph { return UndirectedGraph(make_cow_ptr()); } -private: +protected: UndirectedGraph(cow_ptr_t); friend struct GraphInternal; private: - cow_ptr_t ptr; + cow_ptr_t get_ptr() const; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(UndirectedGraph); diff --git a/lib/utils/graph/include/utils/graph/undirected_edge.h b/lib/utils/graph/include/utils/graph/undirected_edge.h index 67ff88c39b..66c8f27966 100644 --- a/lib/utils/graph/include/utils/graph/undirected_edge.h +++ b/lib/utils/graph/include/utils/graph/undirected_edge.h @@ -1,6 +1,7 @@ #ifndef UTILS_GRAPH_INCLUDE_UTILS_GRAPH_UNDIRECTED_EDGE #define UTILS_GRAPH_INCLUDE_UTILS_GRAPH_UNDIRECTED_EDGE +#include "node.h" namespace FlexFlow { diff --git a/lib/utils/graph/src/utils/graph/diedge.cc b/lib/utils/graph/src/utils/graph/diedge.cc new file mode 100644 index 0000000000..e442856e6b --- /dev/null +++ b/lib/utils/graph/src/utils/graph/diedge.cc @@ -0,0 +1,39 @@ +#include "utils/graph/diedge.h" + +namespace FlexFlow { + +DirectedEdgeQuery DirectedEdgeQuery::all() { + return {matchall(), matchall()}; +} + +bool matches_edge(DirectedEdgeQuery const &q, DirectedEdge const &e) { + return includes(q.srcs, e.src) && includes(q.dsts, e.dst); +} + +DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, + DirectedEdgeQuery const &rhs) { + std::unordered_set srcs_tl; + if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(rhs.srcs); + } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(lhs.srcs); + } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_tl = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); + } + + std::unordered_set dsts_tl; + if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(rhs.dsts); + } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(lhs.dsts); + } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_tl = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); + } + + DirectedEdgeQuery result = DirectedEdgeQuery::all(); + result.srcs = srcs_tl; + result.dsts = dsts_tl; + return result; +} + +} \ No newline at end of file diff --git a/lib/utils/graph/src/utils/graph/digraph.cc b/lib/utils/graph/src/utils/graph/digraph.cc index 341005bd08..d5c72dae48 100644 --- a/lib/utils/graph/src/utils/graph/digraph.cc +++ b/lib/utils/graph/src/utils/graph/digraph.cc @@ -5,7 +5,7 @@ namespace FlexFlow { -void swap(DiGraph &lhs, DiGraph &rhs) { +void swap(DiGraphView &lhs, DiGraphView &rhs) { using std::swap; swap(lhs.ptr, rhs.ptr); @@ -15,93 +15,56 @@ bool is_ptr_equal(DiGraphView const &lhs, DiGraphView const &rhs) { return lhs.ptr == rhs.ptr; } -Node DiGraph::add_node() { - return this->ptr.get_mutable()->add_node(); -} - -void DiGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); +std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { + return this->get_ptr()->query_nodes(q); } -void DiGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); +std::unordered_set + DiGraphView::query_edges(EdgeQuery const &query) const { + return get_ptr()->query_edges(query); } -void DiGraph::add_edge(DirectedEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); -} +DiGraphView::DiGraphView(cow_ptr_t ptr) : GraphView(ptr) {} -void DiGraph::remove_edge(DirectedEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); +cow_ptr_t IDiGraphView::get_ptr() const { + return static_cast>(ptr); } -std::unordered_set - DiGraph::query_edges(DirectedEdgeQuery const &q) const { - return this->ptr->query_edges(q); -} - -DiGraph::DiGraph(cow_ptr_t _ptr) : ptr(std::move(_ptr)) {} +void swap(DiGraph &lhs, DiGraph &rhs) { + using std::swap; -DiGraphView::operator GraphView() const { - return GraphInternal::create_graphview(this->ptr); + swap(lhs.ptr, rhs.ptr); } -std::unordered_set DiGraphView::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); +Node DiGraph::add_node() { + return this->get_ptr().get_mutable()->add_node(); } -std::unordered_set - DiGraphView::query_edges(EdgeQuery const &query) const { - return ptr->query_edges(query); +void DiGraph::add_node_unsafe(Node const &n) { + return this->get_ptr().get_mutable()->add_node_unsafe(n); } -// Set the shared_ptr's destructor to a nop so that effectively there is no -// ownership -DiGraphView DiGraphView::unsafe_create_without_ownership( - IDiGraphView const &graphView) { - std::shared_ptr ptr((&graphView), - [](IDiGraphView const *) {}); - return DiGraphView(ptr); +void DiGraph::remove_node_unsafe(Node const &n) { + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } -DirectedEdgeQuery DirectedEdgeQuery::all() { - return {matchall(), matchall()}; +void DiGraph::add_edge(DirectedEdge const &e) { + return this->get_ptr().get_mutable()->add_edge(e); } -bool matches_edge(DirectedEdgeQuery const &q, DirectedEdge const &e) { - return includes(q.srcs, e.src) && includes(q.dsts, e.dst); +void DiGraph::remove_edge(DirectedEdge const &e) { + return this->get_ptr().get_mutable()->remove_edge(e); } -DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, - DirectedEdgeQuery const &rhs) { - std::unordered_set srcs_tl; - if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_tl = allowed_values(rhs.srcs); - } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { - srcs_tl = allowed_values(lhs.srcs); - } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_tl = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); - } - - std::unordered_set dsts_tl; - if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_tl = allowed_values(rhs.dsts); - } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { - dsts_tl = allowed_values(lhs.dsts); - } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_tl = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); - } - - DirectedEdgeQuery result = DirectedEdgeQuery::all(); - result.srcs = srcs_tl; - result.dsts = dsts_tl; - return result; +std::unordered_set DiGraph::query_nodes(NodeQuery const &q) const { + return this->get_ptr()->query_nodes(q); } -DiGraph::operator DiGraphView() const { - return GraphInternal::create_digraphview(this->ptr.get()); +std::unordered_set + DiGraph::query_edges(DirectedEdgeQuery const &q) const { + return this->get_ptr()->query_edges(q); } -DiGraphView::DiGraphView(std::shared_ptr ptr) : ptr(ptr) {} +DiGraph::DiGraph(cow_ptr_t _ptr) : DiGraphView(_ptr) {} } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/multidiedge.cc b/lib/utils/graph/src/utils/graph/multidiedge.cc index c2d61fe179..7724f2ab99 100644 --- a/lib/utils/graph/src/utils/graph/multidiedge.cc +++ b/lib/utils/graph/src/utils/graph/multidiedge.cc @@ -2,12 +2,117 @@ namespace FlexFlow { -MultiDiInput get_input(MultiDiEdge const &e) { - return {e.dst, e.dstIdx}; +OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::all() { + return {matchall(), matchall()}; } -MultiDiOutput get_output(MultiDiEdge const &e) { - return {e.src, e.srcIdx}; +OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::none() { + return {{}, {}}; +} + +InputMultiDiEdgeQuery InputMultiDiEdgeQuery::all() { + return {matchall(), matchall()}; +} + +InputMultiDiEdgeQuery InputMultiDiEdgeQuery::none() { + return {{}, {}}; +} + +MultiDiEdgeQuery + MultiDiEdgeQuery::with_src_nodes(query_set const &nodes) const { + MultiDiEdgeQuery e = *this; + if (!is_matchall(e.srcs)) { + throw mk_runtime_error("Expected matchall previous value"); + } + e.srcs = nodes; + return e; +} + +MultiDiEdgeQuery + MultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { + MultiDiEdgeQuery e = *this; + if (!is_matchall(e.dsts)) { + throw mk_runtime_error("Expected matchall previous value"); + } + e.dsts = nodes; + return e; +} + +MultiDiEdgeQuery + MultiDiEdgeQuery::with_src_idxs(query_set const &idxs) const { + MultiDiEdgeQuery e{*this}; + if (!is_matchall(e.srcIdxs)) { + throw mk_runtime_error("Expected matchall previous value"); + } + e.srcIdxs = idxs; + return e; +} + +MultiDiEdgeQuery + MultiDiEdgeQuery::with_dst_idxs(query_set const &idxs) const { + MultiDiEdgeQuery e = *this; + if (!is_matchall(e.dstIdxs)) { + throw mk_runtime_error("Expected matchall previous value"); + } + e.dstIdxs = idxs; + return e; +} + +MultiDiEdgeQuery MultiDiEdgeQuery::all() { + return {matchall(), + matchall(), + matchall(), + matchall()}; +} + +MultiDiEdgeQuery MultiDiEdgeQuery::none() { + return {{}, {}, {}, {}}; +} + +MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs, + MultiDiEdgeQuery const &rhs) { + std::unordered_set srcs_t1; + if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(rhs.srcs); + } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(lhs.srcs); + } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { + srcs_t1 = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); + } + + std::unordered_set dsts_t1; + if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(rhs.dsts); + } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(lhs.dsts); + } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { + dsts_t1 = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); + } + + std::unordered_set srcIdxs_t1; + if (is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(rhs.srcIdxs); + } else if (!is_matchall(lhs.srcIdxs) && is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(lhs.srcIdxs); + } else if (!is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { + srcIdxs_t1 = allowed_values(query_intersection(lhs.srcIdxs, rhs.srcIdxs)); + } + + std::unordered_set dstIdxs_t1; + if (is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(rhs.dstIdxs); + } else if (!is_matchall(lhs.dstIdxs) && is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(lhs.dstIdxs); + } else if (!is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { + dstIdxs_t1 = allowed_values(query_intersection(lhs.dstIdxs, rhs.dstIdxs)); + } + + MultiDiEdgeQuery e = MultiDiEdgeQuery::all(); + e.srcs = srcs_t1; + e.dsts = dsts_t1; + e.srcIdxs = srcIdxs_t1; + e.dstIdxs = dstIdxs_t1; + return e; } } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/multidigraph.cc b/lib/utils/graph/src/utils/graph/multidigraph.cc index f21e35fbe7..4158b02fb6 100644 --- a/lib/utils/graph/src/utils/graph/multidigraph.cc +++ b/lib/utils/graph/src/utils/graph/multidigraph.cc @@ -4,179 +4,77 @@ namespace FlexFlow { -MultiDiEdgeQuery - MultiDiEdgeQuery::with_src_nodes(query_set const &nodes) const { - MultiDiEdgeQuery e = *this; - if (!is_matchall(e.srcs)) { - throw mk_runtime_error("Expected matchall previous value"); - } - e.srcs = nodes; - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_dst_nodes(query_set const &nodes) const { - MultiDiEdgeQuery e = *this; - if (!is_matchall(e.dsts)) { - throw mk_runtime_error("Expected matchall previous value"); - } - e.dsts = nodes; - return e; -} - -MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs, - MultiDiEdgeQuery const &rhs) { - std::unordered_set srcs_t1; - if (is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(rhs.srcs); - } else if (!is_matchall(lhs.srcs) && is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(lhs.srcs); - } else if (!is_matchall(lhs.srcs) && !is_matchall(rhs.srcs)) { - srcs_t1 = allowed_values(query_intersection(lhs.srcs, rhs.srcs)); - } - - std::unordered_set dsts_t1; - if (is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(rhs.dsts); - } else if (!is_matchall(lhs.dsts) && is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(lhs.dsts); - } else if (!is_matchall(lhs.dsts) && !is_matchall(rhs.dsts)) { - dsts_t1 = allowed_values(query_intersection(lhs.dsts, rhs.dsts)); - } - - std::unordered_set srcIdxs_t1; - if (is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(rhs.srcIdxs); - } else if (!is_matchall(lhs.srcIdxs) && is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(lhs.srcIdxs); - } else if (!is_matchall(lhs.srcIdxs) && !is_matchall(rhs.srcIdxs)) { - srcIdxs_t1 = allowed_values(query_intersection(lhs.srcIdxs, rhs.srcIdxs)); - } - - std::unordered_set dstIdxs_t1; - if (is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(rhs.dstIdxs); - } else if (!is_matchall(lhs.dstIdxs) && is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(lhs.dstIdxs); - } else if (!is_matchall(lhs.dstIdxs) && !is_matchall(rhs.dstIdxs)) { - dstIdxs_t1 = allowed_values(query_intersection(lhs.dstIdxs, rhs.dstIdxs)); - } - - MultiDiEdgeQuery e = MultiDiEdgeQuery::all(); - e.srcs = srcs_t1; - e.dsts = dsts_t1; - e.srcIdxs = srcIdxs_t1; - e.dstIdxs = dstIdxs_t1; - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_src_idxs(query_set const &idxs) const { - MultiDiEdgeQuery e{*this}; - if (!is_matchall(e.srcIdxs)) { - throw mk_runtime_error("Expected matchall previous value"); - } - e.srcIdxs = idxs; - return e; -} - -MultiDiEdgeQuery - MultiDiEdgeQuery::with_dst_idxs(query_set const &idxs) const { - MultiDiEdgeQuery e = *this; - if (!is_matchall(e.dstIdxs)) { - throw mk_runtime_error("Expected matchall previous value"); - } - e.dstIdxs = idxs; - return e; -} +void swap(MultiDiGraphView &lhs, MultiDiGraphView &rhs) { + using std::swap; -MultiDiEdgeQuery MultiDiEdgeQuery::all() { - return {matchall(), - matchall(), - matchall(), - matchall()}; + swap(lhs.ptr, rhs.ptr); } std::unordered_set MultiDiGraphView::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return this->get_ptr()->query_nodes(q); } -MultiDiGraphView::MultiDiGraphView(std::shared_ptr ptr) - : ptr(ptr) {} - std::unordered_set MultiDiGraphView::query_edges(MultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } -MultiDiGraphView::operator GraphView() const { - return GraphInternal::create_graphview(this->ptr); -} - -// Set the shared_ptr's destructor to a nop so that effectively there is no -// ownership -MultiDiGraphView MultiDiGraphView::unsafe_create_without_ownership( - IMultiDiGraphView const &graphView) { - std::shared_ptr ptr( - (&graphView), [](IMultiDiGraphView const *ptr) {}); - return MultiDiGraphView(ptr); -} +MultiDiGraphView::MultiDiGraphView(cow_ptr_t ptr) + : DiGraphView(ptr) {} -void swap(MultiDiGraphView &lhs, MultiDiGraphView &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); +cow_ptr_t MultiDiGraphView::get_ptr() const { + return static_cast>(ptr); } -MultiDiGraph::MultiDiGraph(cow_ptr_t _ptr) - : ptr(std::move(_ptr)) {} - void swap(MultiDiGraph &lhs, MultiDiGraph &rhs) { using std::swap; swap(lhs.ptr, rhs.ptr); } -MultiDiGraph::operator MultiDiGraphView() const { - return GraphInternal::create_multidigraphview(this->ptr.get()); -} - Node MultiDiGraph::add_node() { - return this->ptr.get_mutable()->add_node(); + return this->get_ptr().get_mutable()->add_node(); } NodePort MultiDiGraph::add_node_port() { - return this->ptr.get_mutable()->add_node_port(); + return this->get_ptr().get_mutable()->add_node_port(); } void MultiDiGraph::add_node_port_unsafe(NodePort const &np) { - return this->ptr.get_mutable()->add_node_port_unsafe(np); + return this->get_ptr().get_mutable()->add_node_port_unsafe(np); } void MultiDiGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); + return this->get_ptr().get_mutable()->add_node_unsafe(n); } void MultiDiGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } void MultiDiGraph::add_edge(MultiDiEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); + return this->get_ptr().get_mutable()->add_edge(e); } void MultiDiGraph::remove_edge(MultiDiEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); + return this->get_ptr().get_mutable()->remove_edge(e); } std::unordered_set MultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return this->get_ptr()->query_nodes(q); +} + +MultiDiGraph::MultiDiGraph(cow_ptr_t ptr) + : MultiDiGraphView(ptr) {} + +cow_ptr_t MultiDiGraph::get_ptr() const { + return static_cast>(ptr); } } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/node.cc b/lib/utils/graph/src/utils/graph/node.cc index 0740cde3eb..ba15eccce2 100644 --- a/lib/utils/graph/src/utils/graph/node.cc +++ b/lib/utils/graph/src/utils/graph/node.cc @@ -26,27 +26,42 @@ NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) { return intersection_result; } +void swap(GraphView &lhs, GraphView &rhs) { + std::swap(lhs.ptr, rhs.ptr); +} + std::unordered_set GraphView::query_nodes(NodeQuery const &g) const { return this->ptr->query_nodes(g); } GraphView::GraphView(std::shared_ptr ptr) : ptr(ptr) {} -// Set the shared_ptr's destructor to a nop so that effectively there is no -// ownership -GraphView - GraphView::unsafe_create_without_ownership(IGraphView const &graphView) { - std::shared_ptr ptr((&graphView), - [](IGraphView const *) {}); - return GraphView(ptr); +void swap(Graph &lhs, Graph &rhs) { + std::swap(lhs.ptr, rhs.ptr); +} + +Node Graph::add_node() { + get_ptr()->add_node(); +} + +Node Graph::add_node_unsafe() { + get_ptr()->add_node_unsafe(); +} + +Node Graph::remove_node_unsafe() { + get_ptr()->remove_node_unsafe(); } -Graph::Graph(cow_ptr_t _ptr) : ptr(std::move(_ptr)) { +std::unordered_set Graph::query_nodes(NodeQuery const &q) const { + return get_ptr()->query_nodes(q); +} + +Graph::Graph(cow_ptr_t _ptr) : ptr(_ptr) { assert(this->ptr.get() != nullptr); } -Node Graph::add_node() { - return this->ptr.get_mutable()->add_node(); +cow_ptr_t Graph::get_ptr() const { + return static_cast>(ptr); } } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/open_edge.cc b/lib/utils/graph/src/utils/graph/open_edge.cc new file mode 100644 index 0000000000..fc2fa97e84 --- /dev/null +++ b/lib/utils/graph/src/utils/graph/open_edge.cc @@ -0,0 +1,54 @@ +#include "utils/graph/open_edge.h" + +namespace FlexFlow { + +bool is_input_edge(OpenMultiDiEdge const &e) { + return holds_alternative(e); +} + +bool is_output_edge(OpenMultiDiEdge const &e) { + return holds_alternative(e); +} + +bool is_standard_edge(OpenMultiDiEdge const &e) { + return holds_alternative(e); +} + +OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &input_edge_query, + MultiDiEdgeQuery const &standard_edge_query, + OutputMultiDiEdgeQuery const &output_edge_query); + +OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(MultiDiEdgeQuery const &q) + : OpenMultiDiEdgeQuery( + InputMultiDiEdgeQuery::none(), q, OutputMultiDiEdgeQuery::none()) {} +OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &q) + : OpenMultiDiEdgeQuery( + q, MultiDiEdgeQuery::none(), OutputMultiDiEdgeQuery::none()) {} +OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &q) + : OpenMultiDiEdgeQuery( + InputMultiDiEdgeQuery::none(), MultiDiEdgeQuery::none(), q) {} + +DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query, + MultiDiEdgeQuery const &standard_edge_query) + : output_edge_query(output_edge_query), + standard_edge_query(standard_edge_query) {} +DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery const &output_edge_query) + : DownwardOpenMultiDiEdgeQuery(output_edge_query, + MultiDiEdgeQuery::none()) {} +DownwardOpenMultiDiEdgeQuery::DownwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query) + : DownwardOpenMultiDiEdgeQuery(OutputMultiDiEdgeQuery::all(), + standard_edge_query){}; + +DownwardOpenMultiDiEdgeQuery::operator OpenMultiDiEdgeQuery() const { + return OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::none(), standard_edge_query, output_edge_query); +} + +UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &input_edge_query, + MultiDiEdgeQuery const &standard_edge_query) : input_edge_query(input_edge_query), standard_edge_query(standard_edge_query) {} + +UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery(InputMultiDiEdgeQuery const &input_edge_query) + : input_edge_query(input_edge_query), standard_edge_query(MultiDiEdgeQuery::none()) {} +UpwardOpenMultiDiEdgeQuery::UpwardOpenMultiDiEdgeQuery(MultiDiEdgeQuery const &standard_edge_query) + : input_edge_query(InputMultiDiEdgeQuery::none()), standard_edge_query(standard_edge_query) {} + +} \ No newline at end of file diff --git a/lib/utils/graph/src/utils/graph/open_graphs.cc b/lib/utils/graph/src/utils/graph/open_graphs.cc index 0a41c45fa0..972cc0dd5f 100644 --- a/lib/utils/graph/src/utils/graph/open_graphs.cc +++ b/lib/utils/graph/src/utils/graph/open_graphs.cc @@ -6,33 +6,28 @@ namespace FlexFlow { -InputMultiDiEdgeQuery InputMultiDiEdgeQuery::all() { - return {matchall(), matchall()}; -} +void swap(OpenMultiDiGraphView &lhs, OpenMultiDiGraphView &rhs) { + using std::swap; -OutputMultiDiEdgeQuery OutputMultiDiEdgeQuery::all() { - return {matchall(), matchall()}; + swap(lhs.ptr, rhs.ptr); } std::unordered_set OpenMultiDiGraphView::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return this->get_ptr()->query_nodes(q); } std::unordered_set OpenMultiDiGraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); -} - -OpenMultiDiGraphView::operator MultiDiGraphView() const { - return as_multidigraph(*this); + return this->get_ptr()->query_edges(q); } OpenMultiDiGraphView::OpenMultiDiGraphView( - std::shared_ptr ptr) - : ptr(ptr) {} + cow_ptr_t ptr) + : MultiDiGraphView(ptr) {} -OpenMultiDiGraph::OpenMultiDiGraph(OpenMultiDiGraph const &other) - : ptr(other.ptr) {} +cow_ptr_t OpenMultiDiGraphView::get_ptr() const { + return static_cast>(ptr); +} void swap(OpenMultiDiGraph &lhs, OpenMultiDiGraph &rhs) { using std::swap; @@ -41,37 +36,56 @@ void swap(OpenMultiDiGraph &lhs, OpenMultiDiGraph &rhs) { } Node OpenMultiDiGraph::add_node() { - return this->ptr.get_mutable()->add_node(); + return this->get_ptr().get_mutable()->add_node(); } void OpenMultiDiGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); + return this->get_ptr().get_mutable()->add_node_unsafe(n); } void OpenMultiDiGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } void OpenMultiDiGraph::add_edge(OpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); + return this->get_ptr().get_mutable()->add_edge(e); } void OpenMultiDiGraph::remove_edge(OpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); + return this->get_ptr().get_mutable()->remove_edge(e); } std::unordered_set OpenMultiDiGraph::query_edges(OpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } OpenMultiDiGraph::OpenMultiDiGraph(cow_ptr_t _ptr) - : ptr(std::move(_ptr)) {} + : OpenMultiDiGraphView(_ptr) {} -UpwardOpenMultiDiGraph & - UpwardOpenMultiDiGraph::operator=(UpwardOpenMultiDiGraph other) { - swap(*this, other); - return *this; +cow_ptr_t OpenMultiDiGraph::get_ptr() const { + return static_cast>(ptr); +} + +void swap(UpwardOpenMultiDiGraphView &lhs, UpwardOpenMultiDiGraphView &rhs) { + using std::swap; + + swap(lhs.ptr, rhs.ptr); +} + +std::unordered_set UpwardOpenMultiDiGraphView::query_nodes(NodeQuery const &q) { + return get_ptr()->query_nodes(q); +} + +std::unordered_set UpwardOpenMultiDiGraphView::query_edges(UpwardOpenMultiDiEdgeQuery const &q) { + return get_ptr()->query_edges(q); +} + +UpwardOpenMultiDiGraphView::UpwardOpenMultiDiGraphView( + cow_ptr_t ptr) : MultiDiGraphView(ptr) {} + +cow_ptr_t UpwardOpenMultiDiGraphView::get_ptr() const { + return static_cast>(ptr); } void swap(UpwardOpenMultiDiGraph &lhs, UpwardOpenMultiDiGraph &rhs) { @@ -81,38 +95,58 @@ void swap(UpwardOpenMultiDiGraph &lhs, UpwardOpenMultiDiGraph &rhs) { } Node UpwardOpenMultiDiGraph::add_node() { - return this->ptr.get_mutable()->add_node(); + return this->get_ptr().get_mutable()->add_node(); } void UpwardOpenMultiDiGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); + return this->get_ptr().get_mutable()->add_node_unsafe(n); } void UpwardOpenMultiDiGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } void UpwardOpenMultiDiGraph::add_edge(UpwardOpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); + return this->get_ptr().get_mutable()->add_edge(e); } void UpwardOpenMultiDiGraph::remove_edge(UpwardOpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); + return this->get_ptr().get_mutable()->remove_edge(e); } std::unordered_set UpwardOpenMultiDiGraph::query_edges( UpwardOpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } UpwardOpenMultiDiGraph::UpwardOpenMultiDiGraph( - std::unique_ptr _ptr) - : ptr(std::move(_ptr)) {} + cow_ptr_t _ptr) + : UpwardOpenMultiDiGraphView(_ptr) {} -DownwardOpenMultiDiGraph & - DownwardOpenMultiDiGraph::operator=(DownwardOpenMultiDiGraph other) { - swap(*this, other); - return *this; +cow_ptr_t UpwardOpenMultiDiGraph::get_ptr() const { + return static_cast>(ptr); +} + +void swap(DownwardOpenMultiDiGraphView &lhs, DownwardOpenMultiDiGraphView &rhs) { + using std::swap; + + swap(lhs.ptr, rhs.ptr); +} + +std::unordered_set DownwardOpenMultiDiGraphView::query_nodes(NodeQuery const &q) const { + return this->get_ptr()->query_nodes(q); +} + +std::unordered_set + DownwardOpenMultiDiGraphView::query_edges( + DownwardOpenMultiDiEdgeQuery const &q) const { + return this->get_ptr()->query_edges(q); +} + +DownwardOpenMultiDiGraphView::DownwardOpenMultiDiGraphView(cow_ptr_t ptr) : MultiDiGraphView(ptr) {} + +cow_ptr_t get_ptr() const { + return static_cast>(ptr); } void swap(DownwardOpenMultiDiGraph &lhs, DownwardOpenMultiDiGraph &rhs) { @@ -122,33 +156,43 @@ void swap(DownwardOpenMultiDiGraph &lhs, DownwardOpenMultiDiGraph &rhs) { } Node DownwardOpenMultiDiGraph::add_node() { - return this->ptr.get_mutable()->add_node(); + return this->get_ptr().get_mutable()->add_node(); } void DownwardOpenMultiDiGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); + return this->get_ptr().get_mutable()->add_node_unsafe(n); } void DownwardOpenMultiDiGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } void DownwardOpenMultiDiGraph::add_edge(DownwardOpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); + return this->get_ptr().get_mutable()->add_edge(e); } void DownwardOpenMultiDiGraph::remove_edge(DownwardOpenMultiDiEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); + return this->get_ptr().get_mutable()->remove_edge(e); +} + +std::unordered_set + DownwardOpenMultiDiGraph::query_nodes( + NodeQuery const &q) const { + return this->get_ptr()->query_nodes(q); } std::unordered_set DownwardOpenMultiDiGraph::query_edges( DownwardOpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } DownwardOpenMultiDiGraph::DownwardOpenMultiDiGraph( - std::unique_ptr _ptr) - : ptr(std::move(_ptr)) {} + cow_ptr_t _ptr) + : DownwardOpenMultiDiGraphView(ptr) {} + +cow_ptr_t DownwardOpenMultiDiGraph::get_ptr() const { + return static_cast>(ptr); +} } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/undirected.cc b/lib/utils/graph/src/utils/graph/undirected.cc index 9aeeae9b63..f32c86070e 100644 --- a/lib/utils/graph/src/utils/graph/undirected.cc +++ b/lib/utils/graph/src/utils/graph/undirected.cc @@ -6,92 +6,61 @@ namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; -} - -UndirectedEdgeQuery UndirectedEdgeQuery::all() { - return {matchall()}; +std::unordered_set UndirectedGraphView::query_nodes(NodeQuery const &q) const { + return get_ptr()->query_nodes(q); } -UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, - UndirectedEdgeQuery const &rhs) { - return { - query_intersection(lhs.nodes, rhs.nodes), - }; +std::unordered_set UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { + return get_ptr()->query_edges(q); } -void swap(UndirectedGraph &lhs, UndirectedGraph &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); -} +UndirectedGraphView UndirectedGraphView(cow_ptr_t ptr) : GraphView(ptr) {} Node UndirectedGraph::add_node() { - return this->ptr.get_mutable()->add_node(); + return this->get_ptr().get_mutable()->add_node(); } void UndirectedGraph::add_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->add_node_unsafe(n); + return this->get_ptr().get_mutable()->add_node_unsafe(n); } void UndirectedGraph::remove_node_unsafe(Node const &n) { - return this->ptr.get_mutable()->remove_node_unsafe(n); + return this->get_ptr().get_mutable()->remove_node_unsafe(n); } void UndirectedGraph::add_edge(UndirectedEdge const &e) { - return this->ptr.get_mutable()->add_edge(e); + return this->get_ptr().get_mutable()->add_edge(e); } void UndirectedGraph::remove_edge(UndirectedEdge const &e) { - return this->ptr.get_mutable()->remove_edge(e); + return this->get_ptr().get_mutable()->remove_edge(e); } std::unordered_set UndirectedGraph::query_edges(UndirectedEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } std::unordered_set UndirectedGraph::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); + return this->get_ptr()->query_nodes(q); } UndirectedGraph::UndirectedGraph(cow_ptr_t _ptr) - : ptr(std::move(_ptr)) {} - -UndirectedGraph::operator UndirectedGraphView() const { - return GraphInternal::create_undirectedgraphview(this->ptr.get()); -} - -UndirectedGraphView::UndirectedGraphView( - std::shared_ptr ptr) - : ptr(ptr) {} + : UndirectedGraphView(_ptr) {} std::unordered_set UndirectedGraphView::query_edges(UndirectedEdgeQuery const &q) const { - return this->ptr->query_edges(q); + return this->get_ptr()->query_edges(q); } std::unordered_set UndirectedGraphView::query_nodes(NodeQuery const &q) const { - return this->ptr->query_nodes(q); -} - -// Set the shared_ptr's destructor to a nop so that effectively there is no -// ownership -UndirectedGraphView UndirectedGraphView::unsafe_create_without_ownership( - IUndirectedGraphView const &g) { - std::shared_ptr ptr( - (&g), [](IUndirectedGraphView const *) {}); - return UndirectedGraphView(ptr); + return this->get_ptr()->query_nodes(q); } -UndirectedGraphView::operator GraphView() const { - return GraphInternal::create_graphview(this->ptr); +cow_ptr_t UndirectedGraphView::get_ptr() const { + return static_cast>(ptr); } } // namespace FlexFlow diff --git a/lib/utils/graph/src/utils/graph/undirected_edge.cc b/lib/utils/graph/src/utils/graph/undirected_edge.cc new file mode 100644 index 0000000000..a8c86888e3 --- /dev/null +++ b/lib/utils/graph/src/utils/graph/undirected_edge.cc @@ -0,0 +1,23 @@ +#include "utils/graph/undirected_edge.h" + +namespace FlexFlow { + +UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) + : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} + +bool is_connected_to(UndirectedEdge const &e, Node const &n) { + return e.bigger == n || e.smaller == n; +} + +UndirectedEdgeQuery UndirectedEdgeQuery::all() { + return {matchall()}; +} + +UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, + UndirectedEdgeQuery const &rhs) { + return { + query_intersection(lhs.nodes, rhs.nodes), + }; +} + +} \ No newline at end of file