From cd7fef3efb51bd9f61a278d08fe3dbf805f6b08b Mon Sep 17 00:00:00 2001 From: lambda7xx Date: Mon, 17 Jul 2023 07:50:58 +0000 Subject: [PATCH] implement the query_values and query_keys for std::unordered_map and bidict fix the bug of containers --- lib/utils/include/utils/containers.h | 34 +++++++++++++++++-- lib/utils/include/utils/graph/query_set.h | 30 ++++++++++++++-- lib/utils/src/graph/adjacency_multidigraph.cc | 3 +- lib/utils/src/graph/views.cc | 2 +- 4 files changed, 63 insertions(+), 6 deletions(-) diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index c75876ecf1..ae274f473a 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -110,6 +110,11 @@ typename It::value_type product(It begin, It end) { }); } +// template +// bool contains(Container const &c, typename Container::C::key_type const &e) { +// return find(c, e) != c.cend(); +// } + template bool contains(Container const &c, typename Container::value_type const &e) { return find(c, e) != c.cend(); @@ -149,7 +154,7 @@ template ()(std::declval()))> bidict map_keys(bidict const &m, F const &f) { bidict result; - for (auto const &kv : f) { + for (auto const &kv : m) { result.equate(f(kv.first), kv.second); } return result; @@ -159,7 +164,17 @@ template std::unordered_map filter_keys(std::unordered_map const &m, F const &f) { std::unordered_map result; - for (auto const &kv : f) { + for (auto const &kv : m) { + if (f(kv.first)) { + result.insert(kv); + } + } + return result; +} + +template std::unordered_map filter_keys(bidict const & m, F const & f) { + std::unordered_map result; + for (auto const &kv : m) { if (f(kv.first)) { result.insert(kv); } @@ -167,6 +182,18 @@ std::unordered_map filter_keys(std::unordered_map const &m, return result; } + +template std::unordered_map filter_values(bidict const & m, F const & f) { + std::unordered_map result; + for (auto const &kv : m) { + if (f(kv.second)) { + result.insert(kv); + } + } + return result; +} + + template filter_values(std::unordered_map const &m, return result; } + + + template std::vector keys(C const &c) { std::vector result; diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 957f73cb1a..d16099a080 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -4,6 +4,7 @@ #include "utils/containers.h" #include "utils/exception.h" #include "utils/optional.h" +#include "utils/bidict.h" #include namespace FlexFlow { @@ -71,15 +72,40 @@ template std::unordered_map query_keys(query_set const &q, C const &m) { std::cout<<"3"< q_set = allowed_values(q); + + auto filter_lambda = [&q_set](const K& key) { + return q_set.find(key) != q_set.end(); + }; + + return filter_keys(m, filter_lambda); }//TODO +template +std::unordered_map query_keys(query_set const &q, bidict const &m) { + std::unordered_set q_set = allowed_values(q); + + auto filter_lambda = [&q_set](const V& value) { + return q_set.find(value) != q_set.end(); + }; + + return filter_values(m, filter_lambda); +} + + template std::unordered_map query_values(query_set const &q, C const &m) { std::cout<<"4"< q_set = allowed_values(q); + + auto filter_lambda = [&q_set](const V& value) { + return q_set.find(value) != q_set.end(); + }; + + return filter_values(m, filter_lambda); } template diff --git a/lib/utils/src/graph/adjacency_multidigraph.cc b/lib/utils/src/graph/adjacency_multidigraph.cc index ce9fd33d8d..0b8ee7f4f1 100644 --- a/lib/utils/src/graph/adjacency_multidigraph.cc +++ b/lib/utils/src/graph/adjacency_multidigraph.cc @@ -42,10 +42,11 @@ void AdjacencyMultiDiGraph::remove_edge(MultiDiEdge const &e) { std::unordered_set AdjacencyMultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const { std::unordered_set result; - for (auto const &src_kv : query_keys(q.srcs, this->adjacency)) { + for (auto const &src_kv : query_keys(q.srcs, this->adjacency)) { //query_keys(q.srcs, this->adjacency) return map>>> for (auto const &dst_kv : query_keys(q.dsts, src_kv.second)) { for (auto const &srcIdx_kv : query_keys(q.srcIdxs, dst_kv.second)) { for (auto const &dstIdx : apply_query(q.dstIdxs, srcIdx_kv.second)) { + std::cout<<"add edge "< JoinedNodeView::query_nodes(NodeQuery const &query) const { - return unique(values(query_keys(query.nodes, this->mapping))); + return unique(values(query_keys(query.nodes, this->mapping)));//query_keys(query.nodes, this->mapping)->std::unordered_map } std::pair, std::unordered_set>