Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda7xx committed Jul 15, 2023
1 parent 05bd453 commit 6691b9e
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 59 deletions.
2 changes: 1 addition & 1 deletion lib/utils/include/utils/graph/multidiedge.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct MultiDiEdge {
NodePort srcIdx, dstIdx;
};
FF_VISITABLE_STRUCT(MultiDiEdge, src, dst, srcIdx, dstIdx);
std::ostream& operator<<(std::ostream& os, const MultiDiEdge& edge);
std::ostream &operator<<(std::ostream &os, MultiDiEdge const &edge);

struct MultiDiInput {
Node node;
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/multidigraph_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct MultiDiEdgeQuery {
MultiDiEdgeQuery with_src_idxs(query_set<NodePort> const &) const;
MultiDiEdgeQuery with_dst_idxs(query_set<NodePort> const &) const;

MultiDiEdgeQuery with_dst_node(Node const &) const;
MultiDiEdgeQuery with_dst_node(Node const &) const;
MultiDiEdgeQuery with_src_node(Node const &) const;
MultiDiEdgeQuery with_src_idx(NodePort const &) const;
MultiDiEdgeQuery with_dst_idx(NodePort const &) const;
Expand All @@ -40,7 +40,7 @@ struct IMultiDiGraphView : public IGraphView {
using EdgeQuery = MultiDiEdgeQuery;

virtual std::unordered_set<Edge> query_edges(EdgeQuery const &) const = 0;
virtual ~IMultiDiGraphView()=default;
virtual ~IMultiDiGraphView() = default;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraphView);

Expand Down
2 changes: 1 addition & 1 deletion lib/utils/include/utils/graph/open_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct OpenMultiDiGraph {
}

OpenMultiDiGraph(std::unique_ptr<IOpenMultiDiGraph> ptr);

private:
cow_ptr_t<IOpenMultiDiGraph> ptr;
};
Expand Down
4 changes: 2 additions & 2 deletions lib/utils/include/utils/graph/query_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ struct query_set {
query_set(T const &) {
NOT_IMPLEMENTED();
}
query_set(std::unordered_set<T> const & query): query(query) {}
query_set(std::unordered_set<T> const &query) : query(query) {}

query_set(optional<std::unordered_set<T>> const &) {
NOT_IMPLEMENTED();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/src/graph/adjacency_multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace FlexFlow {
Node AdjacencyMultiDiGraph::add_node() {
Node node{this->next_node_idx};
adjacency[node];
std::cout<<"add node "<<node.value()<<std::endl;
std::cout << "add node " << node.value() << std::endl;
this->next_node_idx++;
return node;
}
Expand Down
15 changes: 9 additions & 6 deletions lib/utils/src/graph/digraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,20 @@ DiGraphView unsafe_create(IDiGraphView const &graphView) {
}

DirectedEdgeQuery DirectedEdgeQuery::all() {
return {matchall<Node>(),
matchall<Node>()};
return {matchall<Node>(), matchall<Node>()};
}

DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, DirectedEdgeQuery const &rhs){
DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs,
DirectedEdgeQuery const &rhs) {
assert(lhs != tl::nullopt);
assert(rhs != tl::nullopt);
assert (lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() && rhs.dsts.has_value());
assert(lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() &&
rhs.dsts.has_value());

std::unordered_set<Node> srcs_t1 = intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 = intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));
std::unordered_set<Node> srcs_t1 =
intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 =
intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));

DirectedEdgeQuery result = DirectedEdgeQuery::all();
result.srcs = srcs_t1;
Expand Down
24 changes: 14 additions & 10 deletions lib/utils/src/graph/multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,19 @@ MultiDiEdgeQuery
}

MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs,
MultiDiEdgeQuery const &rhs) {
MultiDiEdgeQuery const &rhs) {
assert(lhs != tl::nullopt);
assert(rhs != tl::nullopt);

std::unordered_set<Node> srcs_t1 = intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 = intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));
std::unordered_set<Node> srcs_t1 =
intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 =
intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));

std::unordered_set<NodePort> srcIdxs_t1 = intersection(allowed_values(lhs.srcIdxs), allowed_values(rhs.srcIdxs));
std::unordered_set<NodePort> dstIdxs_t1 = intersection(allowed_values(lhs.dstIdxs), allowed_values(rhs.dstIdxs));
std::unordered_set<NodePort> srcIdxs_t1 =
intersection(allowed_values(lhs.srcIdxs), allowed_values(rhs.srcIdxs));
std::unordered_set<NodePort> dstIdxs_t1 =
intersection(allowed_values(lhs.dstIdxs), allowed_values(rhs.dstIdxs));

MultiDiEdgeQuery e = MultiDiEdgeQuery::all();
e.srcs = srcs_t1;
Expand All @@ -61,19 +65,19 @@ MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs,
MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_node(Node const &n) const {
return this->with_dst_nodes({n});
}
MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idx(NodePort const & p) const {
MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idx(NodePort const &p) const {
return this->with_src_idxs({p});
}

MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_idx(NodePort const & p) const {
MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_idx(NodePort const &p) const {
return this->with_dst_idxs({p});
}

MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idxs(
query_set<NodePort> const &idxs) const {
MultiDiEdgeQuery
MultiDiEdgeQuery::with_src_idxs(query_set<NodePort> const &idxs) const {
MultiDiEdgeQuery e{*this};
if (is_matchall(e.srcIdxs)) {
throw mk_runtime_error("Expected matchall previous value");
throw mk_runtime_error("Expected matchall previous value");
}
e.srcIdxs = idxs;
return e;
Expand Down
11 changes: 6 additions & 5 deletions lib/utils/src/graph/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ std::ostream &operator<<(std::ostream &os, Node const &node) {
return os << fmt::format("Node({})", node.value());
}

NodeQuery NodeQuery::all() {
return {matchall<Node>()} ;
}
NodeQuery NodeQuery::all() {
return {matchall<Node>()};
}

NodeQuery query_intersection(NodeQuery const & lhs, NodeQuery const & rhs) {
NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) {
assert(lhs != tl::nullopt && rhs != tl::nullopt);
std::unordered_set<Node> nodes = intersection(allowed_values(lhs.nodes), allowed_values(rhs.nodes));
std::unordered_set<Node> nodes =
intersection(allowed_values(lhs.nodes), allowed_values(rhs.nodes));
NodeQuery intersection_result = NodeQuery::all();
intersection_result.nodes = nodes;
return intersection_result;
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/src/graph/views.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ std::unordered_set<Node> ViewOpenMultiDiGraphAsMultiDiGraph::query_nodes(

std::unordered_set<MultiDiEdge> ViewOpenMultiDiGraphAsMultiDiGraph::query_edges(
MultiDiEdgeQuery const &query) const {

// OpenMultiDiEdgeQuery q{query};
// // q.standard_edge_query = query;
// std::unordered_set<OpenMultiDiEdge> edges = g.query_edges(q);
Expand Down
40 changes: 21 additions & 19 deletions lib/utils/test/src/test_adjacency_multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") {
g.add_edge(e3);
g.add_edge(e4);

CHECK(g.query_nodes(NodeQuery::all()) == std::unordered_set<Node>{n0, n1, n2});

CHECK(g.query_nodes(NodeQuery::all()) ==
std::unordered_set<Node>{n0, n1, n2});

std::unordered_set<Node> nodes = {n0, n2};
query_set<Node> q{nodes};
CHECK(g.query_nodes(NodeQuery(q)) == std::unordered_set<Node>{n0, n2});

//CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e1, e2, e3, e4});
// CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e1, e2, e3,
// e4});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n1)) ==
std::unordered_set<MultiDiEdge>{});
Expand All @@ -37,22 +39,22 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") {
std::unordered_set<MultiDiEdge>{});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p1)) ==
std::unordered_set<MultiDiEdge>{e1, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n1, n2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n0, n2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p1, p2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0, p2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all()
// .with_src_node(n1)
// .with_dst_node(n2)
// .with_src_idx(p1)
// .with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{e2});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n1, n2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n0, n2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p1, p2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0, p2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all()
// .with_src_node(n1)
// .with_dst_node(n2)
// .with_src_idx(p1)
// .with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{e2});
}

// TEST_CASE("AdjacencyMultiDiGraph:remove_node") {
Expand Down
26 changes: 15 additions & 11 deletions lib/utils/test/src/test_algorithms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@
// g.add_edge(e3);

// CHECK(g.query_nodes({}) == std::unordered_set<Node>{n0, n1, n2, n3});
// CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e0, e1, e2, e3});
// CHECK(get_incoming_edges(g, {n1, n3}) ==
// CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e0, e1, e2,
// e3}); CHECK(get_incoming_edges(g, {n1, n3}) ==
// std::unordered_set<MultiDiEdge>{e0, e2, e3});
// CHECK(get_incoming_edges(g, {n1}) == std::unordered_set<MultiDiEdge>{});
// CHECK(get_outgoing_edges(g, {n2, n3}) == std::unordered_set<MultiDiEdge>{e3});
// auto res = get_predecessors(g, {n1, n2, n3});
// auto expected_result = std::unordered_map<Node, std::unordered_set<Node>>{
// CHECK(get_outgoing_edges(g, {n2, n3}) ==
// std::unordered_set<MultiDiEdge>{e3}); auto res = get_predecessors(g, {n1,
// n2, n3}); auto expected_result = std::unordered_map<Node,
// std::unordered_set<Node>>{
// {n1, {}},
// {n2, {n1}},
// {n3, {n0, n1, n2}},
Expand All @@ -60,11 +61,12 @@
// g.add_edge(e2);
// g.add_edge(e3);

// CHECK(g.query_edges({}) == std::unordered_set<DirectedEdge>{e0, e1, e2, e3});
// CHECK(get_incoming_edges(g, {n2, n3}) ==
// CHECK(g.query_edges({}) == std::unordered_set<DirectedEdge>{e0, e1, e2,
// e3}); CHECK(get_incoming_edges(g, {n2, n3}) ==
// std::unordered_set<DirectedEdge>{e0, e2, e3});
// CHECK(get_outgoing_edges(g, {n2, n3}) == std::unordered_set<DirectedEdge>{});
// auto expected_result = std::unordered_map<Node, std::unordered_set<Node>>{
// CHECK(get_outgoing_edges(g, {n2, n3}) ==
// std::unordered_set<DirectedEdge>{}); auto expected_result =
// std::unordered_map<Node, std::unordered_set<Node>>{
// {n1, {n0}},
// {n2, {n0, n1}},
// {n3, {n0}},
Expand All @@ -82,7 +84,8 @@
// g.add_edge({n[1], n[2]});
// g.add_edge({n[2], n[3]});

// /* CHECK(get_incoming_edges(g, n[0]) == std::unordered_set<DirectedEdge>{});
// /* CHECK(get_incoming_edges(g, n[0]) ==
// std::unordered_set<DirectedEdge>{});
// */
// CHECK(get_sources(g) == std::unordered_set<Node>{n[0]});
// CHECK(get_unchecked_dfs_ordering(g, {n[0]}) ==
Expand Down Expand Up @@ -136,7 +139,8 @@
// auto CHECK_BEFORE = [&](int l, int r) {
// CHECK(index_of(ordering, n[l]).has_value());
// CHECK(index_of(ordering, n[r]).has_value());
// CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value());
// CHECK(index_of(ordering, n[l]).value() < index_of(ordering,
// n[r]).value());
// };

// CHECK(ordering.size() == n.size());
Expand Down

0 comments on commit 6691b9e

Please sign in to comment.