diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 25b0103f9c..f3d31e7bc8 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,10 +15,16 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected -- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. - +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. +- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. +- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: + - The edges entering, exiting a given nodes now have a well-defined order. + - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. + - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. + +Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. + Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -37,7 +43,7 @@ flowchart TD D --- B ``` -Example of `DirectedGraph`: +Example of `DiGraph`: ```mermaid flowchart TD A(" ") @@ -58,98 +64,34 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A("A") - B("B") - C("C") - D("D") - E("E") - F("F") - - A -->|"(■, ★)"| B - B -->|"(●, ★)"| C - C -->|"(♥, ▲)"| D - D -->|"(●, ■)"| A - B -->|"(★, ●)"| E - E -->|"(■, ■)"| B - D -->|"(●, ●)"| A - A -->|"(●, ■)"| E - D -->|"(■, ●)"| D - E -->|"(■, ■)"| E -``` -or visualized a different way, -```mermaid -flowchart TD - Acirc("●") - Asqua("■") - Bcirc("●") - Bstar("★") - Bsqua("■") - Chear("♥") - Cstar("★") - Dsqua("■") - Dcirc("●") - Dtria("▲") - Ecirc("●") - Esqua("■") - Fplaceholder(" ") - - style Fplaceholder fill:#0000,stroke:#0000 - - subgraph "A" - Acirc - Asqua - end - - subgraph "B" - Bsqua - Bcirc - Bstar - end - - subgraph "C" - Chear - Cstar - end - - subgraph "D" - Dsqua - Dcirc - Dtria - end - - subgraph "E" - Ecirc - Esqua - end - - subgraph "F" - Fplaceholder - end - - Asqua --> Bstar - Bcirc --> Cstar - Chear --> Dtria - Dcirc --> Asqua - Bstar --> Ecirc - Esqua --> Bsqua - Dcirc --> Acirc - Acirc --> Esqua - Dsqua --> Dcirc - Esqua --> Esqua + A + B + C + D + E + F + + A --> B + B --> C + C --> D + D --> A + B --> E + E --> B + D --> A + A --> E + D --> D + E --> E ``` -Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. +Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. -`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. -Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. -`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -158,11 +100,13 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. +They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). +You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -171,40 +115,35 @@ This may seem wasteful (oftentimes graphs are large objects that are passed arou At this point, however, we still have not discussed how to create a graph. The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: +For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open, Upward, Downward +### Open DataFlow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). - -![Open graphs inheritance diagram](docs/open.svg) +This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. -Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) - - -### Labelled Graphs +### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). -While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. -Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. +Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. +- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). + +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) +Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram](docs/labelled.svg) - - ## Internals @@ -236,12 +175,7 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. -To create graphs within the library, we thus use the following syntax: -`BaseGraph obj = BaseGraph::create();` - -Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` - ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). -In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 6a1898dd13..d73175c7dd 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 1e4d09d3ae..96e8864bc1 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,7 +5,22 @@ namespace FlexFlow { +/** + * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) + * + * @note By definition, the root node dominates every node and every node + * dominates itself. + * + */ std::unordered_set get_dominators(DiGraphView const &, Node const &); + +/** + * @brief Returns the intersection of the dominators of the given set of nodes. + * @note This is conceptually equivalent to merging the given set of nodes and + * then finding the set of dominators of the new merged node (where merged means + * that all edges belonging to the set of nodes now pass through a single + * unified node). + */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index e36b90d4bf..3d320b1c06 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 54f84f8d2c..0380751c55 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 69080b9348..692ee33783 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index bddefdacb3..1d94d1a65e 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,8 +31,6 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index fce3177ef1..8d904e05f2 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,8 +22,6 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index d5c22e5d3d..46e0255de3 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -6,6 +6,7 @@ features = [ "hash", "fmt", "json", + "rapidcheck", ] includes = [ diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 33d50192cb..d051413faa 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,33 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); +#include "utils/graph/undirected/undirected_edge.dtg.h" - bool operator==(UndirectedEdge const &) const; - bool operator!=(UndirectedEdge const &) const; - bool operator<(UndirectedEdge const &) const; - -public: - Node smaller; - Node bigger; -}; +namespace FlexFlow { -bool is_connected_to(UndirectedEdge const &, Node const &); +bool is_connected_to(UndirectedEdge const &e, Node const &n); } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::UndirectedEdge> { - size_t operator()(::FlexFlow::UndirectedEdge const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml new file mode 100644 index 0000000000..f5258b0bfd --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck" +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 69975991ce..09b6495699 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; - friend struct GraphInternal; - private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index c2df96abc0..90dd5dd5d8 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; - friend struct GraphInternal; - private: IUndirectedGraphView const &get_ptr() const; }; diff --git a/lib/utils/include/utils/graph/views/views.h b/lib/utils/include/utils/graph/views/views.h index aaa1e033f4..5e0109ed5b 100644 --- a/lib/utils/include/utils/graph/views/views.h +++ b/lib/utils/include/utils/graph/views/views.h @@ -41,104 +41,11 @@ struct DiSubgraphView : public IDiGraphView { std::unordered_set subgraph_nodes; }; -struct JoinedNodeView { -public: - JoinedNodeView() = delete; - explicit JoinedNodeView(GraphView const &lhs, GraphView const &rhs); - - std::unordered_set query_nodes(NodeQuery const &) const; - std::pair, std::unordered_set> - trace_nodes(std::unordered_set const &) const; - - Node at_join_key(JoinNodeKey const &) const; - JoinNodeKey at_node(Node const &) const; - -private: - bidict mapping; - NodeSource node_source; -}; - -struct JoinedUndirectedGraphView : public IUndirectedGraphView { -public: - JoinedUndirectedGraphView() = delete; - explicit JoinedUndirectedGraphView(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs); - - std::unordered_set - query_edges(UndirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedUndirectedGraphView *clone() const override; - -private: - UndirectedEdge fix_lhs_edge(UndirectedEdge const &) const; - UndirectedEdge fix_rhs_edge(UndirectedEdge const &) const; - -private: - UndirectedGraphView lhs; - UndirectedGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct JoinedDigraphView : virtual public IDiGraphView { -public: - JoinedDigraphView() = delete; - explicit JoinedDigraphView(DiGraphView const &lhs, DiGraphView const &rhs); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - JoinedNodeView const &joined_nodes_view() const; +UndirectedGraphView view_subgraph(UndirectedGraphView const &, + std::unordered_set const &); - JoinedDigraphView *clone() const override; - -private: - DirectedEdge fix_lhs_edge(DirectedEdge const &) const; - DirectedEdge fix_rhs_edge(DirectedEdge const &) const; - -private: - DiGraphView lhs; - DiGraphView rhs; - JoinedNodeView joined_nodes; -}; - -struct AddDirectedEdgesView : public IDiGraphView { -public: - AddDirectedEdgesView() = delete; - - explicit AddDirectedEdgesView(DiGraphView const &g, - std::unordered_set const &edges); - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - AddDirectedEdgesView *clone() const override; - -private: - DiGraphView g; - std::unordered_set edges; -}; - -struct SingleSourceNodeView : public IDiGraphView { -public: - SingleSourceNodeView() = delete; - - explicit SingleSourceNodeView(DiGraphView const &g) : g(g) {} - - std::unordered_set - query_edges(DirectedEdgeQuery const &) const override; - std::unordered_set query_nodes(NodeQuery const &) const override; - - SingleSourceNodeView *clone() const override; - -private: - DiGraphView g; - std::optional singleton_src; - std::optional joined_view; - std::unique_ptr added_edges_view; -}; +DiGraphView view_subgraph(DiGraphView const &, + std::unordered_set const &); UndirectedEdge to_undirected_edge(DirectedEdge const &); std::unordered_set @@ -176,31 +83,6 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView { UndirectedGraphView g; }; -std::unordered_map - flatten_contraction(std::unordered_map const &); - -template -Impl materialize_view(View const &g) { - Impl result; - for (Node const &n : get_nodes(g)) { - result.add_node_unsafe(n); - } - for (auto const &e : get_edges(g)) { - result.add_edge(e); - } - return result; -} - -template -Impl materialize_undirected_graph_view(IUndirectedGraphView const &g) { - return materialize_view(g); -} - -template -Impl materialize_digraph_view(IDiGraphView const &g) { - return materialize_view(g); -} - } // namespace FlexFlow #endif diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 6ed41daf43..d7cd979f14 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,7 +184,8 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = + UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; return contains(g.query_edges(q), e); } @@ -212,7 +213,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.smaller, e.bigger}; + return {e.endpoints.min(), e.endpoints.max()}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { @@ -480,15 +481,6 @@ DiGraphView get_subgraph(DiGraphView const &g, // return MultiDiGraphView::create(lhs, rhs); // } -DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) { - return DiGraphView::create(lhs, rhs); -} - -UndirectedGraphView join(UndirectedGraphView const &lhs, - UndirectedGraphView const &rhs) { - return UndirectedGraphView::create(lhs, rhs); -} - UndirectedGraphView as_undirected(DiGraphView const &g) { return UndirectedGraphView::create(g); } diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 61c4f80763..5d16304701 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -22,31 +22,34 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.bigger)); + if (!contains_key(this->adjacency, e.endpoints.max())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.max())); } - if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error(fmt::format( - "Could not add edge connected to non-existent node {}", e.smaller)); + if (!contains_key(this->adjacency, e.endpoints.min())) { + throw mk_runtime_error( + fmt::format("Could not add edge connected to non-existent node {}", + e.endpoints.min())); } - this->adjacency.at(e.bigger).insert(e.smaller); - this->adjacency.at(e.smaller).insert(e.bigger); + this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); + this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.bigger); - m.erase(e.smaller); - m.erase(e.bigger); + std::unordered_set &max_map = this->adjacency.at(e.endpoints.max()); + max_map.erase(e.endpoints.min()); + std::unordered_set &min_map = this->adjacency.at(e.endpoints.min()); + min_map.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : src_kv.second) { - result.insert({src_kv.first, dst}); + for (auto const &dst : apply_query(query.nodes, src_kv.second)) { + result.insert(UndirectedEdge{{src_kv.first, dst}}); } } return result; diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc index 6f6722f635..cb44f4636d 100644 --- a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -27,8 +27,8 @@ void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { } void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { - assert(contains(this->nodes, e.bigger)); - assert(contains(this->nodes, e.smaller)); + assert(contains(this->nodes, e.endpoints.min())); + assert(contains(this->nodes, e.endpoints.max())); this->edges.insert(e); } diff --git a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 12a6630bf0..78265f6856 100644 --- a/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,5 +1,9 @@ #include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/multidigraph/algorithms/get_edge_counts.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -10,13 +14,21 @@ ParallelReduction make_parallel_reduction(MultiDiEdge const &e1, std::optional find_parallel_reduction(MultiDiGraphView const &g) { - std::unordered_set edges = get_edges(g); - for (MultiDiEdge const &e1 : edges) { - for (MultiDiEdge const &e2 : edges) { - if (e1 != e2 && g.get_multidiedge_src(e1) == g.get_multidiedge_src(e2) && - g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { - return make_parallel_reduction(e1, e2); + for (auto const &[directed_edge, count] : get_edge_counts(g)) { + + if (count <= 1) { + continue; + } + + std::unordered_set const &outgoing_edges = + get_outgoing_edges(g, directed_edge.src); + for (MultiDiEdge const &e1 : outgoing_edges) { + for (MultiDiEdge const &e2 : outgoing_edges) { + if (e1 != e2 && + g.get_multidiedge_dst(e1) == g.get_multidiedge_dst(e2)) { + return make_parallel_reduction(e1, e2); + } } } } diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc index 3c05b9d5d5..726fda8af7 100644 --- a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -10,7 +10,7 @@ std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, std::unordered_set result = set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { - return std::unordered_set{e.bigger, e.smaller}; + return std::unordered_set{e.endpoints.max(), e.endpoints.max()}; })); result.erase(n); return result; diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 0a575e115c..4cfc6aaaa8 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,40 +1,11 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" +#include namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -static std::tuple tie(UndirectedEdge const &e) { - return std::tie(e.smaller, e.bigger); -} - -bool UndirectedEdge::operator==(UndirectedEdge const &other) const { - return tie(*this) == tie(other); -} - -bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { - return tie(*this) != tie(other); -} - -bool UndirectedEdge::operator<(UndirectedEdge const &other) const { - return tie(*this) < tie(other); -} - bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; + return e.endpoints.min() == n || e.endpoints.max() == n; } } // namespace FlexFlow - -namespace std { - -using namespace FlexFlow; - -size_t hash::operator()(UndirectedEdge const &e) const { - std::tuple members = ::FlexFlow::tie(e); - return std::hash{}(members); -} - -} // namespace std diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 3cccf1c6eb..e9e948aa40 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -7,7 +7,8 @@ UndirectedEdgeQuery undirected_edge_query_all() { } bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { - return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); + return includes(q.nodes, e.endpoints.max()) && + includes(q.nodes, e.endpoints.min()); } UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 9b5353de9f..7bb039d314 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,4 +1,5 @@ #include "utils/graph/views/views.h" +#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" #include "utils/disjoint_set.h" @@ -7,8 +8,8 @@ #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" +#include "utils/graph/query_set.h" #include "utils/graph/undirected/undirected_edge_query.h" - namespace FlexFlow { UndirectedSubgraphView::UndirectedSubgraphView( @@ -65,150 +66,8 @@ DiGraphView view_subgraph(DiGraphView const &g, return DiGraphView::create(g, subgraph_nodes); } -JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { - for (Node const &n : get_nodes(lhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::LEFT}, - this->node_source.new_node()); - } - for (Node const &n : get_nodes(rhs)) { - this->mapping.equate(JoinNodeKey{n, LRDirection::RIGHT}, - this->node_source.new_node()); - } -} - -std::unordered_set - JoinedNodeView::query_nodes(NodeQuery const &query) const { - // TODO @lockshaw this is going to be reimplemented in 984, so don't bother - // fixing it for now - NOT_IMPLEMENTED(); -} - -std::pair, std::unordered_set> - JoinedNodeView::trace_nodes(std::unordered_set const &nodes) const { - std::unordered_set left_nodes, right_nodes; - - for (Node const &n : nodes) { - JoinNodeKey k = this->at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - return {left_nodes, right_nodes}; -} - -Node JoinedNodeView::at_join_key(JoinNodeKey const &k) const { - return this->mapping.at_l(k); -} - -JoinNodeKey JoinedNodeView::at_node(Node const &n) const { - return this->mapping.at_r(n); -} - -JoinedUndirectedGraphView::JoinedUndirectedGraphView( - UndirectedGraphView const &lhs, UndirectedGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -std::unordered_set - JoinedUndirectedGraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set JoinedUndirectedGraphView::query_edges( - UndirectedEdgeQuery const &query) const { - std::unordered_set nodes = this->query_nodes(NodeQuery{query.nodes}); - std::unordered_set left_nodes, right_nodes; - for (Node const &n : nodes) { - JoinNodeKey k = this->joined_nodes.at_node(n); - if (k.direction == LRDirection::LEFT) { - left_nodes.insert(k.node); - } else { - assert(k.direction == LRDirection::RIGHT); - right_nodes.insert(k.node); - } - } - - std::unordered_set result; - for (UndirectedEdge const &e : - this->lhs.query_edges(UndirectedEdgeQuery{left_nodes})) { - result.insert(this->fix_lhs_edge(e)); - } - for (UndirectedEdge const &e : - this->rhs.query_edges(UndirectedEdgeQuery{right_nodes})) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return { - this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; -} - -UndirectedEdge - JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key( - JoinNodeKey{e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key( - JoinNodeKey{e.bigger, LRDirection::RIGHT})}; -} - -JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, - DiGraphView const &rhs) - : lhs(lhs), rhs(rhs), joined_nodes(lhs, rhs) {} - -JoinedDigraphView *JoinedDigraphView::clone() const { - return new JoinedDigraphView(lhs, rhs); -} - -std::unordered_set - JoinedDigraphView::query_nodes(NodeQuery const &query) const { - return this->joined_nodes.query_nodes(query); -} - -std::unordered_set - JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const { - - std::unordered_set srcs = this->query_nodes(NodeQuery{query.srcs}); - std::unordered_set dsts = this->query_nodes(NodeQuery{query.dsts}); - auto traced_srcs = this->joined_nodes.trace_nodes(srcs); - auto traced_dsts = this->joined_nodes.trace_nodes(dsts); - DirectedEdgeQuery left_query = - DirectedEdgeQuery{traced_srcs.first, traced_dsts.first}; - DirectedEdgeQuery right_query = - DirectedEdgeQuery{traced_srcs.second, traced_dsts.second}; - - std::unordered_set result; - for (DirectedEdge const &e : this->lhs.query_edges(left_query)) { - result.insert(this->fix_lhs_edge(e)); - } - for (DirectedEdge const &e : this->rhs.query_edges(right_query)) { - result.insert(this->fix_rhs_edge(e)); - } - - return result; -} - -DirectedEdge JoinedDigraphView::fix_lhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::LEFT})}; -} - -DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { - return DirectedEdge{ - this->joined_nodes.at_join_key(JoinNodeKey{e.src, LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.dst, LRDirection::RIGHT})}; -} - UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return {e.src, e.dst}; + return UndirectedEdge{{e.src, e.dst}}; } std::unordered_set to_undirected_edges( @@ -218,8 +77,9 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, - DirectedEdge{e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.endpoints.min(), e.endpoints.max()}, + DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; } std::unordered_set to_directed_edges( @@ -258,8 +118,8 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), + g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); @@ -272,8 +132,4 @@ std::unordered_set return g.query_nodes(q); } -JoinedUndirectedGraphView *JoinedUndirectedGraphView::clone() const { - return new JoinedUndirectedGraphView(lhs, rhs); -} - } // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc new file mode 100644 index 0000000000..0817c69e06 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/algorithms.cc @@ -0,0 +1,106 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DiGraph - algorithms.cc") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + }; + add_edges(g, e); + + SUBCASE("get_edges") { + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[3], n[1]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n[0], n[3]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + } + + SUBCASE("get_sinks") { + SUBCASE("Base") { + std::unordered_set correct = {n[2], n[3]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a sink") { + g.add_edge(DirectedEdge{n[3], n[2]}); + std::unordered_set correct = {n[2]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sinks(g); + std::unordered_set correct = {n[3]}; + CHECK(result == correct); + } + } + + SUBCASE("get_sources") { + SUBCASE("Base") { + std::unordered_set correct = {n[0]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set correct = {}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n[0], n[1]}); + std::unordered_set correct = {n[0], n[1]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sources(g); + std::unordered_set correct = {}; + CHECK(result.empty()); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc new file mode 100644 index 0000000000..3a3648eec8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/digraph.cc @@ -0,0 +1,85 @@ +#include "utils/graph/digraph/digraph.h" +#include "utils/containers/repeat.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { + /* + graph TD + + n0 --> n1 + n0 --> n2 + n1 --> n2 + n2 --> n4 + n1 --> n3 + */ + + DiGraph g = DiGraph::create(); + std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector e = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}; + for (DirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[1], n[2], n[3], n[4]}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[3], e[4]}); + + g.remove_node_unsafe(n[1]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[2], n[3], n[4]}); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[3]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[1], e[2], e[3], e[4]}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + g.remove_edge(e[1]); + g.remove_edge(e[3]); + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[4]}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc new file mode 100644 index 0000000000..1dde5c8f69 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/directed_edge_query.cc @@ -0,0 +1,70 @@ +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("directed_edge_query") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}}); + + SUBCASE("directed_edge_query_all") { + + DirectedEdgeQuery result = directed_edge_query_all(); + + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); + } + + SUBCASE("matches_edge") { + DirectedEdgeQuery q = + DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; + + CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); + CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); + } + + SUBCASE("query_intersection") { + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; + DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, + query_set{n.at(2), n.at(3)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1)}, + query_set{n.at(2)}, + }; + + CHECK(result == correct); + } + SUBCASE("intersection with std::nullopt") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc new file mode 100644 index 0000000000..e9151b53e5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -0,0 +1,68 @@ +#include "utils/graph/digraph/algorithms/get_dominators.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dominators") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + }; + add_edges(g, e); + + SUBCASE("single node") { + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); + } + + SUBCASE("multiple nodes") { + std::unordered_set nodes = {n.at(1), n.at(3)}; + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n.at(0)}; + CHECK(correct == result); + } + + SUBCASE("graph with cycles") { + // example from + // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 + + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(1)}, + }); + + SUBCASE("node 1") { + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + CHECK(result == correct); + } + + SUBCASE("node 3") { + std::unordered_set result = get_dominators(g, n.at(3)); + std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc new file mode 100644 index 0000000000..5adc0cc4df --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -0,0 +1,36 @@ +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/containers/index_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc new file mode 100644 index 0000000000..0d8e7ca53a --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/traversal.cc @@ -0,0 +1,112 @@ +#include "utils/graph/traversal.h" +#include "utils/fmt/vector.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/hash/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unchecked_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + } + + TEST_CASE("get_bfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}}); + + SUBCASE("branching path") { + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3], n[4], n[5]}, + {n[0], n[2], n[1], n[3], n[4], n[5]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("isolated node") { + std::vector correct = {n[5]}; + std::vector result = get_bfs_ordering(g, {n[5]}); + CHECK(correct == result); + } + + SUBCASE("graph with cycle") { + g = DiGraph::create(); + n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[0]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[2], n[1]}}); + std::unordered_set> corrects = {{n[0], n[1], n[2]}, + {n[0], n[2], n[1]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + } + + TEST_CASE("get_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("with cycle") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("branching") { + g.add_edge(DirectedEdge{n[1], n[3]}); + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("disconnected") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[0], n[1], n[2]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("isolated node") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[3]}; + std::vector result = get_dfs_ordering(g, {n[3]}); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..b5943cd99f --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -0,0 +1,36 @@ +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector edges = add_edges(g, + {{n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}}); + + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..d4748e8422 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,40 @@ +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector> input = { + {n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}, + }; + + std::vector edges = add_edges(g, input); + + SUBCASE("node has outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no outgoing edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc new file mode 100644 index 0000000000..179cce7db7 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -0,0 +1,61 @@ +#include "utils/graph/undirected/algorithms/get_connected_components.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/undirected_graph.h" +#include + +using namespace FlexFlow; + +TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + + SUBCASE("disjoint nodes") { + std::vector n = add_nodes(g, 3); + + std::unordered_set> correct = { + {n[0]}, + {n[1]}, + {n[2]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("2 components") { + std::vector n = add_nodes(g, 4); + add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } + + SUBCASE("3 components") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + UndirectedEdge{{n[0], n[1]}}, + UndirectedEdge{{n[0], n[2]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[3], n[4]}}, + }); + + std::unordered_set> correct = { + {n[0], n[1], n[2]}, + {n[3], n[4]}, + {n[5]}, + }; + std::unordered_set> result = + get_connected_components(g); + + CHECK(correct == result); + } +} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected.cc b/lib/utils/test/src/utils/graph/undirected/undirected.cc new file mode 100644 index 0000000000..7973cf8af5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/undirected/undirected.cc @@ -0,0 +1,75 @@ +#include "test/utils/rapidcheck.h" +#include "test/utils/rapidcheck/visitable.h" +#include "utils/commutative_pair.h" +#include "utils/containers/repeat.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" +#include "utils/graph/undirected/undirected_graph.h" + +using namespace FlexFlow; + +using namespace rc; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} +/* static_assert(is_fmtable::value, ""); */ + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + RC_SUBCASE("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); + }); + } +} diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc new file mode 100644 index 0000000000..8a6a44d1cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -0,0 +1,153 @@ +#include "utils/graph/views/views.h" +#include "utils/containers/set_union.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("UndirectedSubgraphView") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(4)}}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + UndirectedGraphView view = view_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("DiSubgraphView") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; + DiGraphView view = view_subgraph(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewDiGraphAsUndirectedGraph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}); + + UndirectedGraphView view = as_undirected(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewUndirectedGraphAsDiGraph") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {UndirectedEdge{{n.at(0), n.at(0)}}, + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}); + + DiGraphView view = as_digraph(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(0)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } +}