Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Library #1521

Open
wants to merge 7 commits into
base: repo-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 51 additions & 117 deletions lib/utils/include/utils/graph/README.md

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ struct DataflowGraph : virtual public DataflowGraphView {
private:
IDataflowGraph &get_interface();
IDataflowGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
15 changes: 15 additions & 0 deletions lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@

namespace FlexFlow {

/**
* @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory)
*
* @note By definition, the root node dominates every node and every node
* dominates itself.
*
*/
std::unordered_set<Node> get_dominators(DiGraphView const &, Node const &);

/**
* @brief Returns the intersection of the dominators of the given set of nodes.
* @note This is conceptually equivalent to merging the given set of nodes and
* then finding the set of dominators of the new merged node (where merged means
* that all edges belonging to the set of nodes now pass through a single
* unified node).
*/
std::unordered_set<Node> get_dominators(DiGraphView const &,
std::unordered_set<Node> const &);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView {
private:
IDiGraph &get_ptr();
IDiGraph const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/digraph/digraph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView {

private:
IDiGraphView const &get_ptr() const;

friend struct GraphInternal;
};
CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView);

Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/multidigraph/multidigraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView {
private:
IMultiDiGraph &get_interface();
IMultiDiGraph const &get_interface() const;

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ struct Graph : virtual GraphView {
private:
IGraph const &get_ptr() const;
IGraph &get_ptr();

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/node/graph_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ struct GraphView {
GraphView();
cow_ptr_t<IGraphView> ptr;
GraphView(cow_ptr_t<IGraphView> ptr);

friend struct GraphInternal;
};

} // namespace FlexFlow
Expand Down
1 change: 1 addition & 0 deletions lib/utils/include/utils/graph/node/node.struct.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ features = [
"hash",
"fmt",
"json",
"rapidcheck",
]

includes = [
Expand Down
27 changes: 3 additions & 24 deletions lib/utils/include/utils/graph/undirected/undirected_edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,12 @@
#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H

#include "utils/graph/node/node.dtg.h"
namespace FlexFlow {

struct UndirectedEdge {
public:
UndirectedEdge() = delete;
UndirectedEdge(Node const &src, Node const &dst);
#include "utils/graph/undirected/undirected_edge.dtg.h"

bool operator==(UndirectedEdge const &) const;
bool operator!=(UndirectedEdge const &) const;
bool operator<(UndirectedEdge const &) const;

public:
Node smaller;
Node bigger;
};
namespace FlexFlow {

bool is_connected_to(UndirectedEdge const &, Node const &);
bool is_connected_to(UndirectedEdge const &e, Node const &n);

} // namespace FlexFlow

namespace std {

template <>
struct hash<::FlexFlow::UndirectedEdge> {
size_t operator()(::FlexFlow::UndirectedEdge const &) const;
};

} // namespace std

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
namespace = "FlexFlow"
name = "UndirectedEdge"
features = [
"eq",
"ord",
"hash",
"fmt",
"rapidcheck"
]

includes = [
"utils/commutative_pair.h",
"utils/graph/node/node.dtg.h",
]

[[fields]]
name = "endpoints"
type = "::FlexFlow::commutative_pair<::FlexFlow::Node>"
2 changes: 0 additions & 2 deletions lib/utils/include/utils/graph/undirected/undirected_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView {

using UndirectedGraphView::UndirectedGraphView;

friend struct GraphInternal;

private:
IUndirectedGraph const &get_ptr() const;
IUndirectedGraph &get_ptr();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView {

using GraphView::GraphView;

friend struct GraphInternal;

private:
IUndirectedGraphView const &get_ptr() const;
};
Expand Down
126 changes: 4 additions & 122 deletions lib/utils/include/utils/graph/views/views.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,104 +41,11 @@ struct DiSubgraphView : public IDiGraphView {
std::unordered_set<Node> subgraph_nodes;
};

struct JoinedNodeView {
public:
JoinedNodeView() = delete;
explicit JoinedNodeView(GraphView const &lhs, GraphView const &rhs);

std::unordered_set<Node> query_nodes(NodeQuery const &) const;
std::pair<std::unordered_set<Node>, std::unordered_set<Node>>
trace_nodes(std::unordered_set<Node> const &) const;

Node at_join_key(JoinNodeKey const &) const;
JoinNodeKey at_node(Node const &) const;

private:
bidict<JoinNodeKey, Node> mapping;
NodeSource node_source;
};

struct JoinedUndirectedGraphView : public IUndirectedGraphView {
public:
JoinedUndirectedGraphView() = delete;
explicit JoinedUndirectedGraphView(UndirectedGraphView const &lhs,
UndirectedGraphView const &rhs);

std::unordered_set<UndirectedEdge>
query_edges(UndirectedEdgeQuery const &) const override;
std::unordered_set<Node> query_nodes(NodeQuery const &) const override;

JoinedUndirectedGraphView *clone() const override;

private:
UndirectedEdge fix_lhs_edge(UndirectedEdge const &) const;
UndirectedEdge fix_rhs_edge(UndirectedEdge const &) const;

private:
UndirectedGraphView lhs;
UndirectedGraphView rhs;
JoinedNodeView joined_nodes;
};

struct JoinedDigraphView : virtual public IDiGraphView {
public:
JoinedDigraphView() = delete;
explicit JoinedDigraphView(DiGraphView const &lhs, DiGraphView const &rhs);

std::unordered_set<DirectedEdge>
query_edges(DirectedEdgeQuery const &) const override;
std::unordered_set<Node> query_nodes(NodeQuery const &) const override;

JoinedNodeView const &joined_nodes_view() const;
UndirectedGraphView view_subgraph(UndirectedGraphView const &,
std::unordered_set<Node> const &);

JoinedDigraphView *clone() const override;

private:
DirectedEdge fix_lhs_edge(DirectedEdge const &) const;
DirectedEdge fix_rhs_edge(DirectedEdge const &) const;

private:
DiGraphView lhs;
DiGraphView rhs;
JoinedNodeView joined_nodes;
};

struct AddDirectedEdgesView : public IDiGraphView {
public:
AddDirectedEdgesView() = delete;

explicit AddDirectedEdgesView(DiGraphView const &g,
std::unordered_set<DirectedEdge> const &edges);

std::unordered_set<DirectedEdge>
query_edges(DirectedEdgeQuery const &) const override;
std::unordered_set<Node> query_nodes(NodeQuery const &) const override;

AddDirectedEdgesView *clone() const override;

private:
DiGraphView g;
std::unordered_set<DirectedEdge> edges;
};

struct SingleSourceNodeView : public IDiGraphView {
public:
SingleSourceNodeView() = delete;

explicit SingleSourceNodeView(DiGraphView const &g) : g(g) {}

std::unordered_set<DirectedEdge>
query_edges(DirectedEdgeQuery const &) const override;
std::unordered_set<Node> query_nodes(NodeQuery const &) const override;

SingleSourceNodeView *clone() const override;

private:
DiGraphView g;
std::optional<AdjacencyDiGraph> singleton_src;
std::optional<JoinedDigraphView> joined_view;
std::unique_ptr<AddDirectedEdgesView> added_edges_view;
};
DiGraphView view_subgraph(DiGraphView const &,
std::unordered_set<Node> const &);

UndirectedEdge to_undirected_edge(DirectedEdge const &);
std::unordered_set<UndirectedEdge>
Expand Down Expand Up @@ -176,31 +83,6 @@ struct ViewUndirectedGraphAsDiGraph : public IDiGraphView {
UndirectedGraphView g;
};

std::unordered_map<Node, Node>
flatten_contraction(std::unordered_map<Node, Node> const &);

template <typename Impl, typename View>
Impl materialize_view(View const &g) {
Impl result;
for (Node const &n : get_nodes(g)) {
result.add_node_unsafe(n);
}
for (auto const &e : get_edges(g)) {
result.add_edge(e);
}
return result;
}

template <typename Impl>
Impl materialize_undirected_graph_view(IUndirectedGraphView const &g) {
return materialize_view<Impl, IUndirectedGraphView>(g);
}

template <typename Impl>
Impl materialize_digraph_view(IDiGraphView const &g) {
return materialize_view<Impl, IDiGraphView>(g);
}

} // namespace FlexFlow

#endif
14 changes: 3 additions & 11 deletions lib/utils/src/utils/graph/algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@
}

bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) {
UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}};
UndirectedEdgeQuery q =
UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}};

Check warning on line 188 in lib/utils/src/utils/graph/algorithms.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/algorithms.cc#L187-L188

Added lines #L187 - L188 were not covered by tests
return contains(g.query_edges(q), e);
}

Expand Down Expand Up @@ -212,7 +213,7 @@
}

std::unordered_set<Node> get_endpoints(UndirectedEdge const &e) {
return {e.smaller, e.bigger};
return {e.endpoints.min(), e.endpoints.max()};

Check warning on line 216 in lib/utils/src/utils/graph/algorithms.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/algorithms.cc#L216

Added line #L216 was not covered by tests
}

// std::unordered_set<MultiDiEdge> get_edges(MultiDiGraphView const &g) {
Expand Down Expand Up @@ -480,15 +481,6 @@
// return MultiDiGraphView::create<JoinedMultiDigraphView>(lhs, rhs);
// }

DiGraphView join(DiGraphView const &lhs, DiGraphView const &rhs) {
return DiGraphView::create<JoinedDigraphView>(lhs, rhs);
}

UndirectedGraphView join(UndirectedGraphView const &lhs,
UndirectedGraphView const &rhs) {
return UndirectedGraphView::create<JoinedUndirectedGraphView>(lhs, rhs);
}

UndirectedGraphView as_undirected(DiGraphView const &g) {
return UndirectedGraphView::create<ViewDiGraphAsUndirectedGraph>(g);
}
Expand Down
29 changes: 16 additions & 13 deletions lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,34 @@
}

void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) {
if (!contains_key(this->adjacency, e.bigger)) {
throw mk_runtime_error(fmt::format(
"Could not add edge connected to non-existent node {}", e.bigger));
if (!contains_key(this->adjacency, e.endpoints.max())) {
throw mk_runtime_error(
fmt::format("Could not add edge connected to non-existent node {}",
e.endpoints.max()));

Check warning on line 28 in lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc#L26-L28

Added lines #L26 - L28 were not covered by tests
}
if (!contains_key(this->adjacency, e.smaller)) {
throw mk_runtime_error(fmt::format(
"Could not add edge connected to non-existent node {}", e.smaller));
if (!contains_key(this->adjacency, e.endpoints.min())) {
throw mk_runtime_error(
fmt::format("Could not add edge connected to non-existent node {}",
e.endpoints.min()));

Check warning on line 33 in lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc#L31-L33

Added lines #L31 - L33 were not covered by tests
}

this->adjacency.at(e.bigger).insert(e.smaller);
this->adjacency.at(e.smaller).insert(e.bigger);
this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min());
this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max());
}

void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) {
std::unordered_set<Node> &m = this->adjacency.at(e.bigger);
m.erase(e.smaller);
m.erase(e.bigger);
std::unordered_set<Node> &max_map = this->adjacency.at(e.endpoints.max());
max_map.erase(e.endpoints.min());
std::unordered_set<Node> &min_map = this->adjacency.at(e.endpoints.min());
min_map.erase(e.endpoints.max());

Check warning on line 44 in lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc#L41-L44

Added lines #L41 - L44 were not covered by tests
}

std::unordered_set<UndirectedEdge> HashmapUndirectedGraph::query_edges(
UndirectedEdgeQuery const &query) const {
std::unordered_set<UndirectedEdge> result;
for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) {
for (auto const &dst : src_kv.second) {
result.insert({src_kv.first, dst});
for (auto const &dst : apply_query(query.nodes, src_kv.second)) {
result.insert(UndirectedEdge{{src_kv.first, dst}});
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
}

void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) {
assert(contains(this->nodes, e.bigger));
assert(contains(this->nodes, e.smaller));
assert(contains(this->nodes, e.endpoints.min()));
assert(contains(this->nodes, e.endpoints.max()));

Check warning on line 31 in lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc

View check run for this annotation

Codecov / codecov/patch

lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc#L30-L31

Added lines #L30 - L31 were not covered by tests
this->edges.insert(e);
}

Expand Down
Loading
Loading