Skip to content

Commit

Permalink
add test_adjacency_multidigraph.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda7xx committed Jul 17, 2023
1 parent 42178e7 commit 6486f20
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 115 deletions.
1 change: 0 additions & 1 deletion lib/utils/include/utils/containers.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ std::unordered_map<K, V> filter_keys(std::unordered_map<K, V> const &m,
result.insert(kv);
}
}
std::cout<<"filter_keys,result.size():"<<result.size()<<std::endl;
return result;
}

Expand Down
7 changes: 3 additions & 4 deletions lib/utils/include/utils/graph/query_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ namespace FlexFlow {
template <typename T>
struct query_set {
query_set() = delete;
query_set(T const &query) : query({query}) {
std::cout << "1" << std::endl;
}
query_set(T const &query) : query({query}) {}

query_set(std::unordered_set<T> const &query) : query(query) {}

query_set(optional<std::unordered_set<T>> const &query) : query(query) {}
Expand Down Expand Up @@ -71,10 +70,10 @@ template <typename C,
typename K = typename C::key_type,
typename V = typename C::mapped_type>
std::unordered_map<K, V> query_keys(query_set<K> const &q, C const &m) {
std::cout << "3" << std::endl;
if(is_matchall(q)) {
return m;
}

std::unordered_set<K> q_set = allowed_values(q);
auto filter_lambda = [&q_set](K const &key) {
return q_set.find(key) != q_set.end();
Expand Down
23 changes: 0 additions & 23 deletions lib/utils/src/graph/adjacency_multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace FlexFlow {
Node AdjacencyMultiDiGraph::add_node() {
Node node{this->next_node_idx};
adjacency[node];
std::cout << "add node " << node.value() << std::endl;
this->next_node_idx++;
return node;
}
Expand Down Expand Up @@ -42,27 +41,6 @@ void AdjacencyMultiDiGraph::remove_edge(MultiDiEdge const &e) {
std::unordered_set<MultiDiEdge>
AdjacencyMultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const {
std::unordered_set<MultiDiEdge> result;
// //std::cout<<"q.srcs:"<<q.srcs.value()<<" and q.dsts:"<<q.dsts.value()<< " and q.srcIdxs:"<<q.srcIdxs.value()<<" and q.dstIdxs:"<<q.dstIdxs.value()<<std::endl;
// std::cout<<"AdjacencyMultiDiGraph::query_edges"<<std::endl;
// for(auto kv : this->adjacency) {
// std::cout<<"this->adjacency node1:"<<kv.first.value()<<std::endl;
// for(auto node_nodeport: kv.second) {
// std::cout<<"this->adjacency node2:"<<node_nodeport.first.value()<<std::endl;
// for(auto nodeport_nodeport: node_nodeport.second) {
// std::cout<<"this->adjacency NodePort1:"<<nodeport_nodeport.first.value()<<std::endl;
// for(auto nodeport: nodeport_nodeport.second) {
// std::cout<<"this->adjacency NodePort2:"<<nodeport.value()<<std::endl;
// }
// }
// }
// std::cout<<"*********"<<std::endl;
// }

// auto src_kvs = query_keys(q.srcs, this->adjacency);
// std::cout<<"x.size:"<<src_kvs.size()<<std::endl;
// for(auto const &x:src_kvs) {
// std::cout<<"x.first:"<<x.first.value() <<" and x.second.size()"<<x.second.size()<<std::endl;
// }
for (auto const &src_kv : query_keys(q.srcs, this->adjacency)) {
for (auto const &dst_kv : query_keys(q.dsts, src_kv.second)) {
for (auto const &srcIdx_kv : query_keys(q.srcIdxs, dst_kv.second)) {
Expand All @@ -73,7 +51,6 @@ std::unordered_set<MultiDiEdge>
}

}
std::cout<<"query_edges, result.size():"<<result.size()<<std::endl;
return result;
}

Expand Down
139 changes: 52 additions & 87 deletions lib/utils/test/src/test_adjacency_multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,11 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") {
CHECK(g.query_nodes(NodeQuery::all()) ==
std::unordered_set<Node>{n0, n1, n2});

std::unordered_set<Node> nodes = {n0, n2};
query_set<Node> q{nodes};
CHECK(g.query_nodes(NodeQuery(q)) == std::unordered_set<Node>{n0, n2});
CHECK(g.query_nodes(NodeQuery(query_set<Node>({n0, n2}))) == std::unordered_set<Node>{n0, n2});

// CHECK(g.query_edges(MultiDiEdgeQuery::all()) ==
// std::unordered_set<MultiDiEdge>{e1, e2, e3, e4});
CHECK(g.query_edges(MultiDiEdgeQuery::all()) ==
std::unordered_set<MultiDiEdge>{e1, e2, e3, e4});

// CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e1, e2, e3,
// e4});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n1)) ==
std::unordered_set<MultiDiEdge>{});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n1)) ==
Expand All @@ -41,84 +37,53 @@ TEST_CASE("AdjacencyMultiDiGraph:basic_test") {
std::unordered_set<MultiDiEdge>{});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p1)) ==
std::unordered_set<MultiDiEdge>{e1, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n1, n2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n0, n2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p1, p2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0, p2})) ==
// std::unordered_set<MultiDiEdge>{e2, e3});
// CHECK(g.query_edges(MultiDiEdgeQuery::all()
// .with_src_node(n1)
// .with_dst_node(n2)
// .with_src_idx(p1)
// .with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{e2});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set<Node>({n1, n2}))) ==
std::unordered_set<MultiDiEdge>{e3, e4});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set<Node>({n0, n2}))) ==
std::unordered_set<MultiDiEdge>{e2, e3});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set<NodePort>({p1, p2}))) ==
std::unordered_set<MultiDiEdge>{e3, e4});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set<NodePort>({p0, p2}))) ==
std::unordered_set<MultiDiEdge>{e2, e3});
CHECK(g.query_edges(MultiDiEdgeQuery::all()
.with_src_node(n1)
.with_dst_node(n2)
.with_src_idx(p1)
.with_dst_idx(p2)) ==
std::unordered_set<MultiDiEdge>{});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p2)) ==
std::unordered_set<MultiDiEdge>{e2});

SUBCASE("remove node" ) {
g.remove_node_unsafe(n0);

CHECK(g.query_nodes(NodeQuery::all()) == std::unordered_set<Node>{n1, n2});

CHECK(g.query_edges(MultiDiEdgeQuery::all()) == std::unordered_set<MultiDiEdge>{e3, e4});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0)) ==
std::unordered_set<MultiDiEdge>{});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n0)) ==
std::unordered_set<MultiDiEdge>{e3});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(p2)) ==
std::unordered_set<MultiDiEdge>{e3, e4});
CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idx(p0)) ==
std::unordered_set<MultiDiEdge>{e3});
}

SUBCASE("remove_edge") {
g.remove_edge(e1);

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0).with_dst_node(
n1)) == std::unordered_set<MultiDiEdge>{});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n2)) ==
std::unordered_set<MultiDiEdge>{e2});

CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(p2)) ==
std::unordered_set<MultiDiEdge>{e3, e4});
}

}

// TEST_CASE("AdjacencyMultiDiGraph:remove_node") {
// MultiDiGraph g = MultiDiGraph::create<AdjacencyMultiDiGraph>();
// Node n0 = g.add_node();
// Node n1 = g.add_node();
// Node n2 = g.add_node();
// NodePort p0 = g.add_node_port();
// NodePort p1 = g.add_node_port();
// NodePort p2 = g.add_node_port();
// MultiDiEdge e1{n0, n1, p0, p1};
// MultiDiEdge e2{n0, n2, p0, p2};
// MultiDiEdge e3{n2, n0, p2, p0};
// MultiDiEdge e4{n2, n1, p2, p1};
// g.add_edge(e1);
// g.add_edge(e2);
// g.add_edge(e3);
// g.add_edge(e4);

// g.remove_node_unsafe(n0);

// CHECK(g.query_nodes({}) == std::unordered_set<Node>{n1, n2});

// CHECK(g.query_edges({}) == std::unordered_set<MultiDiEdge>{e3, e4});

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0)) ==
// std::unordered_set<MultiDiEdge>{});

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n0)) ==
// std::unordered_set<MultiDiEdge>{e3});

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p2})) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p0})) ==
// std::unordered_set<MultiDiEdge>{e3});
// }

// TEST_CASE("AdjacencyMultiDiGraph:remove_edge") {
// AdjacencyMultiDiGraph g;
// Node n0 = g.add_node();
// Node n1 = g.add_node();
// Node n2 = g.add_node();
// NodePort p0 = g.add_node_port();
// NodePort p1 = g.add_node_port();
// NodePort p2 = g.add_node_port();
// MultiDiEdge e1{n0, n1, p0, p1};
// MultiDiEdge e2{n0, n2, p0, p2};
// MultiDiEdge e3{n2, n0, p2, p0};
// MultiDiEdge e4{n2, n1, p2, p1};
// g.add_edge(e1);
// g.add_edge(e2);
// g.add_edge(e3);
// g.add_edge(e4);

// g.remove_edge(e1);

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_node(n0).with_dst_node(
// n1)) == std::unordered_set<MultiDiEdge>{});

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_node(n2)) ==
// std::unordered_set<MultiDiEdge>{e2});

// CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idx(p2)) ==
// std::unordered_set<MultiDiEdge>{e3, e4});
// }

0 comments on commit 6486f20

Please sign in to comment.