From 932979d5db86218d057731f29252ccb889cd536a Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Fri, 9 Aug 2024 22:42:13 -0700 Subject: [PATCH 01/21] removed GraphInternal --- lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h | 2 -- lib/utils/include/utils/graph/digraph/digraph.h | 2 -- lib/utils/include/utils/graph/digraph/digraph_view.h | 2 -- lib/utils/include/utils/graph/multidigraph/multidigraph.h | 2 -- lib/utils/include/utils/graph/node/graph.h | 2 -- lib/utils/include/utils/graph/node/graph_view.h | 2 -- lib/utils/include/utils/graph/undirected/undirected_graph.h | 2 -- .../include/utils/graph/undirected/undirected_graph_view.h | 2 -- 8 files changed, 16 deletions(-) diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 7974c033c3..164f304805 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -43,8 +43,6 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index e36b90d4bf..3d320b1c06 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,8 +40,6 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 54f84f8d2c..0380751c55 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,8 +31,6 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; - - friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 69080b9348..692ee33783 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,8 +40,6 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index bddefdacb3..1d94d1a65e 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,8 +31,6 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index fce3177ef1..8d904e05f2 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,8 +22,6 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); - - friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 69975991ce..09b6495699 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,8 +34,6 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; - friend struct GraphInternal; - private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index c2df96abc0..90dd5dd5d8 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,8 +29,6 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; - friend struct GraphInternal; - private: IUndirectedGraphView const &get_ptr() const; }; From cbc3de4761bb26855165388d1a5cfd87b2a85b71 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 10 Aug 2024 21:28:48 -0700 Subject: [PATCH 02/21] minor changes --- .../instances/hashmap_undirected_graph.h | 0 .../utils/graph/undirected/undirected_edge.h | 5 ++ .../utils/graph/undirected/undirected_edge.cc | 15 ++++ lib/utils/test/src/test_algorithms.cc | 10 --- .../digraph/algorithms/get_dominators.cc | 33 +++++++ .../test/src/utils/graph/digraph/digraph.cc | 85 +++++++++++++++++++ .../graph/digraph/directed_edge_query.cc | 54 ++++++++++++ 7 files changed, 192 insertions(+), 10 deletions(-) rename lib/utils/{src => include}/utils/graph/instances/hashmap_undirected_graph.h (100%) create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/digraph.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 33d50192cb..235cf45c4a 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -31,4 +31,9 @@ struct hash<::FlexFlow::UndirectedEdge> { } // namespace std +namespace FlexFlow { +std::string format_as(UndirectedEdge const &); +std::ostream &operator<<(std::ostream &, UndirectedEdge const &); +} // namespace FlexFlow + #endif diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 0a575e115c..2e79a8016d 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,5 +1,6 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" +#include namespace FlexFlow { @@ -38,3 +39,17 @@ size_t hash::operator()(UndirectedEdge const &e) const { } } // namespace std + +namespace FlexFlow { +std::string format_as(UndirectedEdge const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, UndirectedEdge const &x) { + return s << fmt::to_string(x); +} +} // namespace FlexFlow diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index a1dd75504e..f3a48882a0 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -75,16 +75,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == expected_result); } - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); - } - SUBCASE("get_sinks") { auto expected = std::unordered_set{n[2], n[3]}; CHECK(get_sinks(g) == expected); diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc new file mode 100644 index 0000000000..25c2384959 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -0,0 +1,33 @@ +#include "utils/graph/digraph/algorithms/get_dominators.h" +#include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dominators") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + add_edges(g, e); + + SUBCASE("single node") { + std::unordered_set expected_dominators = {n[0], n[2]}; + CHECK(get_dominators(g, n[2]) == expected_dominators); + } + + SUBCASE("multiple nodes") { + std::unordered_set nodes = {n[1], n[3]}; + std::unordered_set expected_dominators = {n[0]}; + CHECK(get_dominators(g, nodes) == expected_dominators); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/digraph.cc b/lib/utils/test/src/utils/graph/digraph/digraph.cc new file mode 100644 index 0000000000..3a3648eec8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/digraph.cc @@ -0,0 +1,85 @@ +#include "utils/graph/digraph/digraph.h" +#include "utils/containers/repeat.h" +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/node_query.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { + /* + graph TD + + n0 --> n1 + n0 --> n2 + n1 --> n2 + n2 --> n4 + n1 --> n3 + */ + + DiGraph g = DiGraph::create(); + std::vector n = repeat(5, [&] { return g.add_node(); }); + std::vector e = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}; + for (DirectedEdge const &edge : e) { + g.add_edge(edge); + } + + SUBCASE("query_nodes") { + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); + + std::unordered_set queried_edges = + g.query_edges(directed_edge_query_all()); + std::unordered_set expected = { + e[0], e[1], e[2], e[3], e[4]}; + CHECK(queried_edges == expected); + + queried_edges = g.query_edges( + DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); + expected = std::unordered_set{e[0]}; + CHECK(queried_edges == expected); + } + SUBCASE("remove_node_unsafe") { + g.remove_node_unsafe(n[0]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[1], n[2], n[3], n[4]}); + + // removing a node also removes its adjacent edges + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[3], e[4]}); + + g.remove_node_unsafe(n[1]); + + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[2], n[3], n[4]}); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[3]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[1], e[2], e[3], e[4]}); + CHECK(g.query_nodes(node_query_all()) == + std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); + + g.remove_edge(e[1]); + g.remove_edge(e[3]); + CHECK(g.query_edges(directed_edge_query_all()) == + std::unordered_set{e[2], e[4]}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc new file mode 100644 index 0000000000..c17d2b8fd7 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc @@ -0,0 +1,54 @@ +#include "utils/graph/digraph/directed_edge_query.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("directed_edge_query") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[1], n[3]}}); + + SUBCASE("directed_edge_query_all") { + + DirectedEdgeQuery result = directed_edge_query_all(); + + CHECK(matches_edge(result, DirectedEdge{n[0], n[1]})); + CHECK(matches_edge(result, DirectedEdge{n[0], n[2]})); + CHECK(matches_edge(result, DirectedEdge{n[1], n[2]})); + CHECK(matches_edge(result, DirectedEdge{n[2], n[4]})); + CHECK(matches_edge(result, DirectedEdge{n[1], n[3]})); + } + + SUBCASE("matches_edge") { + DirectedEdgeQuery q{{n[0]}, {n[1]}}; + + CHECK(matches_edge(q, DirectedEdge{n[0], n[1]})); + CHECK_FALSE(matches_edge(q, DirectedEdge{n[1], n[2]})); + } + + SUBCASE("query_intersection") { + DirectedEdgeQuery q1{{n[0], n[1]}, {n[1], n[2], n[4]}}; + DirectedEdgeQuery q2{{n[1], n[2]}, {n[2], n[3]}}; + + DirectedEdgeQuery result = query_intersection(q1, q2); + + std::unordered_set expected_srcs = {n[1]}; + std::unordered_set expected_dsts = {n[2]}; + + CHECK(result.srcs == expected_srcs); + CHECK(result.dsts == expected_dsts); + } + } +} From 018bb40f11c36af27fafea85650e9ac54a8cba6a Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sat, 10 Aug 2024 22:54:52 -0700 Subject: [PATCH 03/21] Added views.cc testing --- lib/utils/test/src/utils/graph/views/views.cc | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 lib/utils/test/src/utils/graph/views/views.cc diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc new file mode 100644 index 0000000000..684738e7ba --- /dev/null +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -0,0 +1,249 @@ +#include "utils/graph/views/views.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/undirected/undirected_graph_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("UndirectedSubgraphView") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {{n[0], n[3]}, + {n[1], n[1]}, + {n[1], n[2]}, + {n[1], n[3]}, + {n[2], n[3]}, + {n[2], n[4]}}); + std::unordered_set sub_nodes = {n[0], n[1], n[3]}; + UndirectedGraphView view = + UndirectedGraphView::create(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n[0], n[1], n[3]}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + {n[0], n[3]}, + {n[1], n[1]}, + {n[1], n[3]}, + }; + + std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + } + } + + TEST_CASE("DiSubgraphView") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n[0], n[3]}, + DirectedEdge{n[3], n[0]}, + DirectedEdge{n[1], n[1]}, + DirectedEdge{n[2], n[1]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[2]}, + DirectedEdge{n[2], n[4]}}); + std::unordered_set sub_nodes = {n[0], n[1], n[3]}; + DiGraphView view = DiGraphView::create(g, sub_nodes); + + SUBCASE("get_nodes") { + std::unordered_set expected = {n[0], n[1], n[3]}; + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[3], n[0]}, + DirectedEdge{n[1], n[1]}, + DirectedEdge{n[1], n[3]}, + }; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + // TEST_CASE("JoinedUndirectedGraphView") { + // UndirectedGraph g1 = + // UndirectedGraph::create(); UndirectedGraph g2 + // = UndirectedGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges(g1, {{n1[0], n1[1]}, {n1[1], n1[2]}}); + // add_edges(g2, {{n2[0], n2[2]}, {n2[1], n2[2]}}); + + // UndirectedGraphView view = + // UndirectedGraphView::create(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // {n1[0], n1[1]}, {n1[1], n1[2]}, + // {n2[0], n2[2]}, {n2[1], n2[2]} + // }; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + // TEST_CASE("JoinedDigraphView") { + // DiGraph g1 = DiGraph::create(); + // DiGraph g2 = DiGraph::create(); + + // std::vector n1 = add_nodes(g1, 3); + // std::vector n2 = add_nodes(g2, 3); + + // add_edges(g1, {DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], + // n1[2]}}); add_edges(g2, {DirectedEdge{n2[0], n2[2]}, + // DirectedEdge{n2[1], n2[2]}}); + + // DiGraphView view = DiGraphView::create(g1, g2); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = set_union(unordered_set_of(n1), + // unordered_set_of(n2)); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}, + // DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]} + // }; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + // TEST_CASE("AddDirectedEdgesView") { + // DiGraph g = DiGraph::create(); + // std::vector n = add_nodes(g, 4); + // add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + + // std::unordered_set additional_edges = { + // DirectedEdge{n[2], n[3]}, DirectedEdge{n[3], n[0]} + // }; + + // DiGraphView view = DiGraphView::create(g, + // additional_edges); + + // SUBCASE("get_nodes") { + // std::unordered_set expected = unordered_set_of(n); + + // std::unordered_set result = get_nodes(view); + + // CHECK(result == expected); + // } + + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}, + // DirectedEdge{n[2], n[3]}, DirectedEdge{n[3], n[0]} + // }; + + // std::unordered_set result = get_edges(view); + + // CHECK(result == expected); + // } + // } + + TEST_CASE("ViewDiGraphAsUndirectedGraph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[0], n[2]}}); + + UndirectedGraphView view = + UndirectedGraphView::create(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = { + {n[0], n[1]}, {n[1], n[2]}, {n[2], n[0]}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } + + TEST_CASE("ViewUndirectedGraphAsDiGraph") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, {{n[0], n[0]}, {n[0], n[1]}, {n[1], n[2]}, {n[2], n[0]}}); + + DiGraphView view = DiGraphView::create(g); + + SUBCASE("get_nodes") { + std::unordered_set expected = unordered_set_of(n); + + std::unordered_set result = get_nodes(view); + + CHECK(result == expected); + } + + SUBCASE("get_edges") { + std::unordered_set expected = {DirectedEdge{n[0], n[0]}, + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[0]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[1]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[0], n[2]}}; + + std::unordered_set result = get_edges(view); + + CHECK(result == expected); + } + } +} From 72d61c773d47f0c2e1cdfcdd4e57be980bdc1d26 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 14:03:49 -0700 Subject: [PATCH 04/21] Fixed containers testing --- lib/utils/include/utils/containers.decl.h | 35 +- lib/utils/include/utils/containers.h | 18 +- .../include/utils/containers/vector_split.h | 1 + lib/utils/test/src/test_containers.cc | 393 ------------------ lib/utils/test/src/test_hash.cc | 20 - lib/utils/test/src/utils/containers.cc | 185 +++++++++ lib/utils/test/src/utils/containers/all_of.cc | 13 + .../test/src/utils/containers/are_disjoint.cc | 15 + .../test/src/utils/containers/as_vector.cc | 14 + .../test/src/utils/containers/contains.cc | 13 + lib/utils/test/src/utils/containers/count.cc | 13 + .../test/src/utils/containers/enumerate.cc | 17 + .../test/src/utils/containers/filter_keys.cc | 17 + lib/utils/test/src/utils/containers/find.cc | 12 + .../test/src/utils/containers/get_first.cc | 12 + .../test/src/utils/containers/is_subset_eq.cc | 16 + lib/utils/test/src/utils/containers/keys.cc | 17 + .../test/src/utils/containers/map_keys.cc | 17 + .../test/src/utils/containers/map_values.cc | 16 + .../src/utils/containers/restrict_keys.cc | 18 + .../test/src/utils/containers/reversed.cc | 16 + .../test/src/utils/containers/set_union.cc | 15 + .../test/src/utils/containers/sorted_by.cc | 18 + lib/utils/test/src/utils/containers/subvec.cc | 17 + lib/utils/test/src/utils/containers/values.cc | 17 + .../test/src/utils/containers/vector_split.cc | 16 + .../deduplicated_priority_queue.cc} | 2 +- .../disjoint_set.cc} | 20 +- .../{test_dot_file.cc => utils/dot_file.cc} | 2 +- lib/utils/test/src/utils/hash-utils.cc | 98 +++++ lib/utils/test/src/utils/join_strings.cc | 20 + .../random_utils.cc} | 2 +- .../record_formatter.cc} | 2 +- .../{test_sequence.cc => utils/sequence.cc} | 2 +- .../{test_stack_map.cc => utils/stack_map.cc} | 2 +- .../stack_string.cc} | 7 +- .../stack_vector.cc} | 2 +- .../src/{test_tuple.cc => utils/tuple.cc} | 2 +- .../type_index.cc} | 8 +- .../src/{test_variant.cc => utils/variant.cc} | 2 +- .../src/{test_vector.cc => utils/vector.cc} | 2 +- 41 files changed, 655 insertions(+), 479 deletions(-) delete mode 100644 lib/utils/test/src/test_containers.cc delete mode 100644 lib/utils/test/src/test_hash.cc create mode 100644 lib/utils/test/src/utils/containers.cc create mode 100644 lib/utils/test/src/utils/containers/all_of.cc create mode 100644 lib/utils/test/src/utils/containers/are_disjoint.cc create mode 100644 lib/utils/test/src/utils/containers/as_vector.cc create mode 100644 lib/utils/test/src/utils/containers/contains.cc create mode 100644 lib/utils/test/src/utils/containers/count.cc create mode 100644 lib/utils/test/src/utils/containers/enumerate.cc create mode 100644 lib/utils/test/src/utils/containers/filter_keys.cc create mode 100644 lib/utils/test/src/utils/containers/find.cc create mode 100644 lib/utils/test/src/utils/containers/get_first.cc create mode 100644 lib/utils/test/src/utils/containers/is_subset_eq.cc create mode 100644 lib/utils/test/src/utils/containers/keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_values.cc create mode 100644 lib/utils/test/src/utils/containers/restrict_keys.cc create mode 100644 lib/utils/test/src/utils/containers/reversed.cc create mode 100644 lib/utils/test/src/utils/containers/set_union.cc create mode 100644 lib/utils/test/src/utils/containers/sorted_by.cc create mode 100644 lib/utils/test/src/utils/containers/subvec.cc create mode 100644 lib/utils/test/src/utils/containers/values.cc create mode 100644 lib/utils/test/src/utils/containers/vector_split.cc rename lib/utils/test/src/{test_deduplicated_priority_queue.cc => utils/deduplicated_priority_queue.cc} (100%) rename lib/utils/test/src/{test_disjoint_set.cc => utils/disjoint_set.cc} (74%) rename lib/utils/test/src/{test_dot_file.cc => utils/dot_file.cc} (100%) create mode 100644 lib/utils/test/src/utils/hash-utils.cc create mode 100644 lib/utils/test/src/utils/join_strings.cc rename lib/utils/test/src/{test_random_utils.cc => utils/random_utils.cc} (100%) rename lib/utils/test/src/{test_format.cc => utils/record_formatter.cc} (100%) rename lib/utils/test/src/{test_sequence.cc => utils/sequence.cc} (100%) rename lib/utils/test/src/{test_stack_map.cc => utils/stack_map.cc} (100%) rename lib/utils/test/src/{test_stack_string.cc => utils/stack_string.cc} (94%) rename lib/utils/test/src/{test_stack_vector.cc => utils/stack_vector.cc} (100%) rename lib/utils/test/src/{test_tuple.cc => utils/tuple.cc} (100%) rename lib/utils/test/src/{test_type_index.cc => utils/type_index.cc} (80%) rename lib/utils/test/src/{test_variant.cc => utils/variant.cc} (100%) rename lib/utils/test/src/{test_vector.cc => utils/vector.cc} (100%) diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 81fdff8a40..6f6e1db37d 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -30,17 +30,9 @@ bool contains_l(bidict const &m, K const &k); template bool contains_r(bidict const &m, V const &v); -template -std::unordered_map filter_values(std::unordered_map const &m, - F const &f); - template std::optional index_of(Container const &c, Element const &e); -template -std::unordered_map restrict_keys(std::unordered_map const &m, - std::unordered_set const &mask); - template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs); @@ -48,9 +40,6 @@ std::unordered_map merge_maps(std::unordered_map const &lhs, template bidict merge_maps(bidict const &lhs, bidict const &rhs); -template -std::optional at_idx(std::vector const &v, size_t idx); - template std::function lookup_in(std::unordered_map const &m); @@ -61,16 +50,8 @@ template std::function lookup_in_r(bidict const &m); template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r); - -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input); - -template -std::optional maybe_get_only(C const &c); +bool is_superseteq_of(std::unordered_set const &l, + std::unordered_set const &r); template std::optional optional_all_of(Container const &, Function const &); @@ -78,15 +59,21 @@ std::optional optional_all_of(Container const &, Function const &); template bool are_all_same(C const &c); +template +std::vector flatmap(std::vector const &v, F const &f); + +template +std::unordered_set flatmap(std::unordered_set const &v, F const &f); +template +std::unordered_set flatmap_v2(std::unordered_set const &v, + std::unordered_set (*f)(In const &)); + template std::function compare_by(F const &f); template typename C::value_type maximum(C const &v); -template -T reversed(T const &t); - template std::vector value_all(std::vector> const &v); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6164699f2e..e0a5e43bb7 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -5,17 +5,21 @@ #include "required_core.h" #include "type_traits_core.h" #include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" #include "utils/containers/contains.h" #include "utils/containers/extend.h" #include "utils/containers/extend_vector.h" #include "utils/containers/filter.h" #include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" #include "utils/containers/sorted.h" #include "utils/containers/transform.h" #include "utils/containers/vector_transform.h" #include "utils/exception.h" #include "utils/hash/pair.h" +#include "utils/optional.h" #include "utils/type_traits.h" #include #include @@ -135,21 +139,11 @@ std::function lookup_in_r(bidict const &m) { } template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { +bool is_superseteq_of(std::unordered_set const &l, + std::unordered_set const &r) { return is_subseteq_of(r, l); } -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input) { - std::unordered_set result; - std::transform( - input.cbegin(), input.cend(), std::inserter(result, result.begin()), f); - return result; -} - template std::optional optional_all_of(Container const &container, Function const &func) { diff --git a/lib/utils/include/utils/containers/vector_split.h b/lib/utils/include/utils/containers/vector_split.h index a1ab12a070..b55eea19d4 100644 --- a/lib/utils/include/utils/containers/vector_split.h +++ b/lib/utils/include/utils/containers/vector_split.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H +#include #include namespace FlexFlow { diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc deleted file mode 100644 index af7792dc6d..0000000000 --- a/lib/utils/test/src/test_containers.cc +++ /dev/null @@ -1,393 +0,0 @@ -#include "test/utils/doctest.h" -#include "utils/containers.h" -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); - } - - TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); - } - - TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); - } - - TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); - } - - TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); - } - - TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); - } - - TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); - } - - TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); - } - - TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); - } - - TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); - } - - TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); - } - - TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); - } - - TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); - } - - // TEST_CASE("items") { - // std::unordered_map m = {{1, std::string("one")}, {2, - // std::string("two")}, {3,std::string("three")}}; - // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); - } - - TEST_CASE("without_order") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(unordered_set_of(v) == expected); - } - - TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); - } - - TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); - } - - TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); - } - - TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); - } - - TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - } - - TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); - } - - TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); - } - - TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); - } - - TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); - } - - TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); - } - - TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); - } - - TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); - } - - TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); - } - - TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); - } - - TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); - } - - TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("as_vector") { - std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); - CHECK(result == std::vector({3, 2, 1})); - } - - TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); - } - - TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); - } - - TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); - - CHECK(result == std::vector{0, 1, 2, 3, 4}); - } - - TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); - } - - TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); - - // Checking the maximum is as expected - REQUIRE(result == 5); - } - - TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; - - // Checking the reversed sequence is as expected - CHECK(result == expected); - } - - TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); - } - - TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); - } - - TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); - } - - TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); - } - - TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - - CHECK(subvec_v == std::vector({2, 3, 4})); - - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); - } - - auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); - } - } - return factors; - }; - - // Example for vector - TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); - } -} diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc deleted file mode 100644 index b38c43fe30..0000000000 --- a/lib/utils/test/src/test_hash.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "test/utils/doctest.h" -#include "utils/hash-utils.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("hash:unordered_map") { - std::unordered_map map1{{1, 2}}; - std::unordered_map map2{{1, 2}, {3, 4}}; - - size_t hash1 = get_std_hash(map1); - size_t hash2 = get_std_hash(map2); - - CHECK(hash1 != hash2); - - map1.insert({1, 2}); - hash1 = get_std_hash(map1); - CHECK(hash1 == hash2); - } -} diff --git a/lib/utils/test/src/utils/containers.cc b/lib/utils/test/src/utils/containers.cc new file mode 100644 index 0000000000..43f651ba09 --- /dev/null +++ b/lib/utils/test/src/utils/containers.cc @@ -0,0 +1,185 @@ +#include "utils/containers.h" +#include "test/utils/doctest.h" +#include "utils/bidict/bidict.h" +#include +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(sum(v) == 15); + } + + TEST_CASE("sum_where") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + CHECK(sum_where(v, condition) == 6); + } + + TEST_CASE("product_where") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + CHECK(product_where(v, condition) == 8); + } + + TEST_CASE("contains_l and contains_r") { + bidict bd; + bd.equate(1, "one"); + bd.equate(2, "two"); + + CHECK(contains_l(bd, 1) == true); + CHECK(contains_l(bd, 3) == false); + CHECK(contains_r(bd, std::string("one")) == true); + CHECK(contains_r(bd, std::string("three")) == false); + } + + TEST_CASE("is_submap") { + std::unordered_map m1 = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_map m2 = {{1, "one"}, {2, "two"}}; + std::unordered_map m3 = {{1, "one"}, {4, "four"}}; + + CHECK(is_submap(m1, m2) == true); + CHECK(is_submap(m1, m3) == false); + } + + TEST_CASE("index_of") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(index_of(v, 3).value() == 2); + CHECK(index_of(v, 6) == std::nullopt); + } + + TEST_CASE("merge_maps") { + SUBCASE("bidict") { + std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; + std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; + std::unordered_map fwd_map2 = {{3, "three"}, + {4, "four"}}; + std::unordered_map bwd_map2 = {{"three", 3}, + {"four", 4}}; + bidict lhs{fwd_map1, bwd_map1}; + bidict rhs{fwd_map2, bwd_map2}; + + std::unordered_map result = + merge_maps(lhs, rhs); // impicit conversion + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } + SUBCASE("unordered_map") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + } + } + + TEST_CASE("lookup_in") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = lookup_in(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + } + + TEST_CASE("lookup_in_l") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_l(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + } + + TEST_CASE("lookup_in_r") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_r(m); + CHECK(f("one") == 1); + CHECK(f("two") == 2); + } + + TEST_CASE("is_superseteq_of") { + std::unordered_set s1 = {1, 2, 3, 4}; + std::unordered_set s2 = {1, 2, 3}; + std::unordered_set s3 = {1, 2, 5}; + + CHECK(is_superseteq_of(s1, s2) == true); + CHECK(is_superseteq_of(s1, s3) == false); + } + + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map expected = {{2, "two"}, {3, "three"}}; + CHECK(result == expected); + } + + TEST_CASE("optional_all_of") { + std::vector v = {2, 4, 6, 8}; + auto f = [](int x) -> std::optional { return x % 2 == 0; }; + CHECK(optional_all_of(v, f) == true); + + auto f2 = [](int x) -> std::optional { + if (x == 6) { + return std::nullopt; + } + return x % 2 == 0; + }; + CHECK(optional_all_of(v, f2) == std::nullopt); + } + + TEST_CASE("are_all_same") { + std::vector v1 = {2, 2, 2, 2}; + std::vector v2 = {1, 2, 3, 4}; + CHECK(are_all_same(v1) == true); + CHECK(are_all_same(v2) == false); + } + + TEST_CASE("Test for flatmap function on vectors") { + + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } + } + return factors; + }; + + std::vector v = {2, 3, 4, 5}; + auto result = flatmap(v, get_factors); + CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); + } + + TEST_CASE("compare_by") { + std::vector v = {"abc", "a", "ab"}; + auto comp = compare_by( + [](std::string const &s) { return s.length(); }); + std::sort(v.begin(), v.end(), comp); + CHECK(v == std::vector{"a", "ab", "abc"}); + } + + TEST_CASE("maximum") { + std::vector v = {1, 5, 3, 4, 2}; + CHECK(maximum(v) == 5); + } + + TEST_CASE("value_all") { + std::vector> v = {1, 2, std::nullopt, 4, 5}; + CHECK_THROWS_AS(value_all(v), std::runtime_error); + + std::vector> v2 = {1, 2, 3, 4, 5}; + CHECK(value_all(v2) == std::vector{1, 2, 3, 4, 5}); + } +} diff --git a/lib/utils/test/src/utils/containers/all_of.cc b/lib/utils/test/src/utils/containers/all_of.cc new file mode 100644 index 0000000000..247dd62787 --- /dev/null +++ b/lib/utils/test/src/utils/containers/all_of.cc @@ -0,0 +1,13 @@ +#include "utils/containers/all_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 4 == 0; }) == false); + } +} diff --git a/lib/utils/test/src/utils/containers/are_disjoint.cc b/lib/utils/test/src/utils/containers/are_disjoint.cc new file mode 100644 index 0000000000..1bec8af85f --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_disjoint.cc @@ -0,0 +1,15 @@ +#include "utils/containers/are_disjoint.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + r.insert(3); + CHECK_FALSE(are_disjoint(l, r)); + } +} diff --git a/lib/utils/test/src/utils/containers/as_vector.cc b/lib/utils/test/src/utils/containers/as_vector.cc new file mode 100644 index 0000000000..af89f99245 --- /dev/null +++ b/lib/utils/test/src/utils/containers/as_vector.cc @@ -0,0 +1,14 @@ +#include "utils/containers/as_vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("as_vector") { + std::unordered_set s = {1, 2, 3}; + std::vector result = as_vector(s); + CHECK(result == std::vector({3, 2, 1})); + } +} diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc new file mode 100644 index 0000000000..6e0a84c7ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -0,0 +1,13 @@ +#include "utils/containers/contains.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } +} diff --git a/lib/utils/test/src/utils/containers/count.cc b/lib/utils/test/src/utils/containers/count.cc new file mode 100644 index 0000000000..4d5e05fb9d --- /dev/null +++ b/lib/utils/test/src/utils/containers/count.cc @@ -0,0 +1,13 @@ +#include "utils/containers/count.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc new file mode 100644 index 0000000000..4d276e5670 --- /dev/null +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -0,0 +1,17 @@ +#include "utils/containers/enumerate.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("enumerate") { + std::unordered_set input_set = {1, 2, 3, 4, 5}; + std::unordered_map result = enumerate(input_set); + std::unordered_map expected = { + {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/filter_keys.cc b/lib/utils/test/src/utils/containers/filter_keys.cc new file mode 100644 index 0000000000..3f1c26c4f5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filter_keys.cc @@ -0,0 +1,17 @@ +#include "utils/containers/filter_keys.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; // Filtering function + std::unordered_map result = filter_keys(m, f); + std::unordered_map expected = {{1, "one"}, {3, "three"}}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc new file mode 100644 index 0000000000..b965182853 --- /dev/null +++ b/lib/utils/test/src/utils/containers/find.cc @@ -0,0 +1,12 @@ +#include "utils/containers/find.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(find(v, 3) != v.cend()); + CHECK(find(v, 6) == v.cend()); + } +} diff --git a/lib/utils/test/src/utils/containers/get_first.cc b/lib/utils/test/src/utils/containers/get_first.cc new file mode 100644 index 0000000000..d5df0f7c7a --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_first.cc @@ -0,0 +1,12 @@ +#include "utils/containers/get_first.h" +#include "utils/containers/contains.h" +#include +#include +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_first") { + std::unordered_set s = {1, 2, 3}; + CHECK(contains(s, get_first(s))); + } +} diff --git a/lib/utils/test/src/utils/containers/is_subset_eq.cc b/lib/utils/test/src/utils/containers/is_subset_eq.cc new file mode 100644 index 0000000000..d762f171b6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_subset_eq.cc @@ -0,0 +1,16 @@ +#include "utils/containers/is_subseteq_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } +} diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc new file mode 100644 index 0000000000..2f162ecaf1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -0,0 +1,17 @@ +#include "utils/containers/keys.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {3, 2, 1}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc new file mode 100644 index 0000000000..189840f318 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -0,0 +1,17 @@ +#include "utils/containers/map_keys.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_keys") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; // Mapping function + auto result = map_keys(m, f); + CHECK(result.size() == 2); + CHECK(result[1] == "one"); + CHECK(result[4] == "two"); + } +} diff --git a/lib/utils/test/src/utils/containers/map_values.cc b/lib/utils/test/src/utils/containers/map_values.cc new file mode 100644 index 0000000000..f854f17fc6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_values.cc @@ -0,0 +1,16 @@ +#include "utils/containers/map_values.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](std::string const &s) { return s.size(); }; // Mapping function + std::unordered_map result = map_values(m, f); + std::unordered_map expected = {{1, 3}, {2, 3}}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/restrict_keys.cc b/lib/utils/test/src/utils/containers/restrict_keys.cc new file mode 100644 index 0000000000..bacda218da --- /dev/null +++ b/lib/utils/test/src/utils/containers/restrict_keys.cc @@ -0,0 +1,18 @@ +#include "utils/containers/restrict_keys.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map expected = {{2, "two"}, {3, "three"}}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc new file mode 100644 index 0000000000..8e5ce1804a --- /dev/null +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -0,0 +1,16 @@ +#include "utils/containers/reversed.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing the 'reversed' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + std::vector result = reversed(input_vec); + std::vector expected = {5, 4, 3, 2, 1}; + + // Checking the reversed sequence is as expected + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/set_union.cc b/lib/utils/test/src/utils/containers/set_union.cc new file mode 100644 index 0000000000..d69cb64e00 --- /dev/null +++ b/lib/utils/test/src/utils/containers/set_union.cc @@ -0,0 +1,15 @@ +#include "utils/containers/set_union.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set expected = {1, 2, 3, 4}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc new file mode 100644 index 0000000000..a2e11f96e5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -0,0 +1,18 @@ +#include "utils/containers/sorted_by.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing sorted_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); + CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); + + std::unordered_set s2 = {-5, -1, -3, -2, -4}; + auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); + CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); + } +} diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc new file mode 100644 index 0000000000..38668ea20d --- /dev/null +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -0,0 +1,17 @@ +#include "utils/containers/subvec.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing subvec function") { + std::vector v = {1, 2, 3, 4, 5}; + auto subvec_v = subvec(v, std::optional(1), std::optional(4)); + + CHECK(subvec_v == std::vector({2, 3, 4})); + + auto subvec_v2 = subvec(v, std::nullopt, std::optional(3)); + CHECK(subvec_v2 == std::vector({1, 2, 3})); + } +} diff --git a/lib/utils/test/src/utils/containers/values.cc b/lib/utils/test/src/utils/containers/values.cc new file mode 100644 index 0000000000..b921bcb0cc --- /dev/null +++ b/lib/utils/test/src/utils/containers/values.cc @@ -0,0 +1,17 @@ +#include "utils/containers/values.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::vector result = values(m); + std::vector expected = {"three", "two", "one"}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc new file mode 100644 index 0000000000..872742a429 --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -0,0 +1,16 @@ +#include "utils/containers/vector_split.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + auto result = vector_split(v, 2); + std::vector prefix = result.first; + std::vector postfix = result.second; + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } +} diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/utils/deduplicated_priority_queue.cc similarity index 100% rename from lib/utils/test/src/test_deduplicated_priority_queue.cc rename to lib/utils/test/src/utils/deduplicated_priority_queue.cc index 66cfd395bc..2bf413675d 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/utils/deduplicated_priority_queue.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc similarity index 74% rename from lib/utils/test/src/test_disjoint_set.cc rename to lib/utils/test/src/utils/disjoint_set.cc index 80fcf87d6b..dc5c1bb3cb 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/disjoint_set.h" +#include "test/utils/doctest.h" using namespace FlexFlow; @@ -18,10 +18,10 @@ std::string generate_element(int seed) { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; + disjoint_set> ds; SUBCASE("SingleElementSets") { - optional element = generate_element(1); + std::optional element = generate_element(1); CHECK(ds.find(element) == element); element = generate_element(2); @@ -29,10 +29,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + std::optional element1 = generate_element(1); + std::optional element2 = generate_element(2); + std::optional element3 = generate_element(3); + std::optional element4 = generate_element(4); ds.m_union(element1, element2); CHECK(ds.find(element1) == ds.find(element2)); @@ -55,11 +55,11 @@ TEST_SUITE(FF_TEST_SUITE) { ds.m_union(1, 4); ds.m_union(5, 6); - std::map, optional, OptionalComparator> + std::map, std::optional, OptionalComparator> expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, std::optional, OptionalComparator> + mapping = ds.get_mapping(); for (auto const &kv : mapping) { CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/utils/dot_file.cc similarity index 100% rename from lib/utils/test/src/test_dot_file.cc rename to lib/utils/test/src/utils/dot_file.cc index ed4c32bb1c..d7f5dcb02c 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/utils/dot_file.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/dot_file.h" +#include "test/utils/doctest.h" #include TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/utils/hash-utils.cc b/lib/utils/test/src/utils/hash-utils.cc new file mode 100644 index 0000000000..2981615d9e --- /dev/null +++ b/lib/utils/test/src/utils/hash-utils.cc @@ -0,0 +1,98 @@ +#include "utils/hash-utils.h" +#include "test/utils/doctest.h" +#include "utils/hash/map.h" +#include "utils/hash/set.h" +#include "utils/hash/tuple.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" +#include "utils/hash/vector.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("hash-utils") { + SUBCASE("vector") { + std::vector vec1{1, 2, 3}; + std::vector vec2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(vec1); + size_t hash2 = get_std_hash(vec2); + + CHECK(hash1 != hash2); + + vec1.push_back(4); + hash1 = get_std_hash(vec1); + CHECK(hash1 == hash2); + } + + SUBCASE("map") { + std::map map1{{1, 2}}; + std::map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } + + SUBCASE("unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } + + SUBCASE("set") { + std::set set1{1, 2, 3}; + std::set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } + + SUBCASE("unordered_set") { + std::unordered_set set1{1, 2, 3}; + std::unordered_set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } + + SUBCASE("tuple") { + std::tuple tuple1{1, "test", 3.14}; + std::tuple tuple2{2, "test", 3.14}; + + size_t hash1 = get_std_hash(tuple1); + size_t hash2 = get_std_hash(tuple2); + + CHECK(hash1 != hash2); + + tuple1 = tuple2; + hash1 = get_std_hash(tuple1); + CHECK(hash1 == hash2); + } + } +} diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc new file mode 100644 index 0000000000..1377923ab6 --- /dev/null +++ b/lib/utils/test/src/utils/join_strings.cc @@ -0,0 +1,20 @@ +#include "utils/join_strings.h" +#include "test/utils/doctest.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("join_strings") { + std::vector const v = {"Hello", "world", "!"}; + + SUBCASE("iterator") { + CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); + } + + SUBCASE("join_strings with container") { + CHECK(join_strings(v, " ") == "Hello world !"); + } + } +} diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/utils/random_utils.cc similarity index 100% rename from lib/utils/test/src/test_random_utils.cc rename to lib/utils/test/src/utils/random_utils.cc index 88a566a198..1563de512c 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/utils/random_utils.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/random_utils.h" +#include "test/utils/doctest.h" #include void checkProbabilities(std::vector const &counts, diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/utils/record_formatter.cc similarity index 100% rename from lib/utils/test/src/test_format.cc rename to lib/utils/test/src/utils/record_formatter.cc index eeed2eae81..f211f5eb7b 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/utils/record_formatter.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/record_formatter.h" +#include "test/utils/doctest.h" std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/utils/sequence.cc similarity index 100% rename from lib/utils/test/src/test_sequence.cc rename to lib/utils/test/src/utils/sequence.cc index ee72febe05..c0a0aaf64e 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/utils/sequence.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/sequence.h" +#include "test/utils/doctest.h" using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/utils/stack_map.cc similarity index 100% rename from lib/utils/test/src/test_stack_map.cc rename to lib/utils/test/src/utils/stack_map.cc index 21c1b07d1b..e03552e6fd 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/stack_map.h" +#include "test/utils/doctest.h" using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/utils/stack_string.cc similarity index 94% rename from lib/utils/test/src/test_stack_string.cc rename to lib/utils/test/src/utils/stack_string.cc index a044f85fe3..30e20c5fce 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -1,6 +1,6 @@ +#include "utils/stack_string.h" #include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" -#include "utils/stack_string.h" using namespace FlexFlow; @@ -81,9 +81,4 @@ TEST_SUITE(FF_TEST_SUITE) { std::string stdStr = static_cast(str); CHECK(stdStr == "Hello"); } - - TEST_CASE("Arbitrary") { - constexpr std::size_t MAXSIZE = 10; - RCSUBCASE([](stack_string const &s) {}); - } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc similarity index 100% rename from lib/utils/test/src/test_stack_vector.cc rename to lib/utils/test/src/utils/stack_vector.cc index 1af43b6993..5243dc45fe 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,6 +1,6 @@ +#include "utils/stack_vector.h" #include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" -#include "utils/stack_vector.h" #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/utils/tuple.cc similarity index 100% rename from lib/utils/test/src/test_tuple.cc rename to lib/utils/test/src/utils/tuple.cc index 31308dec2c..2b8f4f65ad 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/tuple.h" +#include "test/utils/doctest.h" #include #include diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/utils/type_index.cc similarity index 80% rename from lib/utils/test/src/test_type_index.cc rename to lib/utils/test/src/utils/type_index.cc index b2d8aea848..5cf7ca989f 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/utils/type_index.cc @@ -1,19 +1,19 @@ -#include "test/utils/doctest.h" #include "utils/type_index.h" +#include "test/utils/doctest.h" #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("type_index function") { + TEST_CASE("get_type_index_for_type") { SUBCASE("int type") { - std::type_index idx = type_index(); + std::type_index idx = get_type_index_for_type(); std::type_index expected_idx = typeid(int); CHECK(idx == expected_idx); } SUBCASE("string type") { - std::type_index idx = type_index(); + std::type_index idx = get_type_index_for_type(); std::type_index expected_idx = typeid(std::string); CHECK(idx == expected_idx); } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/utils/variant.cc similarity index 100% rename from lib/utils/test/src/test_variant.cc rename to lib/utils/test/src/utils/variant.cc index 98b28a48e9..cae75f3045 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,6 +1,6 @@ +#include "utils/variant.h" #include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" -#include "utils/variant.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/utils/vector.cc similarity index 100% rename from lib/utils/test/src/test_vector.cc rename to lib/utils/test/src/utils/vector.cc index 4bdc724dd8..873f9d7447 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/vector.h" +#include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { From b47b801935c91fdf770c98509fca163c35e9b999 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 16:39:55 -0700 Subject: [PATCH 05/21] views.cc fix --- lib/utils/test/src/utils/graph/views/views.cc | 128 +++++++----------- 1 file changed, 47 insertions(+), 81 deletions(-) diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index 684738e7ba..29c63825a2 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -86,106 +86,72 @@ TEST_SUITE(FF_TEST_SUITE) { } } - // TEST_CASE("JoinedUndirectedGraphView") { - // UndirectedGraph g1 = - // UndirectedGraph::create(); UndirectedGraph g2 - // = UndirectedGraph::create(); + TEST_CASE("JoinedUndirectedGraphView") { + UndirectedGraph g1 = UndirectedGraph::create(); + UndirectedGraph g2 = UndirectedGraph::create(); - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); + std::vector n1 = add_nodes(g1, 3); + std::vector n2 = add_nodes(g2, 3); - // add_edges(g1, {{n1[0], n1[1]}, {n1[1], n1[2]}}); - // add_edges(g2, {{n2[0], n2[2]}, {n2[1], n2[2]}}); + add_edges(g1, {{n1[0], n1[1]}, {n1[1], n1[2]}}); + add_edges(g2, {{n2[0], n2[2]}, {n2[1], n2[2]}}); - // UndirectedGraphView view = - // UndirectedGraphView::create(g1, g2); - - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); - - // std::unordered_set result = get_nodes(view); - - // CHECK(result == expected); - // } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // {n1[0], n1[1]}, {n1[1], n1[2]}, - // {n2[0], n2[2]}, {n2[1], n2[2]} - // }; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - // } - - // TEST_CASE("JoinedDigraphView") { - // DiGraph g1 = DiGraph::create(); - // DiGraph g2 = DiGraph::create(); - - // std::vector n1 = add_nodes(g1, 3); - // std::vector n2 = add_nodes(g2, 3); - - // add_edges(g1, {DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], - // n1[2]}}); add_edges(g2, {DirectedEdge{n2[0], n2[2]}, - // DirectedEdge{n2[1], n2[2]}}); + UndirectedGraphView view = + UndirectedGraphView::create(g1, g2); - // DiGraphView view = DiGraphView::create(g1, g2); + // SUBCASE("get_nodes") { + // std::unordered_set expected = + // set_union(unordered_set_of(n1), unordered_set_of(n2)); - // SUBCASE("get_nodes") { - // std::unordered_set expected = set_union(unordered_set_of(n1), - // unordered_set_of(n2)); + // std::unordered_set result = get_nodes(view); - // std::unordered_set result = get_nodes(view); + // CHECK(result == expected); + // } - // CHECK(result == expected); - // } + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // {n1[0], n1[1]}, {n1[1], n1[2]}, + // {n2[0], n2[2]}, {n2[1], n2[2]} + // }; - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}, - // DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]} - // }; + // std::unordered_set result = get_edges(view); - // std::unordered_set result = get_edges(view); + // CHECK(result == expected); + // } + } - // CHECK(result == expected); - // } - // } + TEST_CASE("JoinedDigraphView") { + DiGraph g1 = DiGraph::create(); + DiGraph g2 = DiGraph::create(); - // TEST_CASE("AddDirectedEdgesView") { - // DiGraph g = DiGraph::create(); - // std::vector n = add_nodes(g, 4); - // add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + std::vector n1 = add_nodes(g1, 3); + std::vector n2 = add_nodes(g2, 3); - // std::unordered_set additional_edges = { - // DirectedEdge{n[2], n[3]}, DirectedEdge{n[3], n[0]} - // }; + add_edges(g1, {DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}}); + add_edges(g2, {DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]}}); - // DiGraphView view = DiGraphView::create(g, - // additional_edges); + DiGraphView view = DiGraphView::create(g1, g2); - // SUBCASE("get_nodes") { - // std::unordered_set expected = unordered_set_of(n); + // SUBCASE("get_nodes") { + // std::unordered_set expected = set_union(unordered_set_of(n1), + // unordered_set_of(n2)); - // std::unordered_set result = get_nodes(view); + // std::unordered_set result = get_nodes(view); - // CHECK(result == expected); - // } + // CHECK(result == expected); + // } - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}, - // DirectedEdge{n[2], n[3]}, DirectedEdge{n[3], n[0]} - // }; + // SUBCASE("get_edges") { + // std::unordered_set expected = { + // DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}, + // DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]} + // }; - // std::unordered_set result = get_edges(view); + // std::unordered_set result = get_edges(view); - // CHECK(result == expected); - // } - // } + // CHECK(result == expected); + // } + } TEST_CASE("ViewDiGraphAsUndirectedGraph") { DiGraph g = DiGraph::create(); From 3cabb66b2a3584a23c2b560aaa5674abb37f7d51 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 17:13:15 -0700 Subject: [PATCH 06/21] test reorganizing --- lib/utils/test/src/test_algorithms.cc | 98 ------------------- .../graph/digraph/get_connected_components.cc | 22 +++++ .../graph/digraph/get_topological_ordering.cc | 35 +++++++ .../get_weakly_connected_components.cc | 23 +++++ lib/utils/test/src/utils/graph/traversal.cc | 44 +++++++++ 5 files changed, 124 insertions(+), 98 deletions(-) create mode 100644 lib/utils/test/src/utils/graph/digraph/get_connected_components.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc create mode 100644 lib/utils/test/src/utils/graph/traversal.cc diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index f3a48882a0..7101aad53f 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -95,47 +95,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = { - {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs - } - - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); - } - } TEST_CASE("bfs") { DiGraph g = DiGraph::create(); @@ -176,61 +135,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK_BEFORE(4, 5); } - - TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); - } - - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_connected_components(g) == expected_components); - } - - TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - - CHECK(get_weakly_connected_components(g) == expected_components); - } } diff --git a/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc new file mode 100644 index 0000000000..bcbfdb0e77 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc @@ -0,0 +1,22 @@ +#include "utils/graph/undirected/algorithms/get_connected_components.h" +#include "test/utils/doctest.h" +#include "utils/graph/undirected/undirected_graph.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/algorithms.h" + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; + + CHECK(get_connected_components(g) == expected_components); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc new file mode 100644 index 0000000000..507431163b --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc @@ -0,0 +1,35 @@ +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "test/utils/doctest.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/algorithms.h" +#include "utils/containers.h" + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc new file mode 100644 index 0000000000..28443b6dea --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc @@ -0,0 +1,23 @@ +#include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" +#include "test/utils/doctest.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/algorithms.h" + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE(FF_TEST_SUITE) { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + + std::vector edges = {DirectedEdge{n[0], n[1]}, DirectedEdge{n[2], n[1]}}; + + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; + + CHECK(get_weakly_connected_components(g) == expected_components); + } +} diff --git a/lib/utils/test/src/utils/graph/traversal.cc b/lib/utils/test/src/utils/graph/traversal.cc new file mode 100644 index 0000000000..aa823484e7 --- /dev/null +++ b/lib/utils/test/src/utils/graph/traversal.cc @@ -0,0 +1,44 @@ +#include "utils/graph/traversal.h" +#include "test/utils/doctest.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/algorithms.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("traversal") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 5); + std::vector edges = { + DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}, DirectedEdge{n[2], n[3]}}; + add_edges(g, edges); + + CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_bfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); + CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); + + SUBCASE("with root") { + g.add_edge(DirectedEdge{n[3], n[2]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + } + + SUBCASE("without root") { + g.add_edge(DirectedEdge{n[3], n[0]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + } + SUBCASE("nonlinear") { + g.add_edge(DirectedEdge{n[1], n[3]}); + } + + SUBCASE("not connected") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + } + } +} \ No newline at end of file From f3ae8a9a7e0fa62bace26feca2f5cb38434b7b23 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 17:13:49 -0700 Subject: [PATCH 07/21] fmt --- lib/utils/test/src/test_algorithms.cc | 1 - .../src/utils/graph/digraph/get_connected_components.cc | 4 ++-- .../src/utils/graph/digraph/get_topological_ordering.cc | 4 ++-- .../graph/digraph/get_weakly_connected_components.cc | 5 +++-- lib/utils/test/src/utils/graph/traversal.cc | 9 +++++---- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 7101aad53f..87f7d7a71b 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -95,7 +95,6 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("bfs") { DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 7); diff --git a/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc index bcbfdb0e77..d9fd9e81af 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc @@ -1,8 +1,8 @@ #include "utils/graph/undirected/algorithms/get_connected_components.h" #include "test/utils/doctest.h" -#include "utils/graph/undirected/undirected_graph.h" -#include "utils/graph/instances/hashmap_undirected_graph.h" #include "utils/graph/algorithms.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/undirected/undirected_graph.h" TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc index 507431163b..9b5c243910 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc +++ b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc @@ -1,9 +1,9 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "test/utils/doctest.h" +#include "utils/containers.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/algorithms.h" -#include "utils/containers.h" TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc index 28443b6dea..c48b5e6cdf 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc +++ b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc @@ -1,8 +1,8 @@ #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/algorithms.h" TEST_SUITE(FF_TEST_SUITE) { @@ -10,7 +10,8 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); - std::vector edges = {DirectedEdge{n[0], n[1]}, DirectedEdge{n[2], n[1]}}; + std::vector edges = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[2], n[1]}}; add_edges(g, edges); std::unordered_set> expected_components = { diff --git a/lib/utils/test/src/utils/graph/traversal.cc b/lib/utils/test/src/utils/graph/traversal.cc index aa823484e7..185712cf8f 100644 --- a/lib/utils/test/src/utils/graph/traversal.cc +++ b/lib/utils/test/src/utils/graph/traversal.cc @@ -1,15 +1,16 @@ #include "utils/graph/traversal.h" #include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/algorithms.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("traversal") { DiGraph g = DiGraph::create(); std::vector const n = add_nodes(g, 5); - std::vector edges = { - DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}, DirectedEdge{n[2], n[3]}}; + std::vector edges = {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}; add_edges(g, edges); CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == @@ -41,4 +42,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); } } -} \ No newline at end of file +} From 2d148dbe47c891ef70e6613b34c35476c3dd31d4 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 17:33:06 -0700 Subject: [PATCH 08/21] test rearranging --- lib/utils/test/src/test_algorithms.cc | 31 -------------- .../src/utils/graph/digraph/algorithms.cc | 41 +++++++++++++++++++ 2 files changed, 41 insertions(+), 31 deletions(-) create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms.cc diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 87f7d7a71b..2542e2f68b 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -56,43 +56,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set{e[0], e[2], e[3]}); CHECK(get_outgoing_edges(g, {n[2], n[3]}) == std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } SUBCASE("get_bfs") { std::unordered_set start_points = std::unordered_set{n[0]}; auto expected = std::vector{n[0], n[2], n[1], n[3]}; CHECK(get_bfs_ordering(g, start_points) == expected); } - - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } } TEST_CASE("bfs") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms.cc new file mode 100644 index 0000000000..47abd5186d --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms.cc @@ -0,0 +1,41 @@ +#include "utils/graph/digraph/algorithms.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DiGraph - algorithms.cc") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + std::vector e = { + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + add_edges(g, e); + + SUBCASE("get_edges") { + std::unordered_set expected_edges = unordered_set_of(e); + std::unordered_set actual_edges = get_edges(g); + CHECK(actual_edges == expected_edges); + } + + SUBCASE("get_sinks") { + std::unordered_set expected_sinks = {n[2], n[3]}; + std::unordered_set actual_sinks = get_sinks(g); + CHECK(actual_sinks == expected_sinks); + } + + SUBCASE("get_sources") { + std::unordered_set expected_sources = {n[0]}; + std::unordered_set actual_sources = get_sources(g); + CHECK(actual_sources == expected_sources); + } + } +} From 86e7dd048ba02b7c7bd5c73254e5a6a67cf1278c Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 17:50:45 -0700 Subject: [PATCH 09/21] rearranging --- lib/utils/test/src/test_algorithms.cc | 106 ------------------ .../algorithms/get_incoming_edges.cc | 31 +++++ .../algorithms/get_outgoing_edges.cc | 31 +++++ 3 files changed, 62 insertions(+), 106 deletions(-) delete mode 100644 lib/utils/test/src/test_algorithms.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc deleted file mode 100644 index 2542e2f68b..0000000000 --- a/lib/utils/test/src/test_algorithms.cc +++ /dev/null @@ -1,106 +0,0 @@ -#include "test/utils/doctest.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/construction.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/undirected.h" -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - } - - TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } - } - - TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < - index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); - } -} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..59141513c3 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -0,0 +1,31 @@ +#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" +#include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph - get_incoming_edges") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector> input = { + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(1)}, + {n.at(0), n.at(0)}, + }; + + std::vector edges = add_edges(g, input); + + CHECK(get_incoming_edges(g, n[1]) == + std::unordered_set{edges[0], edges[1], edges[2]}); + CHECK(get_incoming_edges(g, n[2]) == std::unordered_set{}); + } +} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc new file mode 100644 index 0000000000..428d6589f1 --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -0,0 +1,31 @@ +#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" +#include "test/utils/doctest.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/algorithms/add_edges.h" +#include "utils/graph/multidigraph/algorithms/add_nodes.h" +#include "utils/graph/multidigraph/multidigraph.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph - get_outgoing_edges") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 3); + + std::vector> input = { + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(1)}, + {n.at(0), n.at(0)}, + }; + + std::vector edges = add_edges(g, input); + + CHECK(get_outgoing_edges(g, n[0]) == + std::unordered_set{edges[0], edges[1], edges[3]}); + CHECK(get_outgoing_edges(g, n[2]) == std::unordered_set{}); + } +} From c14a34b49ae39f8076784a9d0aecddf277cd5f58 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 18:20:34 -0700 Subject: [PATCH 10/21] graph fixes --- lib/utils/test/src/test_multidigraph.cc | 94 ------------------- .../utils/graph/multidigraph/multidigraph.cc | 91 ++++++++++++++++++ 2 files changed, 91 insertions(+), 94 deletions(-) delete mode 100644 lib/utils/test/src/test_multidigraph.cc create mode 100644 lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc deleted file mode 100644 index 90e1bb2187..0000000000 --- a/lib/utils/test/src/test_multidigraph.cc +++ /dev/null @@ -1,94 +0,0 @@ -#include "test/utils/doctest.h" -#include "utils/graph/adjacency_multidigraph.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/multidigraph_interfaces.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( - query_set({p[1], p[2]}))) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( - query_set({p[0], p[2]}))) == - std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == - std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } - - SUBCASE("remove_edge") { - g.remove_edge(e[0]); - - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - } - } -} diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc new file mode 100644 index 0000000000..5a4e3649bd --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -0,0 +1,91 @@ +#include "utils/graph/multidigraph/multidigraph.h" +#include "test/utils/doctest.h" +#include "utils/containers/contains.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/multidiedge_query.h" +#include "utils/graph/query_set.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + + MultiDiEdge e0 = g.add_edge(n1, n0); + MultiDiEdge e1 = g.add_edge(n2, n0); + MultiDiEdge e2 = g.add_edge(n0, n2); + MultiDiEdge e3 = g.add_edge(n1, n2); + + SUBCASE("add_node") { + Node n3 = g.add_node(); + std::unordered_set nodes = g.query_nodes(NodeQuery{{n3}}); + CHECK(contains(nodes, n3)); + } + + SUBCASE("add_edge") { + MultiDiEdge e4 = g.add_edge(n2, n1); + std::unordered_set edges = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + CHECK(contains(edges, e4)); + } + + SUBCASE("remove_node") { + g.remove_node(n0); + std::unordered_set nodes = g.query_nodes(NodeQuery{{n0}}); + CHECK_FALSE(contains(nodes, n0)); + + std::unordered_set edges_after_removal = + g.query_edges(MultiDiEdgeQuery({n0}, {n0})); + CHECK_FALSE(contains(edges_after_removal, e0)); + CHECK_FALSE(contains(edges_after_removal, e1)); + CHECK_FALSE(contains(edges_after_removal, e2)); + } + + SUBCASE("remove_edge") { + g.remove_edge(e3); + std::unordered_set edges = + g.query_edges(MultiDiEdgeQuery({n1}, {n2})); + CHECK_FALSE(contains(edges, e3)); + } + + SUBCASE("query_nodes") { + std::unordered_set all_nodes = + g.query_nodes(NodeQuery{{n0, n1, n2}}); + CHECK(all_nodes == std::unordered_set{n0, n1, n2}); + + std::unordered_set specific_nodes = + g.query_nodes(NodeQuery{{n0, n2}}); + CHECK(specific_nodes == std::unordered_set{n0, n2}); + } + + SUBCASE("query_edges") { + std::unordered_set all_edges = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); + CHECK(all_edges == std::unordered_set{e0, e1, e2, e3}); + + std::unordered_set edges_from_n1 = + g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); + CHECK(edges_from_n1 == std::unordered_set{e0, e3}); + + std::unordered_set edges_to_n2 = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); + CHECK(edges_to_n2 == std::unordered_set{e2, e3}); + } + + SUBCASE("get_multidiedge_src") { + CHECK(g.get_multidiedge_src(e0) == n1); + CHECK(g.get_multidiedge_dst(e0) == n0); + } + + SUBCASE("get_multidiedge_dst") { + CHECK(g.get_multidiedge_src(e2) == n0); + CHECK(g.get_multidiedge_dst(e2) == n2); + } + } +} From b870a447e73c3b9e1fdaaccfe5025e2f664fc98b Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Sun, 11 Aug 2024 18:31:05 -0700 Subject: [PATCH 11/21] undirected graph test fix --- .../graph/undirected/undirected_graph.cc} | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) rename lib/utils/test/src/{test_undirected_graph.cc => utils/graph/undirected/undirected_graph.cc} (82%) diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc similarity index 82% rename from lib/utils/test/src/test_undirected_graph.cc rename to lib/utils/test/src/utils/graph/undirected/undirected_graph.cc index 33b102bd3b..770f8395c8 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc @@ -1,7 +1,10 @@ +#include "utils/graph/undirected/undirected_graph.h" #include "test/utils/all.h" #include "test/utils/rapidcheck/visitable.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/undirected.h" +#include "utils/containers/repeat.h" +#include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" /* namespace rc { */ @@ -50,12 +53,12 @@ TEST_SUITE(FF_TEST_SUITE) { g.add_edge(edge); } - CHECK(g.query_nodes(NodeQuery::all()) == unordered_set_of(n)); + CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); auto subset = *rc::subset_of(n); CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == unordered_set_of(e)); + CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); }); } } From 14c930bdd7812ec5abd67238eae518b9a88daddf Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 12 Aug 2024 13:05:59 -0700 Subject: [PATCH 12/21] Updated docs --- lib/utils/include/utils/graph/README.md | 159 +++++++----------------- 1 file changed, 42 insertions(+), 117 deletions(-) diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 25b0103f9c..0c5687766d 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,10 +15,10 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected -- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. - +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. +- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. +- `DataFlowGraph`: similar to `MultiDiGraph`, but the edges entering, exiting a given nodes now have a well-defined order. Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -37,7 +37,7 @@ flowchart TD D --- B ``` -Example of `DirectedGraph`: +Example of `DiGraph`: ```mermaid flowchart TD A(" ") @@ -58,98 +58,34 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A("A") - B("B") - C("C") - D("D") - E("E") - F("F") - - A -->|"(■, ★)"| B - B -->|"(●, ★)"| C - C -->|"(♥, ▲)"| D - D -->|"(●, ■)"| A - B -->|"(★, ●)"| E - E -->|"(■, ■)"| B - D -->|"(●, ●)"| A - A -->|"(●, ■)"| E - D -->|"(■, ●)"| D - E -->|"(■, ■)"| E -``` -or visualized a different way, -```mermaid -flowchart TD - Acirc("●") - Asqua("■") - Bcirc("●") - Bstar("★") - Bsqua("■") - Chear("♥") - Cstar("★") - Dsqua("■") - Dcirc("●") - Dtria("▲") - Ecirc("●") - Esqua("■") - Fplaceholder(" ") - - style Fplaceholder fill:#0000,stroke:#0000 - - subgraph "A" - Acirc - Asqua - end - - subgraph "B" - Bsqua - Bcirc - Bstar - end - - subgraph "C" - Chear - Cstar - end - - subgraph "D" - Dsqua - Dcirc - Dtria - end - - subgraph "E" - Ecirc - Esqua - end - - subgraph "F" - Fplaceholder - end - - Asqua --> Bstar - Bcirc --> Cstar - Chear --> Dtria - Dcirc --> Asqua - Bstar --> Ecirc - Esqua --> Bsqua - Dcirc --> Acirc - Acirc --> Esqua - Dsqua --> Dcirc - Esqua --> Esqua + A + B + C + D + E + F + + A --> B + B --> C + C --> D + D --> A + B --> E + E --> B + D --> A + A --> E + D --> D + E --> E ``` -Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. +Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case with all of the 4 core graph classes. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. -`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. -Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. -`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataFlowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -158,11 +94,13 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. +They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). +You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -171,40 +109,32 @@ This may seem wasteful (oftentimes graphs are large objects that are passed arou At this point, however, we still have not discussed how to create a graph. The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: +For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open, Upward, Downward +### Open DataFlow Variant `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). +This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. -![Open graphs inheritance diagram](docs/open.svg) - -Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) - - -### Labelled Graphs +### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). -While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. -Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. +Thus, FlexFlow's graph library provides the ability to add labels to `DataFlowGraph`, through the `LabelleledDataFlowGraph` and `OpenLabelleledDataFlowGraph`, which allow users to label both nodes and edges. +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. +Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. -![Labelled Graphs Inheritance Diagram](docs/labelled.svg) - - ## Internals @@ -236,12 +166,7 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. -To create graphs within the library, we thus use the following syntax: -`BaseGraph obj = BaseGraph::create();` - -Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` - ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). -In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge. +In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. From 5f1bd6d7997080be179b1da38d8c1bd1bbb66874 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 12 Sep 2024 06:26:29 -0700 Subject: [PATCH 13/21] Test updates + bug fixes --- lib/compiler/src/machine_mapping.cc | 1 - lib/kernels/test/src/test_dropout.cc | 1 - lib/runtime/src/model.cc | 1 - .../src/task_invocation_compilation.cc | 1 - .../task_spec/task_invocation_compilation.cc | 1 - .../src/task_spec/tensor_args_format.cc | 2 +- .../unlabelled/multidigraph_pattern_match.cc | 1 - .../utils/bidict/algorithms/merge_bidicts.h | 11 +- lib/utils/include/utils/containers.decl.h | 95 ------ lib/utils/include/utils/containers.h | 280 ------------------ .../include/utils/containers/are_all_same.h | 24 ++ .../include/utils/containers/compare_by.h | 15 + .../include/utils/containers/get_first.h | 15 - .../include/utils/containers/get_one_of.h | 15 + lib/utils/include/utils/containers/index_of.h | 24 ++ .../include/utils/containers/is_submap.h | 18 ++ .../utils/containers/is_superseteq_of.h | 17 ++ lib/utils/include/utils/containers/map_keys.h | 7 + lib/utils/include/utils/containers/maximum.h | 21 ++ .../include/utils/containers/merge_maps.h | 7 +- .../utils/containers/optional_all_of.h | 26 ++ lib/utils/include/utils/containers/product.h | 16 + .../utils/containers/reversed_container.h | 85 ++++++ lib/utils/include/utils/containers/subvec.h | 7 +- lib/utils/include/utils/containers/sum.h | 33 +++ .../include/utils/containers/value_all.h | 32 ++ .../include/utils/containers/vector_split.h | 8 +- lib/utils/include/utils/fmt/optional.h | 19 ++ lib/utils/include/utils/graph/README.md | 11 +- .../graph/digraph/algorithms/get_dominators.h | 13 + .../include/utils/graph/node/node.struct.toml | 1 + .../utils/graph/undirected/undirected_edge.h | 32 +- .../undirected/undirected_edge.struct.toml | 18 ++ lib/utils/include/utils/join_strings.h | 6 + lib/utils/src/containers.cc | 1 - .../src/utils/containers/are_all_same.cc | 1 + lib/utils/src/utils/containers/compare_by.cc | 1 + lib/utils/src/utils/containers/get_first.cc | 1 - lib/utils/src/utils/containers/get_one_of.cc | 1 + lib/utils/src/utils/containers/index_of.cc | 1 + lib/utils/src/utils/containers/is_submap.cc | 1 + .../src/utils/containers/is_superseteq_of.cc | 1 + lib/utils/src/utils/containers/maximum.cc | 1 + .../src/utils/containers/optional_all_of.cc | 1 + .../utils/containers/reversed_container.cc | 1 + lib/utils/src/utils/containers/sum.cc | 1 + lib/utils/src/utils/containers/value_all.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 5 +- .../algorithms/find_isomorphism.cc | 4 +- .../get_cbc_decomposition.cc | 4 +- .../algorithms/get_imm_post_dominator.cc | 4 +- .../instances/hashmap_undirected_graph.cc | 27 +- .../algorithms/find_isomorphism.cc | 4 +- .../algorithms/find_isomorphisms.cc | 2 +- .../utils/graph/undirected/undirected_edge.cc | 46 +-- lib/utils/src/utils/graph/views/views.cc | 26 +- .../utils/bidict/algorithms/merge_bidicts.cc | 35 +++ lib/utils/test/src/utils/bidict/bidict.cc | 25 +- lib/utils/test/src/utils/containers.cc | 185 ------------ .../test/src/utils/containers/are_all_same.cc | 24 ++ .../test/src/utils/containers/are_disjoint.cc | 26 +- .../test/src/utils/containers/as_vector.cc | 9 +- .../test/src/utils/containers/compare_by.cc | 19 ++ .../test/src/utils/containers/enumerate.cc | 19 +- .../test/src/utils/containers/filter_keys.cc | 7 +- lib/utils/test/src/utils/containers/find.cc | 29 +- .../test/src/utils/containers/flatmap.cc | 26 ++ .../{get_first.cc => get_one_of.cc} | 6 +- .../test/src/utils/containers/index_of.cc | 22 ++ .../test/src/utils/containers/is_submap.cc | 40 +++ .../{is_subset_eq.cc => is_subset_eq_of.cc} | 0 .../src/utils/containers/is_superseteq_of.cc | 25 ++ lib/utils/test/src/utils/containers/keys.cc | 3 +- .../test/src/utils/containers/map_keys.cc | 10 +- .../test/src/utils/containers/map_values.cc | 9 +- .../test/src/utils/containers/maximum.cc | 25 ++ .../test/src/utils/containers/merge_maps.cc | 30 ++ .../src/utils/containers/optional_all_of.cc | 39 +++ .../test/src/utils/containers/product.cc | 40 +++ .../src/utils/containers/restrict_keys.cc | 8 +- .../test/src/utils/containers/reversed.cc | 6 +- .../test/src/utils/containers/set_union.cc | 5 +- .../test/src/utils/containers/sorted_by.cc | 23 +- lib/utils/test/src/utils/containers/subvec.cc | 59 +++- lib/utils/test/src/utils/containers/sum.cc | 41 +++ .../test/src/utils/containers/value_all.cc | 24 ++ lib/utils/test/src/utils/containers/values.cc | 8 +- .../test/src/utils/containers/vector_split.cc | 33 ++- lib/utils/test/src/utils/disjoint_set.cc | 22 +- .../src/utils/graph/digraph/algorithms.cc | 85 +++++- .../digraph/algorithms/get_dominators.cc | 47 ++- .../graph/digraph/directed_edge_query.cc | 58 ++-- .../graph/digraph/get_topological_ordering.cc | 19 +- .../get_weakly_connected_components.cc | 13 +- .../algorithms/get_incoming_edges.cc | 20 +- .../algorithms/get_outgoing_edges.cc | 21 +- .../utils/graph/multidigraph/multidigraph.cc | 169 ++++++++--- lib/utils/test/src/utils/graph/traversal.cc | 111 +++++-- .../algorithms}/get_connected_components.cc | 10 +- .../graph/undirected/undirected_graph.cc | 5 +- lib/utils/test/src/utils/graph/views/views.cc | 52 ++-- lib/utils/test/src/utils/hash-utils.cc | 2 +- lib/utils/test/src/utils/join_strings.cc | 12 + lib/utils/test/src/utils/random_utils.cc | 89 +++--- lib/utils/test/src/utils/stack_map.cc | 2 +- lib/utils/test/src/utils/stack_string.cc | 17 +- lib/utils/test/src/utils/stack_vector.cc | 15 +- lib/utils/test/src/utils/tuple.cc | 2 +- lib/utils/test/src/utils/type_index.cc | 10 +- lib/utils/test/src/utils/variant.cc | 2 + lib/utils/test/src/utils/vector.cc | 3 +- 111 files changed, 1633 insertions(+), 1012 deletions(-) delete mode 100644 lib/utils/include/utils/containers.decl.h delete mode 100644 lib/utils/include/utils/containers.h create mode 100644 lib/utils/include/utils/containers/are_all_same.h create mode 100644 lib/utils/include/utils/containers/compare_by.h delete mode 100644 lib/utils/include/utils/containers/get_first.h create mode 100644 lib/utils/include/utils/containers/get_one_of.h create mode 100644 lib/utils/include/utils/containers/index_of.h create mode 100644 lib/utils/include/utils/containers/is_submap.h create mode 100644 lib/utils/include/utils/containers/is_superseteq_of.h create mode 100644 lib/utils/include/utils/containers/maximum.h create mode 100644 lib/utils/include/utils/containers/optional_all_of.h create mode 100644 lib/utils/include/utils/containers/reversed_container.h create mode 100644 lib/utils/include/utils/containers/sum.h create mode 100644 lib/utils/include/utils/containers/value_all.h create mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml delete mode 100644 lib/utils/src/containers.cc create mode 100644 lib/utils/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/src/utils/containers/compare_by.cc delete mode 100644 lib/utils/src/utils/containers/get_first.cc create mode 100644 lib/utils/src/utils/containers/get_one_of.cc create mode 100644 lib/utils/src/utils/containers/index_of.cc create mode 100644 lib/utils/src/utils/containers/is_submap.cc create mode 100644 lib/utils/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/src/utils/containers/maximum.cc create mode 100644 lib/utils/src/utils/containers/optional_all_of.cc create mode 100644 lib/utils/src/utils/containers/reversed_container.cc create mode 100644 lib/utils/src/utils/containers/sum.cc create mode 100644 lib/utils/src/utils/containers/value_all.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc delete mode 100644 lib/utils/test/src/utils/containers.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/test/src/utils/containers/compare_by.cc create mode 100644 lib/utils/test/src/utils/containers/flatmap.cc rename lib/utils/test/src/utils/containers/{get_first.cc => get_one_of.cc} (64%) create mode 100644 lib/utils/test/src/utils/containers/index_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_submap.cc rename lib/utils/test/src/utils/containers/{is_subset_eq.cc => is_subset_eq_of.cc} (100%) create mode 100644 lib/utils/test/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/test/src/utils/containers/maximum.cc create mode 100644 lib/utils/test/src/utils/containers/merge_maps.cc create mode 100644 lib/utils/test/src/utils/containers/optional_all_of.cc create mode 100644 lib/utils/test/src/utils/containers/product.cc create mode 100644 lib/utils/test/src/utils/containers/sum.cc create mode 100644 lib/utils/test/src/utils/containers/value_all.cc rename lib/utils/test/src/utils/graph/{digraph => undirected/algorithms}/get_connected_components.cc (62%) diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index af7756c635..741be06e68 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -6,7 +6,6 @@ #include "pcg/machine_view.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" diff --git a/lib/kernels/test/src/test_dropout.cc b/lib/kernels/test/src/test_dropout.cc index 981bc611d8..81f3c7183a 100644 --- a/lib/kernels/test/src/test_dropout.cc +++ b/lib/kernels/test/src/test_dropout.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "kernels/dropout_kernels.h" #include "test_utils.h" -#include "utils/containers.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/runtime/src/model.cc b/lib/runtime/src/model.cc index a655bcb050..22f0f2e98d 100644 --- a/lib/runtime/src/model.cc +++ b/lib/runtime/src/model.cc @@ -27,7 +27,6 @@ #include "parallel_tensor_mapping.h" #include "task_spec/task_argument_accessor.h" #include "test_utils.h" -#include "utils/containers.h" #include "utils/random_utils.h" #include #include diff --git a/lib/runtime/src/task_invocation_compilation.cc b/lib/runtime/src/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_invocation_compilation.cc +++ b/lib/runtime/src/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/task_invocation_compilation.cc b/lib/runtime/src/task_spec/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_spec/task_invocation_compilation.cc +++ b/lib/runtime/src/task_spec/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/tensor_args_format.cc b/lib/runtime/src/task_spec/tensor_args_format.cc index 22bd021996..8fb7fc029b 100644 --- a/lib/runtime/src/task_spec/tensor_args_format.cc +++ b/lib/runtime/src/task_spec/tensor_args_format.cc @@ -1,5 +1,5 @@ #include "tensor_args_format.h" -#include "utils/containers.h" +#include "utils/containers/flatmap.h" namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc index 8ce60fab4f..d69fd42b10 100644 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc @@ -1,7 +1,6 @@ #include "substitutions/unlabelled/multidigraph_pattern_match.h" // #include "substitutions/unlabelled/edge_splits.h" // #include "substitutions/unlabelled/pattern_edge.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h index d388e35d75..efb1e74cd0 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h @@ -5,13 +5,20 @@ #include "utils/bidict/algorithms/right_entries.h" #include "utils/bidict/bidict.h" #include "utils/containers/are_disjoint.h" +#include "utils/exception.h" namespace FlexFlow { template bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { - assert(are_disjoint(left_entries(lhs), left_entries(rhs))); - assert(are_disjoint(right_entries(lhs), right_entries(rhs))); + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + "Left entries of merge_bidicts parameters are non-disjoint"); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + "Right entries of merge_bidicts parameters are non-disjoint"); + } bidict result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h deleted file mode 100644 index e3fdbeabb1..0000000000 --- a/lib/utils/include/utils/containers.decl.h +++ /dev/null @@ -1,95 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H - -#include "utils/bidict/bidict.h" -#include "utils/containers/get_element_type.h" -#include "utils/required_core.h" -#include "utils/type_traits_core.h" -#include -#include -#include - -namespace FlexFlow { - -template -Element sum(Container const &container); - -template -Element sum_where(Container const &container, ConditionF const &condition); - -template -Element product_where(Container const &container, ConditionF const &condition); - -template -bool contains_l(bidict const &m, K const &k); - -template -bool contains_r(bidict const &m, V const &v); - -template -std::optional index_of(Container const &c, Element const &e); - -template -std::unordered_map restrict_keys(std::unordered_map const &m, - std::unordered_set const &mask); - -template -std::unordered_map merge_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs); - -template -std::optional at_idx(std::vector const &v, size_t idx); - -template -std::function lookup_in(std::unordered_map const &m); - -template -std::function lookup_in_l(bidict const &m); - -template -std::function lookup_in_r(bidict const &m); - -template -bool is_superseteq_of(std::unordered_set const &l, - std::unordered_set const &r); - -template -std::optional optional_all_of(Container const &, Function const &); - -template -bool are_all_same(C const &c); - -template -std::vector flatmap(std::vector const &v, F const &f); - -template -std::unordered_set flatmap(std::unordered_set const &v, F const &f); -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)); - -template -std::function compare_by(F const &f); - -template -typename C::value_type maximum(C const &v); - -template -std::vector value_all(std::vector> const &v); - -template -std::unordered_set value_all(std::unordered_set> const &v); - -template -struct reversed_container_t; - -template -reversed_container_t reversed_container(C const &c); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h deleted file mode 100644 index ae65850990..0000000000 --- a/lib/utils/include/utils/containers.h +++ /dev/null @@ -1,280 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL - -#include "containers.decl.h" -#include "required_core.h" -#include "type_traits_core.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/are_disjoint.h" -#include "utils/containers/contains.h" -#include "utils/containers/extend.h" -#include "utils/containers/extend_vector.h" -#include "utils/containers/filter.h" -#include "utils/containers/intersection.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/containers/keys.h" -#include "utils/containers/restrict_keys.h" -#include "utils/containers/sorted.h" -#include "utils/containers/transform.h" -#include "utils/containers/vector_transform.h" -#include "utils/exception.h" -#include "utils/hash/pair.h" -#include "utils/optional.h" -#include "utils/type_traits.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace FlexFlow { - -template -Element sum(Container const &container) { - Element result = 0; - for (Element const &element : container) { - result += element; - } - return result; -} - -template -Element sum_where(Container const &container, ConditionF const &condition) { - Element result = 0; - for (Element const &element : container) { - if (condition(element)) { - result += element; - } - } - return result; -} - -template -Element product_where(Container const &container, ConditionF const &condition) { - Element result = 1; - for (Element const &element : container) { - if (condition(element)) { - result *= element; - } - } - return result; -} - -template -bool contains_l(bidict const &m, K const &k) { - return m.contains_l(k); -} - -template -bool contains_r(bidict const &m, V const &v) { - return m.contains_r(v); -} - -template -bool is_submap(std::unordered_map const &m, - std::unordered_map const &sub) { - return restrict_keys(m, keys(sub)) == sub; -} - -template -std::optional index_of(Container const &c, Element const &e) { - auto it = std::find(c.cbegin(), c.cend(), e); - if (it == c.cend()) { - return std::nullopt; - } else { - return std::distance(c.cbegin(), it); - } -} - -template -std::function lookup_in(std::unordered_map const &m) { - return [&m](K const &k) -> V { return m.at(k); }; -} - -template -std::function lookup_in_l(bidict const &m) { - return [&m](L const &l) -> R { return m.at_l(l); }; -} - -template -std::function lookup_in_r(bidict const &m) { - return [&m](R const &r) -> L { return m.at_r(r); }; -} - -template -bool is_superseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - return is_subseteq_of(r, l); -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -template -bool are_all_same(C const &c) { - auto const &first = *c.cbegin(); - for (auto const &v : c) { - if (v != first) { - return false; - } - } - return true; -} - -template -std::vector flatmap(std::vector const &v, F const &f) { - std::vector result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::function compare_by(F const &f) { - return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; -} - -template -typename C::value_type maximum(C const &v) { - return *std::max_element(v.begin(), v.end()); -} - -template -std::vector value_all(std::vector> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -struct reversed_container_t { - reversed_container_t() = delete; - reversed_container_t(C const &c) : container(c) {} - - reversed_container_t(reversed_container_t const &) = delete; - reversed_container_t(reversed_container_t &&) = delete; - reversed_container_t &operator=(reversed_container_t const &) = delete; - reversed_container_t &operator=(reversed_container_t &&) = delete; - - using iterator = typename C::reverse_iterator; - using const_iterator = typename C::const_reverse_iterator; - using reverse_iterator = typename C::iterator; - using const_reverse_iterator = typename C::const_iterator; - using value_type = typename C::value_type; - using pointer = typename C::pointer; - using const_pointer = typename C::const_pointer; - using reference = typename C::reference; - using const_reference = typename C::const_reference; - - iterator begin() { - return this->container.rend(); - } - - iterator end() { - return this->container.rbegin(); - } - - const_iterator cbegin() const { - return this->container.crend(); - } - - const_iterator cend() const { - return this->container.crbegin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator end() const { - return this->cend(); - } - - reverse_iterator rbegin() { - return this->container.begin(); - } - - reverse_iterator rend() { - return this->container.end(); - } - - const_reverse_iterator crbegin() const { - return this->container.cbegin(); - } - - const_reverse_iterator crend() const { - return this->container.cend(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - -private: - C const &container; -}; - -template -reversed_container_t reversed_container(C const &c) { - return reversed_container_t(c); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h new file mode 100644 index 0000000000..26716ac911 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H + +#include "utils/exception.h" + +namespace FlexFlow { + +template +bool are_all_same(C const &c) { + if (c.size() == 0) { + throw mk_runtime_error("input to are_all_same must be non-empty container"); + } + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/compare_by.h b/lib/utils/include/utils/containers/compare_by.h new file mode 100644 index 0000000000..d6cb7f48cd --- /dev/null +++ b/lib/utils/include/utils/containers/compare_by.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H + +#include + +namespace FlexFlow { + +template +std::function compare_by(F const &f) { + return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h deleted file mode 100644 index ce2a483401..0000000000 --- a/lib/utils/include/utils/containers/get_first.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H - -#include - -namespace FlexFlow { - -template -T get_first(std::unordered_set const &s) { - return *s.cbegin(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/get_one_of.h b/lib/utils/include/utils/containers/get_one_of.h new file mode 100644 index 0000000000..274bc27239 --- /dev/null +++ b/lib/utils/include/utils/containers/get_one_of.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H + +#include + +namespace FlexFlow { + +template +T get_one_of(std::unordered_set const &s) { + return *s.cbegin(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/index_of.h b/lib/utils/include/utils/containers/index_of.h new file mode 100644 index 0000000000..d78888decd --- /dev/null +++ b/lib/utils/include/utils/containers/index_of.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H + +#include +#include +namespace FlexFlow { + +/** + * @details If multiple `e` are present within the container, the function + *returns the index of the first appearance + **/ +template +std::optional index_of(Container const &c, Element const &e) { + auto it = std::find(c.cbegin(), c.cend(), e); + if (it == c.cend()) { + return std::nullopt; + } else { + return std::distance(c.cbegin(), it); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_submap.h b/lib/utils/include/utils/containers/is_submap.h new file mode 100644 index 0000000000..8cca9f6924 --- /dev/null +++ b/lib/utils/include/utils/containers/is_submap.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H + +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include + +namespace FlexFlow { + +template +bool is_submap(std::unordered_map const &m, + std::unordered_map const &sub) { + return restrict_keys(m, keys(sub)) == sub; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_superseteq_of.h b/lib/utils/include/utils/containers/is_superseteq_of.h new file mode 100644 index 0000000000..1793201261 --- /dev/null +++ b/lib/utils/include/utils/containers/is_superseteq_of.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H + +#include "utils/containers/is_subseteq_of.h" +#include + +namespace FlexFlow { + +template +bool is_superseteq_of(std::unordered_set const &l, + std::unordered_set const &r) { + return is_subseteq_of(r, l); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index e252333e93..bed6c5e4c0 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -6,6 +6,13 @@ namespace FlexFlow { +/** + * @brief Applies the given function to all the keys within the given map and + * returns the updated map. + * + * @details Note that if multiple original keys are transformed to the same new + * key, which value will be associated to the new key is undefined. + */ template + +namespace FlexFlow { + +template +typename C::value_type maximum(C const &c) { + if (c.size() == 0) { + throw mk_runtime_error( + "input to function maximum must be non-empty container"); + } + + return *std::max_element(c.begin(), c.end()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 653c9d24f1..3d93863b2d 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -3,6 +3,8 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" +#include "utils/exception.h" +#include #include namespace FlexFlow { @@ -10,7 +12,10 @@ namespace FlexFlow { template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); + if (!are_disjoint(keys(lhs), keys(rhs))) { + throw mk_runtime_error( + "Key sets of merge_maps parameters are non-disjoint"); + } std::unordered_map result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers/optional_all_of.h b/lib/utils/include/utils/containers/optional_all_of.h new file mode 100644 index 0000000000..825633891b --- /dev/null +++ b/lib/utils/include/utils/containers/optional_all_of.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_OPTIONAL_ALL_OF_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_OPTIONAL_ALL_OF_H + +#include + +namespace FlexFlow { + +template +std::optional optional_all_of(Container const &container, + Function const &func) { + for (auto const &element : container) { + std::optional condition = func(element); + if (!condition.has_value()) { + return std::nullopt; + } + + if (!condition.value()) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/product.h b/lib/utils/include/utils/containers/product.h index 52ff36e790..a9fd8b2e76 100644 --- a/lib/utils/include/utils/containers/product.h +++ b/lib/utils/include/utils/containers/product.h @@ -5,6 +5,9 @@ namespace FlexFlow { +/** + * @details An empty container vacuously has product 1 + **/ template Element product(Container const &container) { Element result = 1; @@ -23,6 +26,19 @@ typename It::value_type product(It begin, It end) { }); } +template +Element product_where(Container const &container, ConditionF const &condition) { + Element result = 1; + for (Element const &element : container) { + if (condition(element)) { + result *= element; + } + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/reversed_container.h b/lib/utils/include/utils/containers/reversed_container.h new file mode 100644 index 0000000000..cffef3c6e6 --- /dev/null +++ b/lib/utils/include/utils/containers/reversed_container.h @@ -0,0 +1,85 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H + +namespace FlexFlow { + +template +struct reversed_container_t { + reversed_container_t() = delete; + reversed_container_t(C const &c) : container(c) {} + + reversed_container_t(reversed_container_t const &) = delete; + reversed_container_t(reversed_container_t &&) = delete; + reversed_container_t &operator=(reversed_container_t const &) = delete; + reversed_container_t &operator=(reversed_container_t &&) = delete; + + using iterator = typename C::reverse_iterator; + using const_iterator = typename C::const_reverse_iterator; + using reverse_iterator = typename C::iterator; + using const_reverse_iterator = typename C::const_iterator; + using value_type = typename C::value_type; + using pointer = typename C::pointer; + using const_pointer = typename C::const_pointer; + using reference = typename C::reference; + using const_reference = typename C::const_reference; + + iterator begin() { + return this->container.rend(); + } + + iterator end() { + return this->container.rbegin(); + } + + const_iterator cbegin() const { + return this->container.crend(); + } + + const_iterator cend() const { + return this->container.crbegin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator end() const { + return this->cend(); + } + + reverse_iterator rbegin() { + return this->container.begin(); + } + + reverse_iterator rend() { + return this->container.end(); + } + + const_reverse_iterator crbegin() const { + return this->container.cbegin(); + } + + const_reverse_iterator crend() const { + return this->container.cend(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + +private: + C const &container; +}; + +template +reversed_container_t reversed_container(C const &c) { + return reversed_container_t(c); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 52368f94ad..77eb6dc421 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -16,9 +16,9 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { if (idx < 0) { - return v.size() + idx; + return std::max(0, (int)v.size() + idx); } else { - return idx; + return std::min(idx, (int)v.size()); } }; @@ -28,6 +28,9 @@ std::vector subvec(std::vector const &v, if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (begin_iter >= end_iter) { + return {}; + } std::vector output(begin_iter, end_iter); return output; diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h new file mode 100644 index 0000000000..4e175ae792 --- /dev/null +++ b/lib/utils/include/utils/containers/sum.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H + +namespace FlexFlow { + +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum(Container const &container) { + Element result = 0; + for (Element const &element : container) { + result += element; + } + return result; +} + +template +Element sum_where(Container const &container, ConditionF const &condition) { + Element result = 0; + for (Element const &element : container) { + if (condition(element)) { + result += element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/value_all.h b/lib/utils/include/utils/containers/value_all.h new file mode 100644 index 0000000000..996b2adfbd --- /dev/null +++ b/lib/utils/include/utils/containers/value_all.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H + +#include "utils/containers/transform.h" +#include "utils/exception.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +std::vector value_all(std::vector> const &v) { + return transform(v, [](std::optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](std::optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_split.h b/lib/utils/include/utils/containers/vector_split.h index b55eea19d4..a5dc098644 100644 --- a/lib/utils/include/utils/containers/vector_split.h +++ b/lib/utils/include/utils/containers/vector_split.h @@ -1,15 +1,19 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H +#include "utils/fmt/vector.h" #include +#include #include namespace FlexFlow { template std::pair, std::vector> vector_split(std::vector const &v, - std::size_t idx) { - assert(v.size() > idx); + int idx) { + if (idx < 0 || idx > static_cast(v.size())) { + throw std::out_of_range("Index out of range in vector_split"); + } std::vector prefix(v.begin(), v.begin() + idx); std::vector postfix(v.begin() + idx, v.end()); diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 45eebc2c58..2c6067cdb5 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -30,6 +30,14 @@ struct formatter< } }; +template +struct formatter : formatter { + template + auto format(std::nullopt_t, FormatContext &ctx) -> decltype(ctx.out()) { + return formatter::format("nullopt", ctx); + } +}; + } // namespace fmt namespace FlexFlow { @@ -41,6 +49,10 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { return s << fmt::to_string(t); } +inline std::ostream &operator<<(std::ostream &s, std::nullopt_t) { + return s << "nullopt"; +} + } // namespace FlexFlow namespace doctest { @@ -52,6 +64,13 @@ struct StringMaker> { } }; +template <> +struct StringMaker { + static String convert(std::nullopt_t) { + return toString("nullopt"); + } +}; + } // namespace doctest #endif diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 0c5687766d..90511ae343 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -18,7 +18,9 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. - `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) - `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. -- `DataFlowGraph`: similar to `MultiDiGraph`, but the edges entering, exiting a given nodes now have a well-defined order. +- `DataflowGraph`: similar to `MultiDiGraph`, but the edges entering, exiting a given nodes now have a well-defined order. Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. +Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. + Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -85,7 +87,7 @@ In addition, nodes should only be used in the context of their graph, so compari All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataFlowGraph`. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -126,7 +128,10 @@ This graph class is particularly useful for processing a sub-graph of a given gr ### Labelled Dataflow Variant As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels to `DataFlowGraph`, through the `LabelleledDataFlowGraph` and `OpenLabelleledDataFlowGraph`, which allow users to label both nodes and edges. +Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. +- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. +- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). + While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 1e4d09d3ae..d1019ce70a 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,7 +5,20 @@ namespace FlexFlow { +/** + * @brief A node `d` is said to dominate a node `n` if every path from the root + * node to `n` must go through `d`. + * + * @note By definition, the root node dominates every node and every node + * dominates itself. + * + */ std::unordered_set get_dominators(DiGraphView const &, Node const &); + +/** + * @brief Returns the intersection of the dominators of the input nodes. + * + */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 0b6f348ddf..cf7309c5a6 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "fmt", + "rapidcheck" ] includes = [ diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index 235cf45c4a..d051413faa 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,38 +2,12 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -namespace FlexFlow { - -struct UndirectedEdge { -public: - UndirectedEdge() = delete; - UndirectedEdge(Node const &src, Node const &dst); - - bool operator==(UndirectedEdge const &) const; - bool operator!=(UndirectedEdge const &) const; - bool operator<(UndirectedEdge const &) const; - -public: - Node smaller; - Node bigger; -}; - -bool is_connected_to(UndirectedEdge const &, Node const &); +#include "utils/graph/undirected/undirected_edge.dtg.h" -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::UndirectedEdge> { - size_t operator()(::FlexFlow::UndirectedEdge const &) const; -}; +namespace FlexFlow { -} // namespace std +bool is_connected_to(UndirectedEdge const &e, Node const &n); -namespace FlexFlow { -std::string format_as(UndirectedEdge const &); -std::ostream &operator<<(std::ostream &, UndirectedEdge const &); } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml new file mode 100644 index 0000000000..f5258b0bfd --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "UndirectedEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck" +] + +includes = [ + "utils/commutative_pair.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "endpoints" +type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index 9eb717b066..28a773b13d 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -38,6 +38,12 @@ std::string join_strings(Container const &c, std::string const &delimiter) { return join_strings(c.cbegin(), c.cend(), delimiter); } +template +std::string + join_strings(Container const &c, std::string const &delimiter, F const &f) { + return join_strings(c.cbegin(), c.cend(), delimiter, f); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/src/containers.cc b/lib/utils/src/containers.cc deleted file mode 100644 index 2af7fd1892..0000000000 --- a/lib/utils/src/containers.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers.h" diff --git a/lib/utils/src/utils/containers/are_all_same.cc b/lib/utils/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..c515bceee2 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_same.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_same.h" diff --git a/lib/utils/src/utils/containers/compare_by.cc b/lib/utils/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..b7df348f37 --- /dev/null +++ b/lib/utils/src/utils/containers/compare_by.cc @@ -0,0 +1 @@ +#include "utils/containers/compare_by.h" diff --git a/lib/utils/src/utils/containers/get_first.cc b/lib/utils/src/utils/containers/get_first.cc deleted file mode 100644 index ce8eb9cbea..0000000000 --- a/lib/utils/src/utils/containers/get_first.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/get_first.h" diff --git a/lib/utils/src/utils/containers/get_one_of.cc b/lib/utils/src/utils/containers/get_one_of.cc new file mode 100644 index 0000000000..4ce017f6b9 --- /dev/null +++ b/lib/utils/src/utils/containers/get_one_of.cc @@ -0,0 +1 @@ +#include "utils/containers/get_one_of.h" diff --git a/lib/utils/src/utils/containers/index_of.cc b/lib/utils/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..d0c4b4dfd3 --- /dev/null +++ b/lib/utils/src/utils/containers/index_of.cc @@ -0,0 +1 @@ +#include "utils/containers/index_of.h" diff --git a/lib/utils/src/utils/containers/is_submap.cc b/lib/utils/src/utils/containers/is_submap.cc new file mode 100644 index 0000000000..be420a14be --- /dev/null +++ b/lib/utils/src/utils/containers/is_submap.cc @@ -0,0 +1 @@ +#include "utils/containers/is_submap.h" diff --git a/lib/utils/src/utils/containers/is_superseteq_of.cc b/lib/utils/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..0728c96f17 --- /dev/null +++ b/lib/utils/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_superseteq_of.h" diff --git a/lib/utils/src/utils/containers/maximum.cc b/lib/utils/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..51d92cf951 --- /dev/null +++ b/lib/utils/src/utils/containers/maximum.cc @@ -0,0 +1 @@ +#include "utils/containers/maximum.h" diff --git a/lib/utils/src/utils/containers/optional_all_of.cc b/lib/utils/src/utils/containers/optional_all_of.cc new file mode 100644 index 0000000000..dd31ad78a5 --- /dev/null +++ b/lib/utils/src/utils/containers/optional_all_of.cc @@ -0,0 +1 @@ +#include "utils/containers/optional_all_of.h" diff --git a/lib/utils/src/utils/containers/reversed_container.cc b/lib/utils/src/utils/containers/reversed_container.cc new file mode 100644 index 0000000000..1a2fe3cf63 --- /dev/null +++ b/lib/utils/src/utils/containers/reversed_container.cc @@ -0,0 +1 @@ +#include "utils/containers/reversed_container.h" diff --git a/lib/utils/src/utils/containers/sum.cc b/lib/utils/src/utils/containers/sum.cc new file mode 100644 index 0000000000..088b5f1983 --- /dev/null +++ b/lib/utils/src/utils/containers/sum.cc @@ -0,0 +1 @@ +#include "utils/containers/sum.h" diff --git a/lib/utils/src/utils/containers/value_all.cc b/lib/utils/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..1f863a20e4 --- /dev/null +++ b/lib/utils/src/utils/containers/value_all.cc @@ -0,0 +1 @@ +#include "utils/containers/value_all.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..4c5d61b9b9 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,7 +184,8 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; + UndirectedEdgeQuery q = + UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; return contains(g.query_edges(q), e); } @@ -212,7 +213,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.smaller, e.bigger}; + return {e.endpoints.min(), e.endpoints.max()}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc index d06a64597e..0e4d0c6759 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 011d8b3ed9..f1f1cd7b21 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,7 +1,7 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" #include "utils/graph/digraph/algorithms.h" @@ -28,7 +28,7 @@ std::optional CompleteBipartiteCompositeDecomposition{{}}; while (!edges_to_process.empty()) { - DirectedEdge e = get_first(edges_to_process); + DirectedEdge e = get_one_of(edges_to_process); std::unordered_set head = get_predecessors(g, e.dst); std::unordered_set tail = get_successors(g, e.src); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 1d98c44a9f..39523f2ec1 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -1,6 +1,6 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominator.h" #include "utils/containers/generate_map.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" @@ -31,7 +31,7 @@ std::optional return get_imm_post_dominator(g, get_only(nodes)); } - Node contracted_node = get_first(nodes); + Node contracted_node = get_one_of(nodes); std::unordered_map contraction = generate_map(nodes, [&](Node const &) { return contracted_node; }); return get_imm_post_dominator(apply_contraction(g, contraction), diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 92a6d0b9eb..6da6cffd03 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -1,4 +1,5 @@ #include "utils/graph/instances/hashmap_undirected_graph.h" +#include "utils/containers/contains.h" #include "utils/containers/contains_key.h" #include "utils/containers/keys.h" #include "utils/exception.h" @@ -22,31 +23,37 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.bigger)) { + if (!contains_key(this->adjacency, e.endpoints.min())) { throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.bigger); + "Could not add edge connected to non-existent node {}", + e.endpoints.min()); } - if (!contains_key(this->adjacency, e.smaller)) { + if (!contains_key(this->adjacency, e.endpoints.max())) { throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.smaller); + "Could not add edge connected to non-existent node {}", + e.endpoints.max()); } - this->adjacency.at(e.bigger).insert(e.smaller); - this->adjacency.at(e.smaller).insert(e.bigger); + this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); + this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.bigger); - m.erase(e.smaller); - m.erase(e.bigger); + std::unordered_set &m = this->adjacency.at(e.endpoints.max()); + m.erase(e.endpoints.min()); + m.erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; + std::unordered_set nodes = + keys(query_keys(query.nodes, this->adjacency)); for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { for (auto const &dst : src_kv.second) { - result.insert({src_kv.first, dst}); + if (contains(nodes, dst)) { + result.insert(UndirectedEdge{{src_kv.first, dst}}); + } } } return result; diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc index d622497629..d75a447127 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index d95a9b9565..387b5b8972 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -4,7 +4,7 @@ #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 2e79a8016d..4cfc6aaaa8 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -4,52 +4,8 @@ namespace FlexFlow { -UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) - : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} - -static std::tuple tie(UndirectedEdge const &e) { - return std::tie(e.smaller, e.bigger); -} - -bool UndirectedEdge::operator==(UndirectedEdge const &other) const { - return tie(*this) == tie(other); -} - -bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { - return tie(*this) != tie(other); -} - -bool UndirectedEdge::operator<(UndirectedEdge const &other) const { - return tie(*this) < tie(other); -} - bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.bigger == n || e.smaller == n; + return e.endpoints.min() == n || e.endpoints.max() == n; } } // namespace FlexFlow - -namespace std { - -using namespace FlexFlow; - -size_t hash::operator()(UndirectedEdge const &e) const { - std::tuple members = ::FlexFlow::tie(e); - return std::hash{}(members); -} - -} // namespace std - -namespace FlexFlow { -std::string format_as(UndirectedEdge const &x) { - std::ostringstream oss; - oss << ""; - return oss.str(); -} -std::ostream &operator<<(std::ostream &s, UndirectedEdge const &x) { - return s << fmt::to_string(x); -} -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 9b5353de9f..3ada561dc9 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -146,17 +146,18 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdge JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return { - this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key( + JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::LEFT})}}; } UndirectedEdge JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return {this->joined_nodes.at_join_key( - JoinNodeKey{e.smaller, LRDirection::RIGHT}), - this->joined_nodes.at_join_key( - JoinNodeKey{e.bigger, LRDirection::RIGHT})}; + return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.min(), LRDirection::RIGHT}), + this->joined_nodes.at_join_key(JoinNodeKey{ + e.endpoints.max(), LRDirection::RIGHT})}}; } JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, @@ -208,7 +209,7 @@ DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return {e.src, e.dst}; + return UndirectedEdge{{e.src, e.dst}}; } std::unordered_set to_undirected_edges( @@ -218,8 +219,9 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, - DirectedEdge{e.bigger, e.smaller}}; + return std::unordered_set{ + DirectedEdge{e.endpoints.min(), e.endpoints.max()}, + DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; } std::unordered_set to_directed_edges( @@ -258,8 +260,8 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), + g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc new file mode 100644 index 0000000000..291b4eab73 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc @@ -0,0 +1,35 @@ +#include "utils/bidict/algorithms/merge_bidicts.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_bidicts") { + + SUBCASE("disjoint keys and values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "three"}, {4, "four"}}; + + bidict result = merge_bidicts(bd1, bd2); + bidict correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping keys") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS_AS(merge_bidicts(bd1, bd2), std::runtime_error); + } + + SUBCASE("overlapping values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "two"}, {4, "four"}}; + + CHECK_THROWS_AS(merge_bidicts(bd1, bd2), std::runtime_error); + } + } +} diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 5c2ffd5bba..019aaed7bd 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -10,7 +10,26 @@ TEST_SUITE(FF_TEST_SUITE) { dict.equate(1, "one"); dict.equate(2, "two"); - // Test the equate() function + SUBCASE("bidict::contains_l") { + CHECK(dict.contains_l(1)); + CHECK_FALSE(dict.contains_l(3)); + } + + SUBCASE("bidict::contains_r") { + CHECK(dict.contains_r(std::string("one"))); + CHECK_FALSE(dict.contains_r(std::string("three"))); + } + SUBCASE("bidict::contains_r, bidict::contains_r - same type") { + bidict bd; + bd.equate(1, 3); + bd.equate(2, 4); + + CHECK(bd.contains_l(1)); + CHECK_FALSE(bd.contains_l(3)); + CHECK(bd.contains_r(3)); + CHECK_FALSE(bd.contains_r(1)); + } + SUBCASE("bidict::equate") { CHECK(dict.at_l(1) == "one"); CHECK(dict.at_r("one") == 1); @@ -18,7 +37,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_l() function SUBCASE("bidict::erase_l") { dict.erase_l(1); CHECK(dict.size() == 1); @@ -26,7 +44,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_r() function SUBCASE("bidict::erase_r") { dict.erase_r("one"); CHECK(dict.size() == 1); @@ -34,14 +51,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_l(2) == "two"); } - // Test the reversed() function SUBCASE("bidict::reversed") { bidict reversed_dict = dict.reversed(); CHECK(reversed_dict.at_l("one") == 1); CHECK(reversed_dict.at_r(2) == "two"); } - // Test the size() function SUBCASE("bidict::size") { CHECK(dict.size() == 2); } diff --git a/lib/utils/test/src/utils/containers.cc b/lib/utils/test/src/utils/containers.cc deleted file mode 100644 index 43f651ba09..0000000000 --- a/lib/utils/test/src/utils/containers.cc +++ /dev/null @@ -1,185 +0,0 @@ -#include "utils/containers.h" -#include "test/utils/doctest.h" -#include "utils/bidict/bidict.h" -#include -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); - } - - TEST_CASE("sum_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; - CHECK(sum_where(v, condition) == 6); - } - - TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; - CHECK(product_where(v, condition) == 8); - } - - TEST_CASE("contains_l and contains_r") { - bidict bd; - bd.equate(1, "one"); - bd.equate(2, "two"); - - CHECK(contains_l(bd, 1) == true); - CHECK(contains_l(bd, 3) == false); - CHECK(contains_r(bd, std::string("one")) == true); - CHECK(contains_r(bd, std::string("three")) == false); - } - - TEST_CASE("is_submap") { - std::unordered_map m1 = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_map m2 = {{1, "one"}, {2, "two"}}; - std::unordered_map m3 = {{1, "one"}, {4, "four"}}; - - CHECK(is_submap(m1, m2) == true); - CHECK(is_submap(m1, m3) == false); - } - - TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3).value() == 2); - CHECK(index_of(v, 6) == std::nullopt); - } - - TEST_CASE("merge_maps") { - SUBCASE("bidict") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, - {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, - {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - SUBCASE("unordered_map") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - } - } - - TEST_CASE("lookup_in") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - } - - TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - } - - TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); - } - - TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3, 4}; - std::unordered_set s2 = {1, 2, 3}; - std::unordered_set s3 = {1, 2, 5}; - - CHECK(is_superseteq_of(s1, s2) == true); - CHECK(is_superseteq_of(s1, s3) == false); - } - - TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("optional_all_of") { - std::vector v = {2, 4, 6, 8}; - auto f = [](int x) -> std::optional { return x % 2 == 0; }; - CHECK(optional_all_of(v, f) == true); - - auto f2 = [](int x) -> std::optional { - if (x == 6) { - return std::nullopt; - } - return x % 2 == 0; - }; - CHECK(optional_all_of(v, f2) == std::nullopt); - } - - TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); - } - - TEST_CASE("Test for flatmap function on vectors") { - - auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); - } - } - return factors; - }; - - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); - } - - TEST_CASE("compare_by") { - std::vector v = {"abc", "a", "ab"}; - auto comp = compare_by( - [](std::string const &s) { return s.length(); }); - std::sort(v.begin(), v.end(), comp); - CHECK(v == std::vector{"a", "ab", "abc"}); - } - - TEST_CASE("maximum") { - std::vector v = {1, 5, 3, 4, 2}; - CHECK(maximum(v) == 5); - } - - TEST_CASE("value_all") { - std::vector> v = {1, 2, std::nullopt, 4, 5}; - CHECK_THROWS_AS(value_all(v), std::runtime_error); - - std::vector> v2 = {1, 2, 3, 4, 5}; - CHECK(value_all(v2) == std::vector{1, 2, 3, 4, 5}); - } -} diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..38e01a13a6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -0,0 +1,24 @@ +#include "utils/containers/are_all_same.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_same") { + SUBCASE("All elements are the same") { + std::vector input = {2, 2, 2, 2}; + CHECK(are_all_same(input)); + } + + SUBCASE("Not all elements are the same") { + std::vector input = {1, 2, 3, 4}; + CHECK_FALSE(are_all_same(input)); + } + + SUBCASE("Empty Container") { + std::vector input = {}; + CHECK_THROWS_AS(are_all_same(input), std::runtime_error); + } + } +} diff --git a/lib/utils/test/src/utils/containers/are_disjoint.cc b/lib/utils/test/src/utils/containers/are_disjoint.cc index 1bec8af85f..17516dbf13 100644 --- a/lib/utils/test/src/utils/containers/are_disjoint.cc +++ b/lib/utils/test/src/utils/containers/are_disjoint.cc @@ -6,10 +6,26 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); + SUBCASE("disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("not disjoint") { + std::unordered_set l = {1, 2, 3, 4}; + std::unordered_set r = {3, 4, 5, 6}; + CHECK_FALSE(are_disjoint(l, r)); + } + + SUBCASE("one empty set") { + std::unordered_set l = {1, 2}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("both empty sets") { + std::unordered_set l = {}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } } } diff --git a/lib/utils/test/src/utils/containers/as_vector.cc b/lib/utils/test/src/utils/containers/as_vector.cc index af89f99245..fc4567b585 100644 --- a/lib/utils/test/src/utils/containers/as_vector.cc +++ b/lib/utils/test/src/utils/containers/as_vector.cc @@ -1,4 +1,6 @@ #include "utils/containers/as_vector.h" +#include "utils/containers/sorted.h" +#include "utils/fmt/vector.h" #include #include #include @@ -7,8 +9,9 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("as_vector") { - std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); - CHECK(result == std::vector({3, 2, 1})); + std::unordered_set input = {1, 2, 3}; + std::vector result = as_vector(input); + std::vector correct_sorted = {1, 2, 3}; + CHECK(sorted(result) == correct_sorted); } } diff --git a/lib/utils/test/src/utils/containers/compare_by.cc b/lib/utils/test/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..52e71ba175 --- /dev/null +++ b/lib/utils/test/src/utils/containers/compare_by.cc @@ -0,0 +1,19 @@ +#include "utils/containers/compare_by.h" +#include "utils/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("compare_by") { + std::vector input = {"abc", "a", "ab"}; + auto comp = compare_by( + [](std::string const &s) { return s.length(); }); + std::vector correct = {"a", "ab", "abc"}; + std::sort(input.begin(), input.end(), comp); + CHECK(correct == input); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2be5f1ef93..25eece0d45 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,10 +1,16 @@ #include "utils/containers/enumerate.h" #include "utils/containers/as_vector.h" +#include "utils/containers/keys.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/values.h" #include "utils/fmt/map.h" #include "utils/fmt/pair.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/unordered_set.h" #include "utils/fmt/vector.h" #include #include +#include using namespace ::FlexFlow; @@ -40,11 +46,12 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("enumerate(std::unordered_set)") { std::unordered_set input = {"zero", "one", "two", "three"}; - std::map correct = { - {0, "zero"}, - {1, "one"}, - {2, "two"}, - {3, "three"}, - }; + std::unordered_set correct_keys = {0, 1, 2, 3}; + std::unordered_multiset correct_values = { + "zero", "one", "two", "three"}; + std::map result = enumerate(input); + + CHECK(keys(result) == correct_keys); + CHECK(unordered_multiset_of(values(result)) == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/filter_keys.cc b/lib/utils/test/src/utils/containers/filter_keys.cc index 3f1c26c4f5..c7a1e8ab0d 100644 --- a/lib/utils/test/src/utils/containers/filter_keys.cc +++ b/lib/utils/test/src/utils/containers/filter_keys.cc @@ -1,4 +1,5 @@ #include "utils/containers/filter_keys.h" +#include "utils/fmt/unordered_map.h" #include #include #include @@ -9,9 +10,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("filter_keys") { std::unordered_map m = { {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function + auto f = [](int x) { return x % 2 == 1; }; std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); + std::unordered_map correct = {{1, "one"}, {3, "three"}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc index b965182853..60260bcb95 100644 --- a/lib/utils/test/src/utils/containers/find.cc +++ b/lib/utils/test/src/utils/containers/find.cc @@ -1,12 +1,33 @@ #include "utils/containers/find.h" -#include +#include "test/utils/doctest.h" +#include +#include +#include +#include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); + + SUBCASE("vector") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + CHECK_WITHOUT_STRINGIFY(find(v, 6) == std::find(v.begin(), v.end(), 6)); + } + + SUBCASE("unordered_set") { + std::unordered_set us = {1, 2, 3, 4, 5}; + CHECK_WITHOUT_STRINGIFY(find(us, 3) == + std::find(us.begin(), us.end(), 3)); + CHECK_WITHOUT_STRINGIFY(find(us, 6) == + std::find(us.begin(), us.end(), 6)); + } + + SUBCASE("set") { + std::set s = {1, 2, 3, 4, 5}; + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } } } diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc new file mode 100644 index 0000000000..e395ce5583 --- /dev/null +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -0,0 +1,26 @@ +#include "utils/containers/flatmap.h" +#include "utils/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Test for flatmap function on vectors") { + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } + } + return factors; + }; + + std::vector input = {2, 3, 4, 5}; + std::vector result = flatmap(input, get_factors); + std::vector correct = {1, 2, 1, 3, 1, 2, 4, 1, 5}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/get_first.cc b/lib/utils/test/src/utils/containers/get_one_of.cc similarity index 64% rename from lib/utils/test/src/utils/containers/get_first.cc rename to lib/utils/test/src/utils/containers/get_one_of.cc index d5df0f7c7a..840bb47ebd 100644 --- a/lib/utils/test/src/utils/containers/get_first.cc +++ b/lib/utils/test/src/utils/containers/get_one_of.cc @@ -1,12 +1,12 @@ -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/contains.h" #include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_first") { + TEST_CASE("get_one_of") { std::unordered_set s = {1, 2, 3}; - CHECK(contains(s, get_first(s))); + CHECK(contains(s, get_one_of(s))); } } diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..d522b91d62 --- /dev/null +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -0,0 +1,22 @@ +#include "utils/containers/index_of.h" +#include "utils/fmt/optional.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("index_of") { + SUBCASE("unique elements") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(index_of(v, 3).value() == 2); + CHECK(index_of(v, 6) == std::nullopt); + } + SUBCASE("duplicate elements") { + std::vector v = {1, 2, 3, 4, 3, 5}; + CHECK(index_of(v, 3).value() == 2); + CHECK(index_of(v, 6) == std::nullopt); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_submap.cc b/lib/utils/test/src/utils/containers/is_submap.cc new file mode 100644 index 0000000000..18ee8c677a --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_submap.cc @@ -0,0 +1,40 @@ +#include "utils/containers/is_submap.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_submap") { + std::unordered_map super = { + {1, "one"}, {2, "two"}, {3, "three"}}; + + SUBCASE("keys and values match") { + std::unordered_map sub = {{1, "one"}, {2, "two"}}; + CHECK(is_submap(super, sub)); + } + + SUBCASE("keys and values don't match") { + std::unordered_map sub = {{1, "one"}, {4, "four"}}; + CHECK_FALSE(is_submap(super, sub)); + } + + SUBCASE("keys match but values don't") { + std::unordered_map sub = {{1, "wrong_value"}, + {2, "two"}}; + CHECK_FALSE(is_submap(super, sub)); + } + + SUBCASE("values match but keys don't") { + std::unordered_map sub = {{5, "one"}, {6, "two"}}; + CHECK_FALSE(is_submap(super, sub)); + } + + SUBCASE("sub is a superset of super") { + std::unordered_map sub = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK_FALSE(is_submap(super, sub)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_subset_eq.cc b/lib/utils/test/src/utils/containers/is_subset_eq_of.cc similarity index 100% rename from lib/utils/test/src/utils/containers/is_subset_eq.cc rename to lib/utils/test/src/utils/containers/is_subset_eq_of.cc diff --git a/lib/utils/test/src/utils/containers/is_superseteq_of.cc b/lib/utils/test/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..e3b429fa64 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1,25 @@ +#include "utils/containers/is_superseteq_of.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_superseteq_of") { + std::unordered_set super = {1, 2, 3, 4}; + + SUBCASE("true containment") { + std::unordered_set sub = {1, 2, 3}; + CHECK(is_superseteq_of(super, sub)); + } + + SUBCASE("false containment") { + std::unordered_set sub = {1, 2, 5}; + CHECK_FALSE(is_superseteq_of(super, sub)); + } + + SUBCASE("reflexive") { + CHECK(is_superseteq_of(super, super)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc index 2f162ecaf1..1244f5c292 100644 --- a/lib/utils/test/src/utils/containers/keys.cc +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -1,4 +1,5 @@ #include "utils/containers/keys.h" +#include "utils/fmt/unordered_set.h" #include #include #include @@ -11,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_map m = { {1, "one"}, {2, "two"}, {3, "three"}}; std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; + std::unordered_set expected = {1, 2, 3}; CHECK(result == expected); } } diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc index 189840f318..971a2017e7 100644 --- a/lib/utils/test/src/utils/containers/map_keys.cc +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -1,4 +1,5 @@ #include "utils/containers/map_keys.h" +#include "utils/fmt/unordered_map.h" #include #include #include @@ -8,10 +9,9 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("map_keys") { std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); + auto f = [](int x) { return x * x; }; + std::unordered_map result = map_keys(m, f); + std::unordered_map correct = {{1, "one"}, {4, "two"}}; + CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/containers/map_values.cc b/lib/utils/test/src/utils/containers/map_values.cc index f854f17fc6..c4c9fcafe6 100644 --- a/lib/utils/test/src/utils/containers/map_values.cc +++ b/lib/utils/test/src/utils/containers/map_values.cc @@ -1,4 +1,5 @@ #include "utils/containers/map_values.h" +#include "utils/fmt/unordered_map.h" #include #include #include @@ -7,10 +8,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function + std::unordered_map m = {{1, "one"}, {3, "three"}}; + auto f = [](std::string const &s) { return s.size(); }; std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); + std::unordered_map correct = {{1, 3}, {3, 5}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..fd695553cd --- /dev/null +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -0,0 +1,25 @@ +#include "utils/containers/maximum.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("maximum") { + + SUBCASE("non-empty container") { + std::vector input = {1, 5, 3, 4, 2}; + int correct = 5; + int result = maximum(input); + CHECK(correct == result); + } + + SUBCASE("empty container") { + std::vector input = {}; + + CHECK_THROWS_AS(maximum(input), std::runtime_error); + } + } +} diff --git a/lib/utils/test/src/utils/containers/merge_maps.cc b/lib/utils/test/src/utils/containers/merge_maps.cc new file mode 100644 index 0000000000..0a1fbd78a4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_maps.cc @@ -0,0 +1,30 @@ +#include "utils/containers/merge_maps.h" +#include "utils/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_maps") { + + SUBCASE("disjoint keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS_AS(merge_maps(lhs, rhs), std::runtime_error); + } + } +} diff --git a/lib/utils/test/src/utils/containers/optional_all_of.cc b/lib/utils/test/src/utils/containers/optional_all_of.cc new file mode 100644 index 0000000000..2c7113d0da --- /dev/null +++ b/lib/utils/test/src/utils/containers/optional_all_of.cc @@ -0,0 +1,39 @@ +#include "utils/containers/optional_all_of.h" +#include "utils/fmt/optional.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("optional_all_of") { + std::vector input = {2, 4, 6, 8}; + + SUBCASE("All values satisfy condition") { + auto f = [](int x) -> std::optional { return x % 2 == 0; }; + std::optional correct = true; + std::optional result = optional_all_of(input, f); + CHECK(correct == result); + } + + SUBCASE("One value returns nullopt") { + auto f = [](int x) -> std::optional { + if (x == 6) { + return std::nullopt; + } + return x % 2 == 0; + }; + std::optional correct = std::nullopt; + std::optional result = optional_all_of(input, f); + CHECK(correct == result); + } + + SUBCASE("Empty container") { + auto f = [](int x) -> std::optional { return true; }; + std::optional correct = true; + std::optional result = optional_all_of(input, f); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/product.cc b/lib/utils/test/src/utils/containers/product.cc new file mode 100644 index 0000000000..3beed9d15c --- /dev/null +++ b/lib/utils/test/src/utils/containers/product.cc @@ -0,0 +1,40 @@ +#include "utils/containers/product.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("product") { + SUBCASE("non-empty container") { + std::vector input = {1, -2, 3, 5}; + int correct = -30; + int result = product(input); + CHECK(correct == result); + } + + SUBCASE("empty container") { + std::vector input = {}; + int correct = 1; + int result = product(input); + CHECK(correct == result); + } + } + + TEST_CASE("product_where") { + SUBCASE("non-empty filtered container") { + std::vector input = {1, -2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = -8; + int result = product_where(input, condition); + CHECK(correct == result); + } + SUBCASE("empty filtered container") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/restrict_keys.cc b/lib/utils/test/src/utils/containers/restrict_keys.cc index bacda218da..169386d899 100644 --- a/lib/utils/test/src/utils/containers/restrict_keys.cc +++ b/lib/utils/test/src/utils/containers/restrict_keys.cc @@ -1,9 +1,7 @@ #include "utils/containers/restrict_keys.h" +#include "utils/fmt/unordered_map.h" #include #include -#include -#include - using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { @@ -12,7 +10,7 @@ TEST_SUITE(FF_TEST_SUITE) { {1, "one"}, {2, "two"}, {3, "three"}}; std::unordered_set mask = {2, 3, 4}; std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); + std::unordered_map correct = {{2, "two"}, {3, "three"}}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc index 8e5ce1804a..4d71752678 100644 --- a/lib/utils/test/src/utils/containers/reversed.cc +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -1,4 +1,5 @@ #include "utils/containers/reversed.h" +#include "utils/fmt/vector.h" #include #include @@ -8,9 +9,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Testing the 'reversed' function") { std::vector input_vec = {1, 2, 3, 4, 5}; std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; + std::vector correct = {5, 4, 3, 2, 1}; - // Checking the reversed sequence is as expected - CHECK(result == expected); + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/set_union.cc b/lib/utils/test/src/utils/containers/set_union.cc index d69cb64e00..92b53f0845 100644 --- a/lib/utils/test/src/utils/containers/set_union.cc +++ b/lib/utils/test/src/utils/containers/set_union.cc @@ -1,4 +1,5 @@ #include "utils/containers/set_union.h" +#include "utils/fmt/unordered_set.h" #include #include @@ -9,7 +10,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set s1 = {1, 2, 3}; std::unordered_set s2 = {2, 3, 4}; std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); + std::unordered_set correct = {1, 2, 3, 4}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc index a2e11f96e5..5fae04bd91 100644 --- a/lib/utils/test/src/utils/containers/sorted_by.cc +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -1,4 +1,5 @@ #include "utils/containers/sorted_by.h" +#include "utils/fmt/vector.h" #include #include #include @@ -6,13 +7,21 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); + TEST_CASE("sorted_by") { + SUBCASE("sort increasing") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, [](int a, int b) { return a < b; }); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); + SUBCASE("sort decreasing") { + std::unordered_set input = {-5, -1, -3, -2, -4}; + std::vector result = + sorted_by(input, [](int a, int b) { return a > b; }); + std::vector correct = {-1, -2, -3, -4, -5}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc index 38668ea20d..5262ab3119 100644 --- a/lib/utils/test/src/utils/containers/subvec.cc +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -1,17 +1,66 @@ #include "utils/containers/subvec.h" +#include "utils/fmt/vector.h" #include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Testing subvec function") { + TEST_CASE("subvec") { std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, std::optional(1), std::optional(4)); - CHECK(subvec_v == std::vector({2, 3, 4})); + SUBCASE("Basic subvector") { + auto result = subvec(v, 1, 4); + std::vector correct = {2, 3, 4}; + CHECK(result == correct); + } - auto subvec_v2 = subvec(v, std::nullopt, std::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); + SUBCASE("From beginning to index") { + auto result = subvec(v, std::nullopt, 3); + std::vector correct = {1, 2, 3}; + CHECK(result == correct); + } + + SUBCASE("From index to end") { + auto result = subvec(v, 2, std::nullopt); + std::vector correct = {3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("All of the vector") { + auto result = subvec(v, std::nullopt, std::nullopt); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("Start greater than end") { + auto result = subvec(v, 3, 1); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Start equal to end") { + auto result = subvec(v, 3, 3); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Negative indices") { + auto result = subvec(v, -3, -1); + std::vector correct = {3, 4}; + CHECK(result == correct); + } + + SUBCASE("Out of bounds index from above") { + auto result = subvec(v, 2, 100); + std::vector correct = {3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("Out of bounds index from below") { + auto result = subvec(v, -100, 2); + std::vector correct = {1, 2}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc new file mode 100644 index 0000000000..f18a4d86c7 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -0,0 +1,41 @@ +#include "utils/containers/sum.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum") { + SUBCASE("non-empty container") { + std::vector input = {1, 2, 3, 4, 5}; + int correct = 15; + int result = sum(input); + CHECK(correct == result); + } + SUBCASE("empty container") { + std::unordered_set input = {}; + int correct = 0; + int result = sum(input); + CHECK(correct == result); + } + } + + TEST_CASE("sum_where") { + SUBCASE("Resulting container is non-empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 6; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("Resulting container is empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/value_all.cc b/lib/utils/test/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..620acc5dcd --- /dev/null +++ b/lib/utils/test/src/utils/containers/value_all.cc @@ -0,0 +1,24 @@ +#include "utils/containers/value_all.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("value_all") { + SUBCASE("With nullopt") { + std::vector> input = {1, 2, std::nullopt, 4, 5}; + CHECK_THROWS_AS(value_all(input), std::runtime_error); + } + + SUBCASE("Without nullopt") { + std::vector> input = {1, 2, 3, 4, 5}; + std::vector correct = {1, 2, 3, 4, 5}; + std::vector result = value_all(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/values.cc b/lib/utils/test/src/utils/containers/values.cc index b921bcb0cc..245622c4ab 100644 --- a/lib/utils/test/src/utils/containers/values.cc +++ b/lib/utils/test/src/utils/containers/values.cc @@ -1,4 +1,6 @@ #include "utils/containers/values.h" +#include "utils/containers/sorted.h" +#include "utils/fmt/vector.h" #include #include #include @@ -9,9 +11,9 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("values") { std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; + {1, "one"}, {2, "two"}, {3, "three"}, {33, "three"}}; std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); + std::vector correct = {"one", "two", "three", "three"}; + CHECK(sorted(result) == sorted(correct)); } } diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc index 872742a429..eb51ff1ab5 100644 --- a/lib/utils/test/src/utils/containers/vector_split.cc +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -1,5 +1,7 @@ #include "utils/containers/vector_split.h" +#include "utils/fmt/vector.h" #include +#include #include using namespace FlexFlow; @@ -7,10 +9,31 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Testing vector_split function") { std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); + + SUBCASE("Normal case: idx = 2") { + auto [prefix, postfix] = vector_split(v, 2); + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } + + SUBCASE("Boundary case: idx = 0") { + auto [prefix, postfix] = vector_split(v, 0); + CHECK(prefix.empty()); + CHECK(postfix == std::vector({1, 2, 3, 4, 5})); + } + + SUBCASE("Boundary case: idx = 5") { + auto [prefix, postfix] = vector_split(v, 5); + CHECK(prefix == std::vector({1, 2, 3, 4, 5})); + CHECK(postfix.empty()); + } + + SUBCASE("Out of bounds case: idx = -1") { + CHECK_THROWS_AS(vector_split(v, -1), std::out_of_range); + } + + SUBCASE("Out of bounds case: idx = 6") { + CHECK_THROWS_AS(vector_split(v, 6), std::out_of_range); + } } } diff --git a/lib/utils/test/src/utils/disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc index dc5c1bb3cb..5e5d118222 100644 --- a/lib/utils/test/src/utils/disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,6 +1,5 @@ #include "utils/disjoint_set.h" #include "test/utils/doctest.h" - using namespace FlexFlow; template @@ -22,10 +21,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("SingleElementSets") { std::optional element = generate_element(1); - CHECK(ds.find(element) == element); + CHECK_WITHOUT_STRINGIFY(ds.find(element) == element); element = generate_element(2); - CHECK(ds.find(element) == element); + CHECK_WITHOUT_STRINGIFY(ds.find(element) == element); } SUBCASE("UnionAndFind") { @@ -35,16 +34,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional element4 = generate_element(4); ds.m_union(element1, element2); - CHECK(ds.find(element1) == ds.find(element2)); + CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element2)); ds.m_union(element3, element4); - CHECK(ds.find(element3) == ds.find(element4)); + CHECK_WITHOUT_STRINGIFY(ds.find(element3) == ds.find(element4)); ds.m_union(element1, element3); - CHECK(ds.find(element1) == ds.find(element3)); - CHECK(ds.find(element2) == ds.find(element4)); - CHECK(ds.find(element1) == ds.find(element2)); - CHECK(ds.find(element1) == ds.find(element4)); + CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element3)); + CHECK_WITHOUT_STRINGIFY(ds.find(element2) == ds.find(element4)); + CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element2)); + CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element4)); } } @@ -62,8 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { mapping = ds.get_mapping(); for (auto const &kv : mapping) { - CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values - // inside the optionals + CHECK_WITHOUT_STRINGIFY( + *kv.second == *expectedMapping[kv.first]); // Compare the values + // inside the optionals } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms.cc index 47abd5186d..0817c69e06 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms.cc @@ -13,29 +13,94 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 4); std::vector e = { - DirectedEdge{n[0], n[3]}, DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, DirectedEdge{n[1], n[2]}, }; add_edges(g, e); SUBCASE("get_edges") { - std::unordered_set expected_edges = unordered_set_of(e); - std::unordered_set actual_edges = get_edges(g); - CHECK(actual_edges == expected_edges); + SUBCASE("Base") { + std::unordered_set correct = unordered_set_of(e); + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[3], n[1]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge") { + g.remove_edge(DirectedEdge{n[0], n[3]}); + std::unordered_set correct = { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + }; + std::unordered_set result = get_edges(g); + CHECK(result == correct); + } } SUBCASE("get_sinks") { - std::unordered_set expected_sinks = {n[2], n[3]}; - std::unordered_set actual_sinks = get_sinks(g); - CHECK(actual_sinks == expected_sinks); + SUBCASE("Base") { + std::unordered_set correct = {n[2], n[3]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a sink") { + g.add_edge(DirectedEdge{n[3], n[2]}); + std::unordered_set correct = {n[2]}; + std::unordered_set result = get_sinks(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sinks(g); + std::unordered_set correct = {n[3]}; + CHECK(result == correct); + } } SUBCASE("get_sources") { - std::unordered_set expected_sources = {n[0]}; - std::unordered_set actual_sources = get_sources(g); - CHECK(actual_sources == expected_sources); + SUBCASE("Base") { + std::unordered_set correct = {n[0]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Adding an edge to remove a source") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set correct = {}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Removing an edge to create a new source") { + g.remove_edge(DirectedEdge{n[0], n[1]}); + std::unordered_set correct = {n[0], n[1]}; + std::unordered_set result = get_sources(g); + CHECK(result == correct); + } + + SUBCASE("Creating a cycle") { + g.add_edge(DirectedEdge{n[2], n[0]}); + std::unordered_set result = get_sources(g); + std::unordered_set correct = {}; + CHECK(result.empty()); + } } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc index 25c2384959..056280df8e 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -12,22 +12,53 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 4); std::vector e = { - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, }; add_edges(g, e); SUBCASE("single node") { - std::unordered_set expected_dominators = {n[0], n[2]}; - CHECK(get_dominators(g, n[2]) == expected_dominators); + Node node = n[2]; + std::unordered_set correct = {n[0], n[2]}; + std::unordered_set result = get_dominators(g, node); + CHECK(correct == result); } SUBCASE("multiple nodes") { std::unordered_set nodes = {n[1], n[3]}; - std::unordered_set expected_dominators = {n[0]}; - CHECK(get_dominators(g, nodes) == expected_dominators); + std::unordered_set result = get_dominators(g, nodes); + std::unordered_set correct = {n[0]}; + CHECK(correct == result); + } + + SUBCASE("graph with cycles") { + // example from + // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 + + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 6); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(1)}, + }); + + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + + result = get_dominators(g, n.at(3)); + correct = {n.at(0), n.at(1), n.at(3)}; + + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc index c17d2b8fd7..fe29fa4ef5 100644 --- a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc @@ -14,41 +14,59 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n = add_nodes(g, 5); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[4]}, - DirectedEdge{n[1], n[3]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}}); SUBCASE("directed_edge_query_all") { DirectedEdgeQuery result = directed_edge_query_all(); - CHECK(matches_edge(result, DirectedEdge{n[0], n[1]})); - CHECK(matches_edge(result, DirectedEdge{n[0], n[2]})); - CHECK(matches_edge(result, DirectedEdge{n[1], n[2]})); - CHECK(matches_edge(result, DirectedEdge{n[2], n[4]})); - CHECK(matches_edge(result, DirectedEdge{n[1], n[3]})); + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); + CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); } SUBCASE("matches_edge") { - DirectedEdgeQuery q{{n[0]}, {n[1]}}; + DirectedEdgeQuery q = + DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; - CHECK(matches_edge(q, DirectedEdge{n[0], n[1]})); - CHECK_FALSE(matches_edge(q, DirectedEdge{n[1], n[2]})); + CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); + CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); } SUBCASE("query_intersection") { - DirectedEdgeQuery q1{{n[0], n[1]}, {n[1], n[2], n[4]}}; - DirectedEdgeQuery q2{{n[1], n[2]}, {n[2], n[3]}}; + SUBCASE("standard intersection") { + DirectedEdgeQuery q1 = DirectedEdgeQuery{ + query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; + DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, + query_set{n.at(2), n.at(3)}}; - DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery result = query_intersection(q1, q2); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1)}, + query_set{n.at(2)}, + }; - std::unordered_set expected_srcs = {n[1]}; - std::unordered_set expected_dsts = {n[2]}; + CHECK(result == correct); + } + SUBCASE("intersection with std::nullopt") { + DirectedEdgeQuery q1 = + DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; + DirectedEdgeQuery q2 = + DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; - CHECK(result.srcs == expected_srcs); - CHECK(result.dsts == expected_dsts); + DirectedEdgeQuery result = query_intersection(q1, q2); + + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); + // CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(2), n.at(3)})); + // CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(1), n.at(4)})); + } } } } diff --git a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc index 9b5c243910..954b6bd3b8 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc +++ b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc @@ -1,6 +1,6 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "test/utils/doctest.h" -#include "utils/containers.h" +#include "utils/containers/index_of.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" @@ -10,18 +10,17 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_topological_ordering") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); - std::vector edges = {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[5]}, - DirectedEdge{n[2], n[3]}, - DirectedEdge{n[3], n[4]}, - DirectedEdge{n[4], n[5]}}; + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}}; add_edges(g, edges); std::vector ordering = get_topological_ordering(g); auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); }; CHECK(ordering.size() == n.size()); diff --git a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc index c48b5e6cdf..8b97c48a03 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc +++ b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc @@ -5,20 +5,19 @@ #include "utils/graph/instances/adjacency_digraph.h" TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE(FF_TEST_SUITE) { + TEST_CASE("get_weakly_connected_components") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); - std::vector edges = {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[2], n[1]}}; + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[2], n[1]}}); - add_edges(g, edges); - std::unordered_set> expected_components = { + std::unordered_set> correct = { {n[0], n[1], n[2]}, {n[3]}, }; + std::unordered_set> result = + get_weakly_connected_components(g); - CHECK(get_weakly_connected_components(g) == expected_components); + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index 59141513c3..6a74e4ce6e 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -11,21 +11,29 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MultiDiGraph - get_incoming_edges") { + TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); std::vector> input = { + {n.at(0), n.at(0)}, {n.at(0), n.at(1)}, {n.at(0), n.at(1)}, - {n.at(1), n.at(1)}, - {n.at(0), n.at(0)}, + {n.at(1), n.at(0)}, }; std::vector edges = add_edges(g, input); - CHECK(get_incoming_edges(g, n[1]) == - std::unordered_set{edges[0], edges[1], edges[2]}); - CHECK(get_incoming_edges(g, n[2]) == std::unordered_set{}); + SUBCASE("node has incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(1)); + std::unordered_set correct = {edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_incoming_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 428d6589f1..f78311315d 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -11,21 +11,30 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MultiDiGraph - get_outgoing_edges") { + TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); std::vector> input = { + {n.at(0), n.at(0)}, {n.at(0), n.at(1)}, {n.at(0), n.at(1)}, - {n.at(1), n.at(1)}, - {n.at(0), n.at(0)}, + {n.at(1), n.at(0)}, }; std::vector edges = add_edges(g, input); - CHECK(get_outgoing_edges(g, n[0]) == - std::unordered_set{edges[0], edges[1], edges[3]}); - CHECK(get_outgoing_edges(g, n[2]) == std::unordered_set{}); + SUBCASE("node has incoming edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(0)); + std::unordered_set correct = { + edges.at(0), edges.at(1), edges.at(2)}; + CHECK(result == correct); + } + + SUBCASE("node has no incoming edges") { + std::unordered_set result = get_outgoing_edges(g, n.at(2)); + std::unordered_set correct = {}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc index 5a4e3649bd..def1ca12b9 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -16,76 +16,165 @@ TEST_SUITE(FF_TEST_SUITE) { Node n0 = g.add_node(); Node n1 = g.add_node(); Node n2 = g.add_node(); - - MultiDiEdge e0 = g.add_edge(n1, n0); - MultiDiEdge e1 = g.add_edge(n2, n0); - MultiDiEdge e2 = g.add_edge(n0, n2); + MultiDiEdge e0 = g.add_edge(n0, n2); + MultiDiEdge e1 = g.add_edge(n1, n0); + MultiDiEdge e2 = g.add_edge(n1, n0); MultiDiEdge e3 = g.add_edge(n1, n2); + MultiDiEdge e4 = g.add_edge(n1, n2); + MultiDiEdge e5 = g.add_edge(n2, n0); + MultiDiEdge e6 = g.add_edge(n2, n2); SUBCASE("add_node") { Node n3 = g.add_node(); - std::unordered_set nodes = g.query_nodes(NodeQuery{{n3}}); - CHECK(contains(nodes, n3)); + std::unordered_set result = g.query_nodes(NodeQuery{{n3}}); + std::unordered_set correct = {n3}; + CHECK(result == correct); } SUBCASE("add_edge") { - MultiDiEdge e4 = g.add_edge(n2, n1); - std::unordered_set edges = + MultiDiEdge e7 = g.add_edge(n2, n1); + std::unordered_set result = g.query_edges(MultiDiEdgeQuery({n2}, {n1})); - CHECK(contains(edges, e4)); + std::unordered_set correct = {e7}; + CHECK(result == correct); + } + + SUBCASE("add_edges") { + MultiDiEdge e7 = g.add_edge(n1, n2); + MultiDiEdge e8 = g.add_edge(n2, n1); + + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1}, {n2})); + std::unordered_set correct = {e0, e3, e4, e7}; + CHECK(result == correct); } SUBCASE("remove_node") { g.remove_node(n0); - std::unordered_set nodes = g.query_nodes(NodeQuery{{n0}}); - CHECK_FALSE(contains(nodes, n0)); - std::unordered_set edges_after_removal = + std::unordered_set node_result = g.query_nodes(NodeQuery{{n0}}); + std::unordered_set node_correct = {}; + CHECK(node_result == node_correct); + + std::unordered_set edge_result = g.query_edges(MultiDiEdgeQuery({n0}, {n0})); - CHECK_FALSE(contains(edges_after_removal, e0)); - CHECK_FALSE(contains(edges_after_removal, e1)); - CHECK_FALSE(contains(edges_after_removal, e2)); + std::unordered_set edge_correct = {}; + CHECK(edge_result == edge_correct); } SUBCASE("remove_edge") { g.remove_edge(e3); - std::unordered_set edges = + std::unordered_set result = g.query_edges(MultiDiEdgeQuery({n1}, {n2})); - CHECK_FALSE(contains(edges, e3)); + std::unordered_set correct = {e4}; + CHECK(result == correct); } - SUBCASE("query_nodes") { - std::unordered_set all_nodes = - g.query_nodes(NodeQuery{{n0, n1, n2}}); - CHECK(all_nodes == std::unordered_set{n0, n1, n2}); + SUBCASE("remove_edges") { + g.remove_edge(e0); + g.remove_edge(e1); + g.remove_edge(e3); - std::unordered_set specific_nodes = - g.query_nodes(NodeQuery{{n0, n2}}); - CHECK(specific_nodes == std::unordered_set{n0, n2}); + SUBCASE("edges from n0 to n2") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0}, {n2})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + + SUBCASE("edges from n1 to n0") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0})); + std::unordered_set correct = {e2}; + CHECK(result == correct); + } + + SUBCASE("edges from n1 to n2") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n2})); + std::unordered_set correct = {e4}; + CHECK(result == correct); + } } - SUBCASE("query_edges") { - std::unordered_set all_edges = - g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); - CHECK(all_edges == std::unordered_set{e0, e1, e2, e3}); - - std::unordered_set edges_from_n1 = - g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); - CHECK(edges_from_n1 == std::unordered_set{e0, e3}); - - std::unordered_set edges_to_n2 = - g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); - CHECK(edges_to_n2 == std::unordered_set{e2, e3}); + SUBCASE("query_nodes") { + SUBCASE("all nodes") { + std::unordered_set result = + g.query_nodes(NodeQuery{{n0, n1, n2}}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("specific nodes") { + std::unordered_set result = g.query_nodes(NodeQuery{{n0, n2}}); + std::unordered_set correct = {n0, n2}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_nodes(NodeQuery{matchall()}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("nodes not in graph") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = g.query_nodes(NodeQuery{{n3, n4}}); + std::unordered_set correct = {}; + CHECK(result == correct); + } } + SUBCASE("query_edges") { + SUBCASE("all edges") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("edges from n1") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); + std::unordered_set correct = {e1, e2, e3, e4}; + CHECK(result == correct); + } + + SUBCASE("edges to n2") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); + std::unordered_set correct = {e0, e3, e4, e6}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery(matchall(), matchall())); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("nodes that don't exist") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1, n3}, {n4})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } SUBCASE("get_multidiedge_src") { - CHECK(g.get_multidiedge_src(e0) == n1); - CHECK(g.get_multidiedge_dst(e0) == n0); + Node result = g.get_multidiedge_src(e0); + Node correct = n0; + CHECK(result == correct); } SUBCASE("get_multidiedge_dst") { - CHECK(g.get_multidiedge_src(e2) == n0); - CHECK(g.get_multidiedge_dst(e2) == n2); + Node result = g.get_multidiedge_dst(e0); + Node correct = n2; + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/graph/traversal.cc b/lib/utils/test/src/utils/graph/traversal.cc index 185712cf8f..66daaf8400 100644 --- a/lib/utils/test/src/utils/graph/traversal.cc +++ b/lib/utils/test/src/utils/graph/traversal.cc @@ -1,45 +1,110 @@ #include "utils/graph/traversal.h" #include "test/utils/doctest.h" +#include "utils/fmt/vector.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/hash/vector.h" TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("traversal") { + TEST_CASE("get_unchecked_dfs_ordering") { DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}; - add_edges(g, edges); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + } - SUBCASE("with root") { - g.add_edge(DirectedEdge{n[3], n[2]}); + TEST_CASE("get_bfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}}); - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); + SUBCASE("branching path") { + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3], n[4], n[5]}, + {n[0], n[2], n[1], n[3], n[4], n[5]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); } - SUBCASE("without root") { - g.add_edge(DirectedEdge{n[3], n[0]}); + SUBCASE("isolated node") { + std::vector correct = {n[5]}; + std::vector result = get_bfs_ordering(g, {n[5]}); + CHECK(correct == result); + } - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); + SUBCASE("graph with cycle") { + g = DiGraph::create(); + n = add_nodes(g, 3); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[0]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[0]}, + DirectedEdge{n[2], n[1]}}); + std::unordered_set> corrects = {{n[0], n[1], n[2]}, + {n[0], n[2], n[1]}}; + std::vector result = get_bfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); } - SUBCASE("nonlinear") { + } + + TEST_CASE("get_dfs_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + SUBCASE("simple path") { + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("with cycle") { + g.add_edge(DirectedEdge{n[3], n[1]}); + std::vector correct = {n[0], n[1], n[2], n[3]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); + } + + SUBCASE("branching") { g.add_edge(DirectedEdge{n[1], n[3]}); + std::unordered_set> corrects = { + {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(contains(corrects, result)); + } + + SUBCASE("disconnected") { + g.remove_edge(DirectedEdge{n[2], n[3]}); + std::vector correct = {n[0], n[1], n[2]}; + std::vector result = get_dfs_ordering(g, {n[0]}); + CHECK(correct == result); } - SUBCASE("not connected") { + SUBCASE("isolated node") { g.remove_edge(DirectedEdge{n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + std::vector correct = {n[3]}; + std::vector result = get_dfs_ordering(g, {n[3]}); + CHECK(correct == result); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc similarity index 62% rename from lib/utils/test/src/utils/graph/digraph/get_connected_components.cc rename to lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc index d9fd9e81af..20d2bfbf34 100644 --- a/lib/utils/test/src/utils/graph/digraph/get_connected_components.cc +++ b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc @@ -1,5 +1,6 @@ #include "utils/graph/undirected/algorithms/get_connected_components.h" #include "test/utils/doctest.h" +#include "utils/fmt/unordered_set.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/hashmap_undirected_graph.h" #include "utils/graph/undirected/undirected_graph.h" @@ -9,14 +10,15 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_connected_components") { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); - add_edges(g, edges); - std::unordered_set> expected_components = { + std::unordered_set> correct = { {n[0], n[1], n[2]}, {n[3]}, }; + std::unordered_set> result = + get_connected_components(g); - CHECK(get_connected_components(g) == expected_components); + CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc index 770f8395c8..b11cdd4753 100644 --- a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc +++ b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc @@ -46,8 +46,9 @@ TEST_SUITE(FF_TEST_SUITE) { if (num_nodes > 0) { e = *gen::unique>( num_edges, - gen::construct(gen::elementOf(n), - gen::elementOf(n))); + gen::construct( + gen::construct>(gen::elementOf(n), + gen::elementOf(n)))); } for (UndirectedEdge const &edge : e) { g.add_edge(edge); diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index 29c63825a2..e6cd562d8f 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -18,15 +18,14 @@ TEST_SUITE(FF_TEST_SUITE) { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 5); add_edges(g, - {{n[0], n[3]}, - {n[1], n[1]}, - {n[1], n[2]}, - {n[1], n[3]}, - {n[2], n[3]}, - {n[2], n[4]}}); + {UndirectedEdge{{n[0], n[3]}}, + UndirectedEdge{{n[1], n[1]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[1], n[3]}}, + UndirectedEdge{{n[2], n[3]}}, + UndirectedEdge{{n[2], n[4]}}}); std::unordered_set sub_nodes = {n[0], n[1], n[3]}; - UndirectedGraphView view = - UndirectedGraphView::create(g, sub_nodes); + UndirectedGraphView view = get_subgraph(g, sub_nodes); SUBCASE("get_nodes") { std::unordered_set expected = {n[0], n[1], n[3]}; @@ -38,14 +37,15 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - {n[0], n[3]}, - {n[1], n[1]}, - {n[1], n[3]}, + UndirectedEdge{{n[0], n[3]}}, + UndirectedEdge{{n[1], n[1]}}, + UndirectedEdge{{n[1], n[3]}}, }; std::unordered_set result = get_edges(view); - // CHECK(result == expected); + // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also + CHECK(result == expected); } } @@ -62,7 +62,7 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n[3], n[2]}, DirectedEdge{n[2], n[4]}}); std::unordered_set sub_nodes = {n[0], n[1], n[3]}; - DiGraphView view = DiGraphView::create(g, sub_nodes); + DiGraphView view = get_subgraph(g, sub_nodes); SUBCASE("get_nodes") { std::unordered_set expected = {n[0], n[1], n[3]}; @@ -93,11 +93,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n1 = add_nodes(g1, 3); std::vector n2 = add_nodes(g2, 3); - add_edges(g1, {{n1[0], n1[1]}, {n1[1], n1[2]}}); - add_edges(g2, {{n2[0], n2[2]}, {n2[1], n2[2]}}); + add_edges(g1, + {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); + add_edges(g2, + {UndirectedEdge{{n2[0], n2[2]}}, UndirectedEdge{{n2[1], n2[2]}}}); - UndirectedGraphView view = - UndirectedGraphView::create(g1, g2); + UndirectedGraphView view = join(g1, g2); // SUBCASE("get_nodes") { // std::unordered_set expected = @@ -130,7 +131,7 @@ TEST_SUITE(FF_TEST_SUITE) { add_edges(g1, {DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}}); add_edges(g2, {DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]}}); - DiGraphView view = DiGraphView::create(g1, g2); + DiGraphView view = join(g1, g2); // SUBCASE("get_nodes") { // std::unordered_set expected = set_union(unordered_set_of(n1), @@ -162,8 +163,7 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n[2], n[0]}, DirectedEdge{n[0], n[2]}}); - UndirectedGraphView view = - UndirectedGraphView::create(g); + UndirectedGraphView view = as_undirected(g); SUBCASE("get_nodes") { std::unordered_set expected = unordered_set_of(n); @@ -175,7 +175,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - {n[0], n[1]}, {n[1], n[2]}, {n[2], n[0]}}; + UndirectedEdge{{n[0], n[1]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[2], n[0]}}}; std::unordered_set result = get_edges(view); @@ -186,9 +188,13 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ViewUndirectedGraphAsDiGraph") { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 3); - add_edges(g, {{n[0], n[0]}, {n[0], n[1]}, {n[1], n[2]}, {n[2], n[0]}}); + add_edges(g, + {UndirectedEdge{{n[0], n[0]}}, + UndirectedEdge{{n[0], n[1]}}, + UndirectedEdge{{n[1], n[2]}}, + UndirectedEdge{{n[2], n[0]}}}); - DiGraphView view = DiGraphView::create(g); + DiGraphView view = as_digraph(g); SUBCASE("get_nodes") { std::unordered_set expected = unordered_set_of(n); diff --git a/lib/utils/test/src/utils/hash-utils.cc b/lib/utils/test/src/utils/hash-utils.cc index 2981615d9e..077037d36a 100644 --- a/lib/utils/test/src/utils/hash-utils.cc +++ b/lib/utils/test/src/utils/hash-utils.cc @@ -90,7 +90,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(hash1 != hash2); - tuple1 = tuple2; + std::get<0>(tuple1) = 2; hash1 = get_std_hash(tuple1); CHECK(hash1 == hash2); } diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc index 1377923ab6..ecb19c2f7d 100644 --- a/lib/utils/test/src/utils/join_strings.cc +++ b/lib/utils/test/src/utils/join_strings.cc @@ -6,6 +6,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("join_strings") { std::vector const v = {"Hello", "world", "!"}; @@ -16,5 +17,16 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("join_strings with container") { CHECK(join_strings(v, " ") == "Hello world !"); } + + SUBCASE("join_strings with transforming function") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + CHECK(join_strings(v, " ", add_exclamation) == "Hello! world! !!"); + } + + SUBCASE("join_strings with transforming function, iterator") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + CHECK(join_strings(v.begin(), v.end(), " ", add_exclamation) == + "Hello! world! !!"); + } } } diff --git a/lib/utils/test/src/utils/random_utils.cc b/lib/utils/test/src/utils/random_utils.cc index 1563de512c..ba994cefce 100644 --- a/lib/utils/test/src/utils/random_utils.cc +++ b/lib/utils/test/src/utils/random_utils.cc @@ -1,67 +1,56 @@ #include "utils/random_utils.h" #include "test/utils/doctest.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/containers/repeat.h" +#include "utils/containers/sum.h" +#include "utils/containers/zip.h" #include -void checkProbabilities(std::vector const &counts, - int numIterations, - std::vector const &weights, - float totalWeight) { - for (int i = 0; i < counts.size(); i++) { - float expectedProbability = weights[i] / totalWeight; - float observedProbability = static_cast(counts[i]) / numIterations; - CHECK(observedProbability == - doctest::Approx(expectedProbability).epsilon(0.01f)); - } -} - TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("select_random") { + TEST_CASE("select_random(std::vector)") { std::vector values = {1, 2, 3, 4, 5}; - SUBCASE("Select random value") { - int result = select_random(values); - - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } - - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); - } - } - - TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; + SUBCASE("selected value is in container") { + SUBCASE("equal weights") { + int result = select_random(values); + CHECK(contains(values, result)); } - checkProbabilities(counts, numIterations, weights, values.size()); + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.3f, 0.2f, 0.2f, 0.2f}; + int result = select_random(values, weights); + CHECK(contains(values, result)); + } } - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; + SUBCASE("correct distribution") { + auto check_probabilities = [](std::vector values, + std::vector const &weights) { + int num_iterations = 10'000; + std::vector trials = repeat( + num_iterations, [&]() { return select_random(values, weights); }); + + for (auto const [v, w] : zip(values, weights)) { + float expectedProbability = w / sum(weights); + int num_occurrences = + filter(trials, [&](int c) { return (c == v); }).size(); + float observedProbability = + static_cast(num_occurrences) / num_iterations; + CHECK(observedProbability == + doctest::Approx(expectedProbability).epsilon(0.01f)); + } + }; + + SUBCASE("equal weights") { + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + check_probabilities(values, weights); } - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + check_probabilities(values, weights); } - - checkProbabilities(counts, numIterations, weights, totalWeight); } } } diff --git a/lib/utils/test/src/utils/stack_map.cc b/lib/utils/test/src/utils/stack_map.cc index e03552e6fd..57ef1c850a 100644 --- a/lib/utils/test/src/utils/stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -46,7 +46,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; std::vector> actual = map; - CHECK(actual == expected); + CHECK_WITHOUT_STRINGIFY(actual == expected); } } } diff --git a/lib/utils/test/src/utils/stack_string.cc b/lib/utils/test/src/utils/stack_string.cc index 30e20c5fce..01bfa60d15 100644 --- a/lib/utils/test/src/utils/stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -49,11 +49,11 @@ TEST_SUITE(FF_TEST_SUITE) { StackString str2{"def"}; StackString str3{"abc"}; - CHECK(str1 == str1); - CHECK(str1 == str3); - CHECK(str1 != str2); - CHECK(str2 != str3); - CHECK(str1 < str2); + CHECK_WITHOUT_STRINGIFY(str1 == str1); + CHECK_WITHOUT_STRINGIFY(str1 == str3); + CHECK_WITHOUT_STRINGIFY(str1 != str2); + CHECK_WITHOUT_STRINGIFY(str2 != str3); + CHECK_WITHOUT_STRINGIFY(str1 < str2); } TEST_CASE_TEMPLATE("StackStringSize", T, char) { @@ -81,4 +81,11 @@ TEST_SUITE(FF_TEST_SUITE) { std::string stdStr = static_cast(str); CHECK(stdStr == "Hello"); } + + TEST_CASE("Arbitrary") { + constexpr std::size_t MAXSIZE = 10; + RC_SUBCASE([&](stack_string const &s) { + RC_ASSERT(s.size() <= MAXSIZE); + }); + } } diff --git a/lib/utils/test/src/utils/stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc index 5243dc45fe..775f088e28 100644 --- a/lib/utils/test/src/utils/stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,6 +1,7 @@ #include "utils/stack_vector.h" #include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include "utils/fmt/vector.h" #include using namespace FlexFlow; @@ -12,14 +13,14 @@ TEST_SUITE(FF_TEST_SUITE) { StackVector vector; vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); + std::vector result = vector; + std::vector correct = {10}; + CHECK(result == correct); vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); + correct = {10, 20}; + result = vector; + CHECK(result == correct); } TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { @@ -63,7 +64,7 @@ TEST_SUITE(FF_TEST_SUITE) { vector2.push_back(15); vector2.push_back(20); - CHECK(vector1 == vector2); + CHECK_WITHOUT_STRINGIFY(vector1 == vector2); } TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { diff --git a/lib/utils/test/src/utils/tuple.cc b/lib/utils/test/src/utils/tuple.cc index 2b8f4f65ad..bf822cf871 100644 --- a/lib/utils/test/src/utils/tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -43,7 +43,7 @@ TEST_SUITE(FF_TEST_SUITE) { auto result = tuple_prepend(value, t1); std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); + CHECK(tuple_compare(result, expected)); } TEST_CASE("Testing tuple_head_t") { diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc index 5cf7ca989f..873ed33ae7 100644 --- a/lib/utils/test/src/utils/type_index.cc +++ b/lib/utils/test/src/utils/type_index.cc @@ -9,13 +9,13 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("int type") { std::type_index idx = get_type_index_for_type(); std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); } SUBCASE("string type") { std::type_index idx = get_type_index_for_type(); std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); } } @@ -23,13 +23,11 @@ TEST_SUITE(FF_TEST_SUITE) { std::type_index idx = typeid(float); SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); + CHECK(matches(idx)); } SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); + CHECK_FALSE(matches(idx)); } } } diff --git a/lib/utils/test/src/utils/variant.cc b/lib/utils/test/src/utils/variant.cc index cae75f3045..9a55539ccf 100644 --- a/lib/utils/test/src/utils/variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,6 +1,8 @@ #include "utils/variant.h" #include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/variant.h" TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/utils/vector.cc b/lib/utils/test/src/utils/vector.cc index 873f9d7447..e4a8d742f0 100644 --- a/lib/utils/test/src/utils/vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,6 @@ #include "utils/vector.h" -#include "test/utils/doctest.h" +#include "utils/fmt/vector.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { From a371c027d43785a236e210051c243b1fe0c026f9 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Thu, 12 Sep 2024 06:32:04 -0700 Subject: [PATCH 14/21] minor fix --- .../test/src/utils/graph/digraph/directed_edge_query.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc index fe29fa4ef5..695731f1b6 100644 --- a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc @@ -64,8 +64,10 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); - // CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(2), n.at(3)})); - // CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(1), n.at(4)})); + CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(3)})); + CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(4)})); + CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(0), n.at(4)})); + CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(2), n.at(0)})); } } } From 475ace53df51a33147d8ff583128ea0f11527eac Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Mon, 23 Sep 2024 22:23:46 -0700 Subject: [PATCH 15/21] PR fixes --- lib/compiler/test/src/test_machine_mapping.cc | 4 +- .../unlabelled/multidigraph_pattern_match.cc | 55 ------- .../utils/containers/are_all_distinct.h | 17 +++ .../include/utils/containers/are_all_same.h | 4 +- .../include/utils/containers/get_one_of.h | 4 + lib/utils/include/utils/containers/index_of.h | 3 +- .../{is_submap.h => is_submapeq_of.h} | 4 +- .../include/utils/containers/is_subseteq_of.h | 10 +- .../utils/containers/is_superseteq_of.h | 6 +- lib/utils/include/utils/containers/map_keys.h | 14 +- lib/utils/include/utils/containers/maximum.h | 2 +- .../include/utils/containers/merge_maps.h | 6 +- .../utils/containers/optional_all_of.h | 26 ---- lib/utils/include/utils/containers/product.h | 13 -- .../include/utils/containers/product_where.h | 26 ++++ lib/utils/include/utils/containers/subvec.h | 4 +- lib/utils/include/utils/containers/sum.h | 13 -- .../include/utils/containers/sum_where.h | 24 ++++ .../include/utils/containers/transform.h | 12 ++ .../include/utils/containers/value_all.h | 18 +-- .../include/utils/containers/vector_split.h | 3 +- lib/utils/include/utils/fmt/unordered_map.h | 2 +- lib/utils/include/utils/graph/README.md | 8 +- .../graph/digraph/algorithms/get_dominators.h | 10 +- lib/utils/include/utils/join_strings.h | 3 - lib/utils/include/utils/stack_map.h | 11 ++ lib/utils/include/utils/stack_string.h | 11 ++ lib/utils/include/utils/stack_vector.h | 11 ++ .../src/utils/containers/are_all_distinct.cc | 1 + lib/utils/src/utils/containers/is_submap.cc | 1 - .../src/utils/containers/is_submapeq_of.cc | 1 + .../src/utils/containers/optional_all_of.cc | 1 - .../src/utils/containers/product_where.cc | 1 + lib/utils/src/utils/containers/sum_where.cc | 1 + .../instances/hashmap_undirected_graph.cc | 18 ++- lib/utils/src/utils/graph/views/views.cc | 21 ++- .../utils/bidict/algorithms/merge_bidicts.cc | 13 +- lib/utils/test/src/utils/bidict/bidict.cc | 29 +++- .../src/utils/containers/are_all_distinct.cc | 24 ++++ .../test/src/utils/containers/are_all_same.cc | 2 +- .../test/src/utils/containers/as_vector.cc | 4 +- .../test/src/utils/containers/enumerate.cc | 5 +- lib/utils/test/src/utils/containers/find.cc | 40 ++++-- .../test/src/utils/containers/flatmap.cc | 43 ++++-- .../test/src/utils/containers/get_one_of.cc | 11 +- .../test/src/utils/containers/index_of.cc | 16 ++- .../{is_submap.cc => is_submapeq_of.cc} | 14 +- .../{is_subset_eq_of.cc => is_subseteq_of.cc} | 0 .../test/src/utils/containers/map_keys.cc | 19 ++- .../test/src/utils/containers/maximum.cc | 2 +- .../test/src/utils/containers/merge_maps.cc | 2 +- .../src/utils/containers/optional_all_of.cc | 39 ----- .../test/src/utils/containers/product.cc | 40 +++--- .../src/utils/containers/product_where.cc | 25 ++++ .../test/src/utils/containers/reversed.cc | 2 +- .../test/src/utils/containers/sorted_by.cc | 8 ++ lib/utils/test/src/utils/containers/sum.cc | 18 --- .../test/src/utils/containers/sum_where.cc | 27 ++++ .../test/src/utils/containers/value_all.cc | 2 +- .../test/src/utils/containers/vector_split.cc | 4 +- lib/utils/test/src/utils/disjoint_set.cc | 23 +-- .../digraph/algorithms/get_dominators.cc | 24 ++-- .../graph/digraph/directed_edge_query.cc | 10 +- .../algorithms/get_incoming_edges.cc | 13 +- .../algorithms/get_outgoing_edges.cc | 4 +- .../utils/graph/multidigraph/multidigraph.cc | 49 +++---- lib/utils/test/src/utils/graph/views/views.cc | 135 ++++++++++-------- lib/utils/test/src/utils/join_strings.cc | 27 +++- lib/utils/test/src/utils/stack_map.cc | 17 ++- lib/utils/test/src/utils/stack_string.cc | 10 +- lib/utils/test/src/utils/stack_vector.cc | 18 ++- lib/utils/test/src/utils/tuple.cc | 27 ++-- lib/utils/test/src/utils/type_index.cc | 2 +- 73 files changed, 636 insertions(+), 481 deletions(-) delete mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc create mode 100644 lib/utils/include/utils/containers/are_all_distinct.h rename lib/utils/include/utils/containers/{is_submap.h => is_submapeq_of.h} (76%) delete mode 100644 lib/utils/include/utils/containers/optional_all_of.h create mode 100644 lib/utils/include/utils/containers/product_where.h create mode 100644 lib/utils/include/utils/containers/sum_where.h create mode 100644 lib/utils/src/utils/containers/are_all_distinct.cc delete mode 100644 lib/utils/src/utils/containers/is_submap.cc create mode 100644 lib/utils/src/utils/containers/is_submapeq_of.cc delete mode 100644 lib/utils/src/utils/containers/optional_all_of.cc create mode 100644 lib/utils/src/utils/containers/product_where.cc create mode 100644 lib/utils/src/utils/containers/sum_where.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_distinct.cc rename lib/utils/test/src/utils/containers/{is_submap.cc => is_submapeq_of.cc} (75%) rename lib/utils/test/src/utils/containers/{is_subset_eq_of.cc => is_subseteq_of.cc} (100%) delete mode 100644 lib/utils/test/src/utils/containers/optional_all_of.cc create mode 100644 lib/utils/test/src/utils/containers/product_where.cc create mode 100644 lib/utils/test/src/utils/containers/sum_where.cc diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc index 4f9b879574..a4caceef66 100644 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -10,8 +10,8 @@ TEST_SUITE(FF_TEST_SUITE) { // RC_ASSERT(comb.machine_views.size() == // m0.machine_views.size() + m1.machine_views.size()); - // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); + // RC_ASSERT(is_submapeq_of(comb.machine_views, m0.machine_views)); + // RC_ASSERT(is_submapeq_of(comb.machine_views, m1.machine_views)); // }); // } diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc deleted file mode 100644 index d69fd42b10..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ /dev/null @@ -1,55 +0,0 @@ -#include "substitutions/unlabelled/multidigraph_pattern_match.h" -// #include "substitutions/unlabelled/edge_splits.h" -// #include "substitutions/unlabelled/pattern_edge.h" - -namespace FlexFlow { - -// MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { -// return MultiDiGraphPatternMatch{ -// bidict{}, -// bidict{}, -// }; -// } - -// std::optional -// unsplit_matches(MultiDiGraphPatternMatch const &prefix, -// MultiDiGraphPatternMatch const &postfix, -// UnlabelledPatternEdgeSplits const &edge_splits) { -// -// MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); -// -// std::unordered_set handled; -// for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { -// ClosedPatternEdge closed_edge = std::get(coi); -// OutputPatternEdge output_edge = std::get(coi); -// InputPatternEdge input_edge = std::get(coi); -// -// handled.insert(pattern_edge_from_output_edge(output_edge)); -// handled.insert(pattern_edge_from_input_edge(input_edge)); -// -// OpenMultiDiEdge output_graph_edge = -// prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); -// OpenMultiDiEdge input_graph_edge = -// postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); -// if (output_graph_edge == input_graph_edge) { -// result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), -// output_graph_edge); -// } else { -// return std::nullopt; -// } -// } -// -// for (auto const &kv : -// merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { -// if (!contains(handled, kv.first)) { -// result.edge_assignment.equate(kv.first, kv.second); -// } -// } -// -// result.node_assignment = -// merge_maps(prefix.node_assignment, postfix.node_assignment); -// -// return result; -// } - -} // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/are_all_distinct.h b/lib/utils/include/utils/containers/are_all_distinct.h new file mode 100644 index 0000000000..0990d76ab7 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_distinct.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H + +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bool are_all_distinct(C const &c) { + return unordered_set_of(c).size() == unordered_multiset_of(c).size(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h index 26716ac911..79c8705999 100644 --- a/lib/utils/include/utils/containers/are_all_same.h +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -7,8 +7,8 @@ namespace FlexFlow { template bool are_all_same(C const &c) { - if (c.size() == 0) { - throw mk_runtime_error("input to are_all_same must be non-empty container"); + if (c.empty()) { + return true; } auto const &first = *c.cbegin(); for (auto const &v : c) { diff --git a/lib/utils/include/utils/containers/get_one_of.h b/lib/utils/include/utils/containers/get_one_of.h index 274bc27239..eb772d917a 100644 --- a/lib/utils/include/utils/containers/get_one_of.h +++ b/lib/utils/include/utils/containers/get_one_of.h @@ -1,12 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H +#include "utils/exception.h" #include namespace FlexFlow { template T get_one_of(std::unordered_set const &s) { + if (s.empty()) { + throw mk_runtime_error("input to are_all_same must be non-empty container"); + } return *s.cbegin(); } diff --git a/lib/utils/include/utils/containers/index_of.h b/lib/utils/include/utils/containers/index_of.h index d78888decd..1792490d0c 100644 --- a/lib/utils/include/utils/containers/index_of.h +++ b/lib/utils/include/utils/containers/index_of.h @@ -3,11 +3,12 @@ #include #include + namespace FlexFlow { /** * @details If multiple `e` are present within the container, the function - *returns the index of the first appearance + * returns the index of the first appearance **/ template std::optional index_of(Container const &c, Element const &e) { diff --git a/lib/utils/include/utils/containers/is_submap.h b/lib/utils/include/utils/containers/is_submapeq_of.h similarity index 76% rename from lib/utils/include/utils/containers/is_submap.h rename to lib/utils/include/utils/containers/is_submapeq_of.h index 8cca9f6924..03cb5ccd78 100644 --- a/lib/utils/include/utils/containers/is_submap.h +++ b/lib/utils/include/utils/containers/is_submapeq_of.h @@ -8,8 +8,8 @@ namespace FlexFlow { template -bool is_submap(std::unordered_map const &m, - std::unordered_map const &sub) { +bool is_submapeq_of(std::unordered_map const &sub, + std::unordered_map const &m) { return restrict_keys(m, keys(sub)) == sub; } diff --git a/lib/utils/include/utils/containers/is_subseteq_of.h b/lib/utils/include/utils/containers/is_subseteq_of.h index 26543ca75b..705c092962 100644 --- a/lib/utils/include/utils/containers/is_subseteq_of.h +++ b/lib/utils/include/utils/containers/is_subseteq_of.h @@ -7,14 +7,14 @@ namespace FlexFlow { template -bool is_subseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - if (l.size() > r.size()) { +bool is_subseteq_of(std::unordered_set const &sub, + std::unordered_set const &super) { + if (sub.size() > super.size()) { return false; } - for (auto const &ll : l) { - if (!contains(r, ll)) { + for (auto const &s : sub) { + if (!contains(super, s)) { return false; } } diff --git a/lib/utils/include/utils/containers/is_superseteq_of.h b/lib/utils/include/utils/containers/is_superseteq_of.h index 1793201261..23b16d92f9 100644 --- a/lib/utils/include/utils/containers/is_superseteq_of.h +++ b/lib/utils/include/utils/containers/is_superseteq_of.h @@ -7,9 +7,9 @@ namespace FlexFlow { template -bool is_superseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - return is_subseteq_of(r, l); +bool is_superseteq_of(std::unordered_set const &super, + std::unordered_set const &sub) { + return is_subseteq_of(sub, super); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index bed6c5e4c0..d3ef658c32 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -1,6 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H +#include "utils/containers/are_all_distinct.h" +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" #include #include @@ -9,9 +13,6 @@ namespace FlexFlow { /** * @brief Applies the given function to all the keys within the given map and * returns the updated map. - * - * @details Note that if multiple original keys are transformed to the same new - * key, which value will be associated to the new key is undefined. */ template > std::unordered_map map_keys(std::unordered_map const &m, F const &f) { + std::unordered_multiset transformed_keys = + transform(unordered_multiset_of(keys(m)), f); + if (!are_all_distinct(transformed_keys)) { + throw mk_runtime_error(fmt::format( + "keys passed to map_keys must be transformed into distinct keys")); + } + std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 70d95b1293..9953c44d0d 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -8,7 +8,7 @@ namespace FlexFlow { template typename C::value_type maximum(C const &c) { - if (c.size() == 0) { + if (c.empty()) { throw mk_runtime_error( "input to function maximum must be non-empty container"); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 3d93863b2d..4176a36762 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -13,8 +13,10 @@ template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { if (!are_disjoint(keys(lhs), keys(rhs))) { - throw mk_runtime_error( - "Key sets of merge_maps parameters are non-disjoint"); + throw mk_runtime_error(fmt::format("Key sets of merge_maps parameters are " + "non-disjoint: lhs = {}, rhs = {}", + lhs, + rhs)); } std::unordered_map result; diff --git a/lib/utils/include/utils/containers/optional_all_of.h b/lib/utils/include/utils/containers/optional_all_of.h deleted file mode 100644 index 825633891b..0000000000 --- a/lib/utils/include/utils/containers/optional_all_of.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_OPTIONAL_ALL_OF_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_OPTIONAL_ALL_OF_H - -#include - -namespace FlexFlow { - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/product.h b/lib/utils/include/utils/containers/product.h index a9fd8b2e76..af04edcb81 100644 --- a/lib/utils/include/utils/containers/product.h +++ b/lib/utils/include/utils/containers/product.h @@ -26,19 +26,6 @@ typename It::value_type product(It begin, It end) { }); } -template -Element product_where(Container const &container, ConditionF const &condition) { - Element result = 1; - for (Element const &element : container) { - if (condition(element)) { - result *= element; - } - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/product_where.h b/lib/utils/include/utils/containers/product_where.h new file mode 100644 index 0000000000..51af47c2fa --- /dev/null +++ b/lib/utils/include/utils/containers/product_where.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H + +#include + +namespace FlexFlow { + +/** + * @details An empty container vacuously has product 1 + **/ +template +Element product_where(Container const &container, ConditionF const &condition) { + Element result = 1; + for (Element const &element : container) { + if (condition(element)) { + result *= element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 77eb6dc421..21bec9ad24 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -16,9 +16,9 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { if (idx < 0) { - return std::max(0, (int)v.size() + idx); + return std::max(0, static_cast(v.size()) + idx); } else { - return std::min(idx, (int)v.size()); + return std::min(idx, static_cast(v.size())); } }; diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h index 4e175ae792..135e704045 100644 --- a/lib/utils/include/utils/containers/sum.h +++ b/lib/utils/include/utils/containers/sum.h @@ -15,19 +15,6 @@ Element sum(Container const &container) { return result; } -template -Element sum_where(Container const &container, ConditionF const &condition) { - Element result = 0; - for (Element const &element : container) { - if (condition(element)) { - result += element; - } - } - return result; -} - } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/sum_where.h b/lib/utils/include/utils/containers/sum_where.h new file mode 100644 index 0000000000..214f51c1c9 --- /dev/null +++ b/lib/utils/include/utils/containers/sum_where.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H + +namespace FlexFlow { + +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum_where(Container const &container, ConditionF const &condition) { + Element result = 0; + for (Element const &element : container) { + if (condition(element)) { + result += element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ec3d5f5612..02dadb5352 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -33,6 +33,18 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } +template ()(std::declval()))> +std::unordered_multiset transform(std::unordered_multiset const &v, + F const &f) { + std::unordered_multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template ()(std::declval()))> diff --git a/lib/utils/include/utils/containers/value_all.h b/lib/utils/include/utils/containers/value_all.h index 996b2adfbd..5727bd8396 100644 --- a/lib/utils/include/utils/containers/value_all.h +++ b/lib/utils/include/utils/containers/value_all.h @@ -9,20 +9,22 @@ namespace FlexFlow { template std::vector value_all(std::vector> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); }); }); } template std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); }); }); } diff --git a/lib/utils/include/utils/containers/vector_split.h b/lib/utils/include/utils/containers/vector_split.h index a5dc098644..872733e3ce 100644 --- a/lib/utils/include/utils/containers/vector_split.h +++ b/lib/utils/include/utils/containers/vector_split.h @@ -12,7 +12,8 @@ template std::pair, std::vector> vector_split(std::vector const &v, int idx) { if (idx < 0 || idx > static_cast(v.size())) { - throw std::out_of_range("Index out of range in vector_split"); + throw std::out_of_range(fmt::format( + "Index out of range in vector_split: index = {}, vector = {}", idx, v)); } std::vector prefix(v.begin(), v.begin() + idx); diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 75bbb4cb8a..93089c02e7 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -19,7 +19,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_map const &m, FormatContext &ctx) + auto format(::std::unordered_map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index 90511ae343..f3d31e7bc8 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -18,7 +18,11 @@ At their core, they are as follows: - `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. - `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) - `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. -- `DataflowGraph`: similar to `MultiDiGraph`, but the edges entering, exiting a given nodes now have a well-defined order. Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. +- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: + - The edges entering, exiting a given nodes now have a well-defined order. + - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. + - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. + Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. Examples of the different graph variants are shown below. @@ -132,7 +136,7 @@ Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGr - `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. - `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). -While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. +While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index d1019ce70a..96e8864bc1 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -6,8 +6,7 @@ namespace FlexFlow { /** - * @brief A node `d` is said to dominate a node `n` if every path from the root - * node to `n` must go through `d`. + * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) * * @note By definition, the root node dominates every node and every node * dominates itself. @@ -16,8 +15,11 @@ namespace FlexFlow { std::unordered_set get_dominators(DiGraphView const &, Node const &); /** - * @brief Returns the intersection of the dominators of the input nodes. - * + * @brief Returns the intersection of the dominators of the given set of nodes. + * @note This is conceptually equivalent to merging the given set of nodes and + * then finding the set of dominators of the new merged node (where merged means + * that all edges belonging to the set of nodes now pass through a single + * unified node). */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index 28a773b13d..7b53384b39 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -13,15 +13,12 @@ std::string join_strings(InputIt first, F const &f) { std::ostringstream oss; bool first_iter = true; - /* int i = 0; */ for (; first != last; first++) { if (!first_iter) { oss << delimiter; } oss << f(*first); - /* break; */ first_iter = false; - /* i++; */ } return oss.str(); } diff --git a/lib/utils/include/utils/stack_map.h b/lib/utils/include/utils/stack_map.h index c70842de7e..ae8ed5cc52 100644 --- a/lib/utils/include/utils/stack_map.h +++ b/lib/utils/include/utils/stack_map.h @@ -130,4 +130,15 @@ CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(stack_map); } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(FlexFlow::stack_map const &map) { + return toString(fmt::to_string(map)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 19743b8301..bdc10a98db 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -134,4 +134,15 @@ CHECK_HASHABLE(stack_string<1>); } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(FlexFlow::stack_string const &s) { + return toString(fmt::to_string(s)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 1d654e3415..21bcdf34ed 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -369,4 +369,15 @@ struct Arbitrary<::FlexFlow::stack_vector> { } // namespace rc +namespace doctest { + +template +struct StringMaker> { + static String convert(FlexFlow::stack_vector const &v) { + return toString(fmt::to_string(v)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/src/utils/containers/are_all_distinct.cc b/lib/utils/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..52c665d191 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_distinct.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_distinct.h" diff --git a/lib/utils/src/utils/containers/is_submap.cc b/lib/utils/src/utils/containers/is_submap.cc deleted file mode 100644 index be420a14be..0000000000 --- a/lib/utils/src/utils/containers/is_submap.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/is_submap.h" diff --git a/lib/utils/src/utils/containers/is_submapeq_of.cc b/lib/utils/src/utils/containers/is_submapeq_of.cc new file mode 100644 index 0000000000..567d94fac5 --- /dev/null +++ b/lib/utils/src/utils/containers/is_submapeq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_submapeq_of.h" diff --git a/lib/utils/src/utils/containers/optional_all_of.cc b/lib/utils/src/utils/containers/optional_all_of.cc deleted file mode 100644 index dd31ad78a5..0000000000 --- a/lib/utils/src/utils/containers/optional_all_of.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/optional_all_of.h" diff --git a/lib/utils/src/utils/containers/product_where.cc b/lib/utils/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..3a435e7d80 --- /dev/null +++ b/lib/utils/src/utils/containers/product_where.cc @@ -0,0 +1 @@ +#include "utils/containers/product_where.h" diff --git a/lib/utils/src/utils/containers/sum_where.cc b/lib/utils/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..c09f9c573e --- /dev/null +++ b/lib/utils/src/utils/containers/sum_where.cc @@ -0,0 +1 @@ +#include "utils/containers/sum_where.h" diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 6da6cffd03..649d372dba 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -39,9 +39,8 @@ void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.endpoints.max()); - m.erase(e.endpoints.min()); - m.erase(e.endpoints.max()); + this->adjacency.at(e.endpoints.max()).erase(e.endpoints.min()); + this->adjacency.at(e.endpoints.min()).erase(e.endpoints.max()); } std::unordered_set HashmapUndirectedGraph::query_edges( @@ -49,13 +48,20 @@ std::unordered_set HashmapUndirectedGraph::query_edges( std::unordered_set result; std::unordered_set nodes = keys(query_keys(query.nodes, this->adjacency)); - for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : src_kv.second) { + for (auto const &[src, n] : query_keys(query.nodes, this->adjacency)) { + for (auto const &dst : n) { + result.insert(UndirectedEdge{{src, dst}}); + } + } + for (auto const &[src, n] : + query_keys(query_set::matchall(), this->adjacency)) { + for (auto const &dst : n) { if (contains(nodes, dst)) { - result.insert(UndirectedEdge{{src_kv.first, dst}}); + result.insert(UndirectedEdge{{src, dst}}); } } } + return result; } diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index 3ada561dc9..eb61a704af 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,4 +1,5 @@ #include "utils/graph/views/views.h" +#include "utils/bidict/algorithms/left_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" #include "utils/disjoint_set.h" @@ -22,9 +23,13 @@ UndirectedSubgraphView *UndirectedSubgraphView::clone() const { std::unordered_set UndirectedSubgraphView::query_edges( UndirectedEdgeQuery const &query) const { - UndirectedEdgeQuery subgraph_query = - UndirectedEdgeQuery{this->subgraph_nodes}; - return this->g.query_edges(query_intersection(query, subgraph_query)); + std::unordered_set edges_including_open_edges = + this->g.query_edges( + query_intersection(query, UndirectedEdgeQuery{this->subgraph_nodes})); + return filter(edges_including_open_edges, [&](UndirectedEdge const &e) { + return contains(this->subgraph_nodes, e.endpoints.min()) && + contains(this->subgraph_nodes, e.endpoints.max()); + }); } std::unordered_set @@ -78,9 +83,11 @@ JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { std::unordered_set JoinedNodeView::query_nodes(NodeQuery const &query) const { - // TODO @lockshaw this is going to be reimplemented in 984, so don't bother - // fixing it for now - NOT_IMPLEMENTED(); + std::unordered_set all_nodes = + transform(left_entries(this->mapping), + [&](JoinNodeKey const &key) { return key.node; }); + + return allowed_values(query_intersection(query, NodeQuery{all_nodes}).nodes); } std::pair, std::unordered_set> @@ -237,7 +244,7 @@ std::unordered_set ViewDiGraphAsUndirectedGraph::query_edges( DirectedEdgeQuery q1{undirected_query.nodes, query_set::matchall()}; DirectedEdgeQuery q2{query_set::matchall(), undirected_query.nodes}; return to_undirected_edges( - set_union(this->g.query_edges(q1), this->g.query_edges(q2))); + intersection(this->g.query_edges(q1), this->g.query_edges(q2))); } std::unordered_set ViewDiGraphAsUndirectedGraph::query_nodes( diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc index 291b4eab73..8c835558ad 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc @@ -18,18 +18,25 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("overlapping keys") { + SUBCASE("overlapping key, different associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS_AS(merge_bidicts(bd1, bd2), std::runtime_error); + CHECK_THROWS(merge_bidicts(bd1, bd2)); + } + + SUBCASE("overlapping key, same associated value") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "two"}, {3, "three"}}; + + CHECK_THROWS(merge_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS_AS(merge_bidicts(bd1, bd2), std::runtime_error); + CHECK_THROWS(merge_bidicts(bd1, bd2)); } } } diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 019aaed7bd..fb58f1738e 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -11,14 +11,35 @@ TEST_SUITE(FF_TEST_SUITE) { dict.equate(2, "two"); SUBCASE("bidict::contains_l") { - CHECK(dict.contains_l(1)); - CHECK_FALSE(dict.contains_l(3)); + SUBCASE("L type is not the same as R type") { + CHECK(dict.contains_l(1)); + CHECK_FALSE(dict.contains_l(3)); + } + + SUBCASE("L type is the same as R type") { + bidict bd; + bd.equate(1, 3); + + CHECK(bd.contains_l(1)); + CHECK_FALSE(bd.contains_l(3)); + } } SUBCASE("bidict::contains_r") { - CHECK(dict.contains_r(std::string("one"))); - CHECK_FALSE(dict.contains_r(std::string("three"))); + SUBCASE("L type is not the same as R type") { + CHECK(dict.contains_r(std::string("one"))); + CHECK_FALSE(dict.contains_r(std::string("three"))); + } + + SUBCASE("L type is the same as R type") { + bidict bd; + bd.equate(1, 3); + + CHECK(bd.contains_r(3)); + CHECK_FALSE(bd.contains_r(1)); + } } + SUBCASE("bidict::contains_r, bidict::contains_r - same type") { bidict bd; bd.equate(1, 3); diff --git a/lib/utils/test/src/utils/containers/are_all_distinct.cc b/lib/utils/test/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..6c2e9ea445 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_distinct.cc @@ -0,0 +1,24 @@ +#include "utils/containers/are_all_distinct.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_distinct") { + + SUBCASE("Empty Container") { + std::vector input = {}; + CHECK(are_all_distinct(input)); + } + SUBCASE("All elements are distinct") { + std::vector input = {1, 2, 3, 4}; + CHECK(are_all_distinct(input)); + } + + SUBCASE("Not all elements are distinct") { + std::vector input = {2, 2, 3, 4}; + CHECK_FALSE(are_all_distinct(input)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc index 38e01a13a6..df671f8541 100644 --- a/lib/utils/test/src/utils/containers/are_all_same.cc +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Empty Container") { std::vector input = {}; - CHECK_THROWS_AS(are_all_same(input), std::runtime_error); + CHECK(are_all_same(input)); } } } diff --git a/lib/utils/test/src/utils/containers/as_vector.cc b/lib/utils/test/src/utils/containers/as_vector.cc index fc4567b585..deef9c00fd 100644 --- a/lib/utils/test/src/utils/containers/as_vector.cc +++ b/lib/utils/test/src/utils/containers/as_vector.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("as_vector") { std::unordered_set input = {1, 2, 3}; std::vector result = as_vector(input); - std::vector correct_sorted = {1, 2, 3}; - CHECK(sorted(result) == correct_sorted); + std::vector correct = {1, 2, 3}; + CHECK(sorted(result) == sorted(correct)); } } diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 25eece0d45..61d70f5aff 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -44,11 +44,10 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("enumerate(std::unordered_set)") { - std::unordered_set input = {"zero", "one", "two", "three"}; + std::unordered_set input = {"A", "B", "C", "D"}; std::unordered_set correct_keys = {0, 1, 2, 3}; - std::unordered_multiset correct_values = { - "zero", "one", "two", "three"}; + std::unordered_multiset correct_values = {"A", "B", "C", "D"}; std::map result = enumerate(input); CHECK(keys(result) == correct_keys); diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc index 60260bcb95..8e5e93bcf9 100644 --- a/lib/utils/test/src/utils/containers/find.cc +++ b/lib/utils/test/src/utils/containers/find.cc @@ -11,23 +11,43 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("find") { SUBCASE("vector") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); - CHECK_WITHOUT_STRINGIFY(find(v, 6) == std::find(v.begin(), v.end(), 6)); + std::vector v = {1, 2, 3, 3, 4, 5, 3}; + + SUBCASE("element found") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } + + SUBCASE("element not found") { + CHECK_WITHOUT_STRINGIFY(find(v, 6) == std::find(v.begin(), v.end(), 6)); + } + + SUBCASE("multiple occurrences of element") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } } SUBCASE("unordered_set") { - std::unordered_set us = {1, 2, 3, 4, 5}; - CHECK_WITHOUT_STRINGIFY(find(us, 3) == - std::find(us.begin(), us.end(), 3)); - CHECK_WITHOUT_STRINGIFY(find(us, 6) == - std::find(us.begin(), us.end(), 6)); + std::unordered_set s = {1, 2, 3, 4, 5}; + + SUBCASE("element found") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not found") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } } SUBCASE("set") { std::set s = {1, 2, 3, 4, 5}; - CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); - CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + + SUBCASE("element found") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not found") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } } } } diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index e395ce5583..f96175b879 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -6,21 +6,36 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Test for flatmap function on vectors") { - auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); + TEST_CASE("flatmap") { + SUBCASE("same data-type") { + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } } - } - return factors; - }; + return factors; + }; - std::vector input = {2, 3, 4, 5}; - std::vector result = flatmap(input, get_factors); - std::vector correct = {1, 2, 1, 3, 1, 2, 4, 1, 5}; - CHECK(result == correct); + std::vector input = {2, 3, 4, 5}; + std::vector result = flatmap(input, get_factors); + std::vector correct = {1, 2, 1, 3, 1, 2, 4, 1, 5}; + CHECK(result == correct); + } + + SUBCASE("different data-type") { + auto get_string_sequence = [](int x) -> std::vector { + return { + std::to_string(x - 1), std::to_string(x), std::to_string(2 * x)}; + }; + + std::vector input = {2, 4, 10}; + std::vector result = flatmap(input, get_string_sequence); + std::vector correct = { + "1", "2", "4", "3", "4", "8", "9", "10", "20"}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/containers/get_one_of.cc b/lib/utils/test/src/utils/containers/get_one_of.cc index 840bb47ebd..3470e16f37 100644 --- a/lib/utils/test/src/utils/containers/get_one_of.cc +++ b/lib/utils/test/src/utils/containers/get_one_of.cc @@ -6,7 +6,14 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_one_of") { - std::unordered_set s = {1, 2, 3}; - CHECK(contains(s, get_one_of(s))); + SUBCASE("non-empty set") { + std::unordered_set s = {1, 2, 3}; + CHECK(contains(s, get_one_of(s))); + } + + SUBCASE("empty set") { + std::unordered_set s = {}; + CHECK_THROWS_AS(get_one_of(s), std::runtime_error); + } } } diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc index d522b91d62..9b6bbfe3f4 100644 --- a/lib/utils/test/src/utils/containers/index_of.cc +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -8,15 +8,17 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("index_of") { - SUBCASE("unique elements") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3).value() == 2); - CHECK(index_of(v, 6) == std::nullopt); + + std::vector v = {1, 2, 3, 4, 3, 5}; + + SUBCASE("unique element") { + CHECK(index_of(v, 4).value() == 3); } SUBCASE("duplicate elements") { - std::vector v = {1, 2, 3, 4, 3, 5}; - CHECK(index_of(v, 3).value() == 2); - CHECK(index_of(v, 6) == std::nullopt); + CHECK(index_of(v, 3).value() == 2); // Returns first occurrence + } + SUBCASE("element not present") { + CHECK(index_of(v, 7) == std::nullopt); } } } diff --git a/lib/utils/test/src/utils/containers/is_submap.cc b/lib/utils/test/src/utils/containers/is_submapeq_of.cc similarity index 75% rename from lib/utils/test/src/utils/containers/is_submap.cc rename to lib/utils/test/src/utils/containers/is_submapeq_of.cc index 18ee8c677a..df89444235 100644 --- a/lib/utils/test/src/utils/containers/is_submap.cc +++ b/lib/utils/test/src/utils/containers/is_submapeq_of.cc @@ -1,4 +1,4 @@ -#include "utils/containers/is_submap.h" +#include "utils/containers/is_submapeq_of.h" #include #include #include @@ -6,35 +6,35 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_submap") { + TEST_CASE("is_submapeq_of") { std::unordered_map super = { {1, "one"}, {2, "two"}, {3, "three"}}; SUBCASE("keys and values match") { std::unordered_map sub = {{1, "one"}, {2, "two"}}; - CHECK(is_submap(super, sub)); + CHECK(is_submapeq_of(sub, super)); } SUBCASE("keys and values don't match") { std::unordered_map sub = {{1, "one"}, {4, "four"}}; - CHECK_FALSE(is_submap(super, sub)); + CHECK_FALSE(is_submapeq_of(sub, super)); } SUBCASE("keys match but values don't") { std::unordered_map sub = {{1, "wrong_value"}, {2, "two"}}; - CHECK_FALSE(is_submap(super, sub)); + CHECK_FALSE(is_submapeq_of(sub, super)); } SUBCASE("values match but keys don't") { std::unordered_map sub = {{5, "one"}, {6, "two"}}; - CHECK_FALSE(is_submap(super, sub)); + CHECK_FALSE(is_submapeq_of(sub, super)); } SUBCASE("sub is a superset of super") { std::unordered_map sub = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK_FALSE(is_submap(super, sub)); + CHECK_FALSE(is_submapeq_of(sub, super)); } } } diff --git a/lib/utils/test/src/utils/containers/is_subset_eq_of.cc b/lib/utils/test/src/utils/containers/is_subseteq_of.cc similarity index 100% rename from lib/utils/test/src/utils/containers/is_subset_eq_of.cc rename to lib/utils/test/src/utils/containers/is_subseteq_of.cc diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc index 971a2017e7..256a991a67 100644 --- a/lib/utils/test/src/utils/containers/map_keys.cc +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -8,10 +8,19 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; - std::unordered_map result = map_keys(m, f); - std::unordered_map correct = {{1, "one"}, {4, "two"}}; - CHECK(correct == result); + SUBCASE("Distinct keys after transformation") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; + std::unordered_map result = map_keys(m, f); + std::unordered_map correct = {{1, "one"}, {4, "two"}}; + CHECK(correct == result); + } + + SUBCASE("Non-distinct keys after transformation") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {-1, "minus one"}}; + auto f = [](int x) { return std::abs(x); }; + CHECK_THROWS(map_keys(m, f)); + } } } diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc index fd695553cd..0baf26630c 100644 --- a/lib/utils/test/src/utils/containers/maximum.cc +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("empty container") { std::vector input = {}; - CHECK_THROWS_AS(maximum(input), std::runtime_error); + CHECK_THROWS(maximum(input)); } } } diff --git a/lib/utils/test/src/utils/containers/merge_maps.cc b/lib/utils/test/src/utils/containers/merge_maps.cc index 0a1fbd78a4..148ad6fd87 100644 --- a/lib/utils/test/src/utils/containers/merge_maps.cc +++ b/lib/utils/test/src/utils/containers/merge_maps.cc @@ -24,7 +24,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_map lhs = {{1, "one"}, {2, "two"}}; std::unordered_map rhs = {{2, "three"}, {3, "four"}}; - CHECK_THROWS_AS(merge_maps(lhs, rhs), std::runtime_error); + CHECK_THROWS(merge_maps(lhs, rhs)); } } } diff --git a/lib/utils/test/src/utils/containers/optional_all_of.cc b/lib/utils/test/src/utils/containers/optional_all_of.cc deleted file mode 100644 index 2c7113d0da..0000000000 --- a/lib/utils/test/src/utils/containers/optional_all_of.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "utils/containers/optional_all_of.h" -#include "utils/fmt/optional.h" -#include -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("optional_all_of") { - std::vector input = {2, 4, 6, 8}; - - SUBCASE("All values satisfy condition") { - auto f = [](int x) -> std::optional { return x % 2 == 0; }; - std::optional correct = true; - std::optional result = optional_all_of(input, f); - CHECK(correct == result); - } - - SUBCASE("One value returns nullopt") { - auto f = [](int x) -> std::optional { - if (x == 6) { - return std::nullopt; - } - return x % 2 == 0; - }; - std::optional correct = std::nullopt; - std::optional result = optional_all_of(input, f); - CHECK(correct == result); - } - - SUBCASE("Empty container") { - auto f = [](int x) -> std::optional { return true; }; - std::optional correct = true; - std::optional result = optional_all_of(input, f); - CHECK(correct == result); - } - } -} diff --git a/lib/utils/test/src/utils/containers/product.cc b/lib/utils/test/src/utils/containers/product.cc index 3beed9d15c..3fa94c8e9e 100644 --- a/lib/utils/test/src/utils/containers/product.cc +++ b/lib/utils/test/src/utils/containers/product.cc @@ -1,39 +1,31 @@ #include "utils/containers/product.h" #include +#include +#include #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("product") { + + TEST_CASE_TEMPLATE("product", + C, + std::vector, + std::vector, + std::set, + std::unordered_set) { + SUBCASE("non-empty container") { - std::vector input = {1, -2, 3, 5}; - int correct = -30; - int result = product(input); + C input = {1, -2, 3, 5}; + auto correct = -30; + auto result = product(input); CHECK(correct == result); } SUBCASE("empty container") { - std::vector input = {}; - int correct = 1; - int result = product(input); - CHECK(correct == result); - } - } - - TEST_CASE("product_where") { - SUBCASE("non-empty filtered container") { - std::vector input = {1, -2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; - int correct = -8; - int result = product_where(input, condition); - CHECK(correct == result); - } - SUBCASE("empty filtered container") { - std::vector input = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x > 10; }; - int correct = 1; - int result = product_where(input, condition); + C input = {}; + auto correct = 1; + auto result = product(input); CHECK(correct == result); } } diff --git a/lib/utils/test/src/utils/containers/product_where.cc b/lib/utils/test/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..d175beefd4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/product_where.cc @@ -0,0 +1,25 @@ +#include "utils/containers/product_where.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("product_where") { + SUBCASE("non-empty filtered container") { + std::vector input = {1, -2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = -8; + int result = product_where(input, condition); + CHECK(correct == result); + } + SUBCASE("empty filtered container") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc index 4d71752678..50900d936e 100644 --- a/lib/utils/test/src/utils/containers/reversed.cc +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -6,7 +6,7 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("Testing the 'reversed' function") { + TEST_CASE("reversed(std::vector)") { std::vector input_vec = {1, 2, 3, 4, 5}; std::vector result = reversed(input_vec); std::vector correct = {5, 4, 3, 2, 1}; diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc index 5fae04bd91..525b1980a6 100644 --- a/lib/utils/test/src/utils/containers/sorted_by.cc +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -23,5 +23,13 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector correct = {-1, -2, -3, -4, -5}; CHECK(result == correct); } + + SUBCASE("duplicates") { + std::vector input = {3, 1, 3, -4, 1}; + std::vector result = + sorted_by(input, [](int a, int b) { return a < b; }); + std::vector correct = {-4, 1, 1, 3, 3}; + CHECK(result == correct); + } } } diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc index f18a4d86c7..211c70ce7a 100644 --- a/lib/utils/test/src/utils/containers/sum.cc +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -20,22 +20,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(correct == result); } } - - TEST_CASE("sum_where") { - SUBCASE("Resulting container is non-empty") { - std::vector input = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; - int correct = 6; - int result = sum_where(input, condition); - CHECK(correct == result); - } - - SUBCASE("Resulting container is empty") { - std::vector input = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x > 10; }; - int correct = 0; - int result = sum_where(input, condition); - CHECK(correct == result); - } - } } diff --git a/lib/utils/test/src/utils/containers/sum_where.cc b/lib/utils/test/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..5d97d199ed --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum_where.cc @@ -0,0 +1,27 @@ +#include "utils/containers/sum_where.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("sum_where") { + SUBCASE("Resulting container is non-empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 6; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("Resulting container is empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/value_all.cc b/lib/utils/test/src/utils/containers/value_all.cc index 620acc5dcd..07f96aa807 100644 --- a/lib/utils/test/src/utils/containers/value_all.cc +++ b/lib/utils/test/src/utils/containers/value_all.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("value_all") { SUBCASE("With nullopt") { std::vector> input = {1, 2, std::nullopt, 4, 5}; - CHECK_THROWS_AS(value_all(input), std::runtime_error); + CHECK_THROWS(value_all(input)); } SUBCASE("Without nullopt") { diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc index eb51ff1ab5..5cf8672ef4 100644 --- a/lib/utils/test/src/utils/containers/vector_split.cc +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -22,7 +22,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(postfix == std::vector({1, 2, 3, 4, 5})); } - SUBCASE("Boundary case: idx = 5") { + SUBCASE("Boundary case: idx is the last index in the list") { auto [prefix, postfix] = vector_split(v, 5); CHECK(prefix == std::vector({1, 2, 3, 4, 5})); CHECK(postfix.empty()); @@ -32,7 +32,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK_THROWS_AS(vector_split(v, -1), std::out_of_range); } - SUBCASE("Out of bounds case: idx = 6") { + SUBCASE("Out of bounds case: idx == list_size + 1") { CHECK_THROWS_AS(vector_split(v, 6), std::out_of_range); } } diff --git a/lib/utils/test/src/utils/disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc index 5e5d118222..283b0a8983 100644 --- a/lib/utils/test/src/utils/disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,5 +1,7 @@ #include "utils/disjoint_set.h" #include "test/utils/doctest.h" +#include "utils/fmt/optional.h" + using namespace FlexFlow; template @@ -21,10 +23,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("SingleElementSets") { std::optional element = generate_element(1); - CHECK_WITHOUT_STRINGIFY(ds.find(element) == element); + CHECK(ds.find(element) == element); element = generate_element(2); - CHECK_WITHOUT_STRINGIFY(ds.find(element) == element); + CHECK(ds.find(element) == element); } SUBCASE("UnionAndFind") { @@ -34,16 +36,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional element4 = generate_element(4); ds.m_union(element1, element2); - CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element2)); + CHECK(ds.find(element1) == ds.find(element2)); ds.m_union(element3, element4); - CHECK_WITHOUT_STRINGIFY(ds.find(element3) == ds.find(element4)); + CHECK(ds.find(element3) == ds.find(element4)); ds.m_union(element1, element3); - CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element3)); - CHECK_WITHOUT_STRINGIFY(ds.find(element2) == ds.find(element4)); - CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element2)); - CHECK_WITHOUT_STRINGIFY(ds.find(element1) == ds.find(element4)); + CHECK(ds.find(element1) == ds.find(element3)); + CHECK(ds.find(element2) == ds.find(element4)); + CHECK(ds.find(element1) == ds.find(element2)); + CHECK(ds.find(element1) == ds.find(element4)); } } @@ -61,9 +63,8 @@ TEST_SUITE(FF_TEST_SUITE) { mapping = ds.get_mapping(); for (auto const &kv : mapping) { - CHECK_WITHOUT_STRINGIFY( - *kv.second == *expectedMapping[kv.first]); // Compare the values - // inside the optionals + CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values + // inside the optionals } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc index 056280df8e..a330d09839 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc @@ -20,16 +20,16 @@ TEST_SUITE(FF_TEST_SUITE) { add_edges(g, e); SUBCASE("single node") { - Node node = n[2]; - std::unordered_set correct = {n[0], n[2]}; + Node node = n.at(2); + std::unordered_set correct = {n.at(0), n.at(2)}; std::unordered_set result = get_dominators(g, node); CHECK(correct == result); } SUBCASE("multiple nodes") { - std::unordered_set nodes = {n[1], n[3]}; + std::unordered_set nodes = {n.at(1), n.at(3)}; std::unordered_set result = get_dominators(g, nodes); - std::unordered_set correct = {n[0]}; + std::unordered_set correct = {n.at(0)}; CHECK(correct == result); } @@ -52,13 +52,17 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(4), n.at(1)}, }); - std::unordered_set result = get_dominators(g, n.at(1)); - std::unordered_set correct = {n.at(0), n.at(1)}; + SUBCASE("node 1") { + std::unordered_set result = get_dominators(g, n.at(1)); + std::unordered_set correct = {n.at(0), n.at(1)}; + CHECK(result == correct); + } - result = get_dominators(g, n.at(3)); - correct = {n.at(0), n.at(1), n.at(3)}; - - CHECK(result == correct); + SUBCASE("node 3") { + std::unordered_set result = get_dominators(g, n.at(3)); + std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; + CHECK(result == correct); + } } } } diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc index 695731f1b6..1dde5c8f69 100644 --- a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc +++ b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc @@ -61,13 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; DirectedEdgeQuery result = query_intersection(q1, q2); - - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); - CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); - CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(3)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(4)})); - CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(0), n.at(4)})); - CHECK_FALSE(matches_edge(result, DirectedEdge{n.at(2), n.at(0)})); + DirectedEdgeQuery correct = DirectedEdgeQuery{ + query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index 6a74e4ce6e..89d38111ea 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -15,14 +15,11 @@ TEST_SUITE(FF_TEST_SUITE) { MultiDiGraph g = MultiDiGraph::create(); std::vector n = add_nodes(g, 3); - std::vector> input = { - {n.at(0), n.at(0)}, - {n.at(0), n.at(1)}, - {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}, - }; - - std::vector edges = add_edges(g, input); + std::vector edges = add_edges(g, + {{n.at(0), n.at(0)}, + {n.at(0), n.at(1)}, + {n.at(0), n.at(1)}, + {n.at(1), n.at(0)}}); SUBCASE("node has incoming edges") { std::unordered_set result = get_incoming_edges(g, n.at(1)); diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index f78311315d..9e5790bb5a 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -24,14 +24,14 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector edges = add_edges(g, input); - SUBCASE("node has incoming edges") { + SUBCASE("node has outgoing edges") { std::unordered_set result = get_outgoing_edges(g, n.at(0)); std::unordered_set correct = { edges.at(0), edges.at(1), edges.at(2)}; CHECK(result == correct); } - SUBCASE("node has no incoming edges") { + SUBCASE("node has no outgoing edges") { std::unordered_set result = get_outgoing_edges(g, n.at(2)); std::unordered_set correct = {}; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc index def1ca12b9..16ee406255 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -32,21 +32,23 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("add_edge") { - MultiDiEdge e7 = g.add_edge(n2, n1); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n2}, {n1})); - std::unordered_set correct = {e7}; - CHECK(result == correct); - } + SUBCASE("non-duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7}; + CHECK(result == correct); + } - SUBCASE("add_edges") { - MultiDiEdge e7 = g.add_edge(n1, n2); - MultiDiEdge e8 = g.add_edge(n2, n1); + SUBCASE("duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + MultiDiEdge e8 = g.add_edge(n2, n1); - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n0, n1}, {n2})); - std::unordered_set correct = {e0, e3, e4, e7}; - CHECK(result == correct); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7, e8}; + CHECK(result == correct); + } } SUBCASE("remove_node") { @@ -57,7 +59,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(node_result == node_correct); std::unordered_set edge_result = - g.query_edges(MultiDiEdgeQuery({n0}, {n0})); + g.query_edges(MultiDiEdgeQuery({n0}, {n1, n2})); std::unordered_set edge_correct = {}; CHECK(edge_result == edge_correct); } @@ -68,33 +70,22 @@ TEST_SUITE(FF_TEST_SUITE) { g.query_edges(MultiDiEdgeQuery({n1}, {n2})); std::unordered_set correct = {e4}; CHECK(result == correct); - } - - SUBCASE("remove_edges") { - g.remove_edge(e0); - g.remove_edge(e1); - g.remove_edge(e3); - SUBCASE("edges from n0 to n2") { + SUBCASE("remove non-duplicate edge") { + g.remove_edge(e0); std::unordered_set result = g.query_edges(MultiDiEdgeQuery({n0}, {n2})); std::unordered_set correct = {}; CHECK(result == correct); } - SUBCASE("edges from n1 to n0") { + SUBCASE("remove duplicate edge") { + g.remove_edge(e1); std::unordered_set result = g.query_edges(MultiDiEdgeQuery({n1}, {n0})); std::unordered_set correct = {e2}; CHECK(result == correct); } - - SUBCASE("edges from n1 to n2") { - std::unordered_set result = - g.query_edges(MultiDiEdgeQuery({n1}, {n2})); - std::unordered_set correct = {e4}; - CHECK(result == correct); - } } SUBCASE("query_nodes") { diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc index e6cd562d8f..6bc4fdbcea 100644 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ b/lib/utils/test/src/utils/graph/views/views.cc @@ -18,17 +18,17 @@ TEST_SUITE(FF_TEST_SUITE) { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 5); add_edges(g, - {UndirectedEdge{{n[0], n[3]}}, - UndirectedEdge{{n[1], n[1]}}, - UndirectedEdge{{n[1], n[2]}}, - UndirectedEdge{{n[1], n[3]}}, - UndirectedEdge{{n[2], n[3]}}, - UndirectedEdge{{n[2], n[4]}}}); - std::unordered_set sub_nodes = {n[0], n[1], n[3]}; + {UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(3)}}, + UndirectedEdge{{n.at(2), n.at(4)}}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; UndirectedGraphView view = get_subgraph(g, sub_nodes); SUBCASE("get_nodes") { - std::unordered_set expected = {n[0], n[1], n[3]}; + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; std::unordered_set result = get_nodes(view); @@ -37,9 +37,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - UndirectedEdge{{n[0], n[3]}}, - UndirectedEdge{{n[1], n[1]}}, - UndirectedEdge{{n[1], n[3]}}, + UndirectedEdge{{n.at(0), n.at(3)}}, + UndirectedEdge{{n.at(1), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(3)}}, }; std::unordered_set result = get_edges(view); @@ -53,19 +53,19 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 5); add_edges(g, - {DirectedEdge{n[0], n[3]}, - DirectedEdge{n[3], n[0]}, - DirectedEdge{n[1], n[1]}, - DirectedEdge{n[2], n[1]}, - DirectedEdge{n[1], n[3]}, - DirectedEdge{n[2], n[3]}, - DirectedEdge{n[3], n[2]}, - DirectedEdge{n[2], n[4]}}); - std::unordered_set sub_nodes = {n[0], n[1], n[3]}; + {DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + DirectedEdge{n.at(2), n.at(4)}}); + std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; DiGraphView view = get_subgraph(g, sub_nodes); SUBCASE("get_nodes") { - std::unordered_set expected = {n[0], n[1], n[3]}; + std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; std::unordered_set result = get_nodes(view); @@ -74,10 +74,10 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[3], n[0]}, - DirectedEdge{n[1], n[1]}, - DirectedEdge{n[1], n[3]}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}, + DirectedEdge{n.at(1), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, }; std::unordered_set result = get_edges(view); @@ -94,25 +94,28 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n2 = add_nodes(g2, 3); add_edges(g1, - {UndirectedEdge{{n1[0], n1[1]}}, UndirectedEdge{{n1[1], n1[2]}}}); + {UndirectedEdge{{n1.at(0), n1.at(1)}}, + UndirectedEdge{{n1.at(1), n1.at(2)}}}); add_edges(g2, - {UndirectedEdge{{n2[0], n2[2]}}, UndirectedEdge{{n2[1], n2[2]}}}); + {UndirectedEdge{{n2.at(0), n2.at(2)}}, + UndirectedEdge{{n2.at(1), n2.at(2)}}}); UndirectedGraphView view = join(g1, g2); - // SUBCASE("get_nodes") { - // std::unordered_set expected = - // set_union(unordered_set_of(n1), unordered_set_of(n2)); + SUBCASE("get_nodes") { + std::unordered_set expected = + set_union(unordered_set_of(n1), unordered_set_of(n2)); - // std::unordered_set result = get_nodes(view); + std::unordered_set result = get_nodes(view); - // CHECK(result == expected); - // } + CHECK(result == expected); + } // SUBCASE("get_edges") { // std::unordered_set expected = { - // {n1[0], n1[1]}, {n1[1], n1[2]}, - // {n2[0], n2[2]}, {n2[1], n2[2]} + // UndirectedEdge{{n1.at(0), n1.at(1)}}, UndirectedEdge{{n1.at(1), + // n1.at(2)}}, UndirectedEdge{{n2.at(0), n2.at(2)}}, + // UndirectedEdge{{n2.at(1), n2.at(2)}} // }; // std::unordered_set result = get_edges(view); @@ -128,24 +131,29 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector n1 = add_nodes(g1, 3); std::vector n2 = add_nodes(g2, 3); - add_edges(g1, {DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}}); - add_edges(g2, {DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]}}); + add_edges( + g1, + {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), n1.at(2)}}); + add_edges( + g2, + {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), n2.at(2)}}); DiGraphView view = join(g1, g2); - // SUBCASE("get_nodes") { - // std::unordered_set expected = set_union(unordered_set_of(n1), - // unordered_set_of(n2)); + SUBCASE("get_nodes") { + std::unordered_set expected = + set_union(unordered_set_of(n1), unordered_set_of(n2)); - // std::unordered_set result = get_nodes(view); + std::unordered_set result = get_nodes(view); - // CHECK(result == expected); - // } + CHECK(result == expected); + } // SUBCASE("get_edges") { // std::unordered_set expected = { - // DirectedEdge{n1[0], n1[1]}, DirectedEdge{n1[1], n1[2]}, - // DirectedEdge{n2[0], n2[2]}, DirectedEdge{n2[1], n2[2]} + // DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), + // n1.at(2)}, DirectedEdge{n2.at(0), n2.at(2)}, + // DirectedEdge{n2.at(1), n2.at(2)} // }; // std::unordered_set result = get_edges(view); @@ -158,10 +166,10 @@ TEST_SUITE(FF_TEST_SUITE) { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[0]}, - DirectedEdge{n[0], n[2]}}); + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}); UndirectedGraphView view = as_undirected(g); @@ -175,9 +183,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("get_edges") { std::unordered_set expected = { - UndirectedEdge{{n[0], n[1]}}, - UndirectedEdge{{n[1], n[2]}}, - UndirectedEdge{{n[2], n[0]}}}; + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}; std::unordered_set result = get_edges(view); @@ -189,10 +197,10 @@ TEST_SUITE(FF_TEST_SUITE) { UndirectedGraph g = UndirectedGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, - {UndirectedEdge{{n[0], n[0]}}, - UndirectedEdge{{n[0], n[1]}}, - UndirectedEdge{{n[1], n[2]}}, - UndirectedEdge{{n[2], n[0]}}}); + {UndirectedEdge{{n.at(0), n.at(0)}}, + UndirectedEdge{{n.at(0), n.at(1)}}, + UndirectedEdge{{n.at(1), n.at(2)}}, + UndirectedEdge{{n.at(2), n.at(0)}}}); DiGraphView view = as_digraph(g); @@ -205,13 +213,14 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("get_edges") { - std::unordered_set expected = {DirectedEdge{n[0], n[0]}, - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[0]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[1]}, - DirectedEdge{n[2], n[0]}, - DirectedEdge{n[0], n[2]}}; + std::unordered_set expected = { + DirectedEdge{n.at(0), n.at(0)}, + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(0)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(2), n.at(0)}, + DirectedEdge{n.at(0), n.at(2)}}; std::unordered_set result = get_edges(view); diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc index ecb19c2f7d..84f2364fd8 100644 --- a/lib/utils/test/src/utils/join_strings.cc +++ b/lib/utils/test/src/utils/join_strings.cc @@ -8,25 +8,40 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; + std::vector v = {"Hello", "world", "!"}; SUBCASE("iterator") { - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); + std::string result = join_strings(v.begin(), v.end(), " "); + std::string correct = "Hello world !"; + CHECK(result == correct); } SUBCASE("join_strings with container") { - CHECK(join_strings(v, " ") == "Hello world !"); + std::string result = join_strings(v, " "); + std::string correct = "Hello world !"; + CHECK(result == correct); } SUBCASE("join_strings with transforming function") { auto add_exclamation = [](std::string const &str) { return str + "!"; }; - CHECK(join_strings(v, " ", add_exclamation) == "Hello! world! !!"); + std::string result = join_strings(v, " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); } SUBCASE("join_strings with transforming function, iterator") { auto add_exclamation = [](std::string const &str) { return str + "!"; }; - CHECK(join_strings(v.begin(), v.end(), " ", add_exclamation) == - "Hello! world! !!"); + std::string result = + join_strings(v.begin(), v.end(), " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); + } + + SUBCASE("empty sequence") { + v = {}; + std::string result = join_strings(v, "!"); + std::string correct = ""; + CHECK(result == correct); } } } diff --git a/lib/utils/test/src/utils/stack_map.cc b/lib/utils/test/src/utils/stack_map.cc index 57ef1c850a..20659d49a5 100644 --- a/lib/utils/test/src/utils/stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -1,13 +1,15 @@ #include "utils/stack_map.h" #include "test/utils/doctest.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/vector.h" using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("stack_map") { stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { + + SUBCASE("bracket operator") { map[1] = 10; map[2] = 20; @@ -15,8 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the insert() function - SUBCASE("Insert") { + SUBCASE("insert") { map.insert(1, 10); map.insert(2, 20); @@ -24,21 +25,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the at() function to access elements - SUBCASE("At") { + SUBCASE("at") { map[1] = 10; map[2] = 20; CHECK(map.at(1) == 10); CHECK(map.at(2) == 20); CHECK(map.at(1) != 20); - // Test const version of at() function + stack_map const &const_map = map; CHECK(const_map.at(1) == 10); CHECK(const_map.at(2) == 20); } - // Test the begin() and end() functions for iterator SUBCASE("Iterator") { map[1] = 10; map[2] = 20; @@ -46,7 +45,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; std::vector> actual = map; - CHECK_WITHOUT_STRINGIFY(actual == expected); + CHECK(actual == expected); } } } diff --git a/lib/utils/test/src/utils/stack_string.cc b/lib/utils/test/src/utils/stack_string.cc index 01bfa60d15..08659375d1 100644 --- a/lib/utils/test/src/utils/stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -49,11 +49,11 @@ TEST_SUITE(FF_TEST_SUITE) { StackString str2{"def"}; StackString str3{"abc"}; - CHECK_WITHOUT_STRINGIFY(str1 == str1); - CHECK_WITHOUT_STRINGIFY(str1 == str3); - CHECK_WITHOUT_STRINGIFY(str1 != str2); - CHECK_WITHOUT_STRINGIFY(str2 != str3); - CHECK_WITHOUT_STRINGIFY(str1 < str2); + CHECK(str1 == str1); + CHECK(str1 == str3); + CHECK(str1 != str2); + CHECK(str2 != str3); + CHECK(str1 < str2); } TEST_CASE_TEMPLATE("StackStringSize", T, char) { diff --git a/lib/utils/test/src/utils/stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc index 775f088e28..68e2af3d12 100644 --- a/lib/utils/test/src/utils/stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -7,7 +7,8 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::push_back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -23,7 +24,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator[]", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -37,7 +39,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector[2] == 30); } - TEST_CASE_TEMPLATE("Size", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::size", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -51,7 +53,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.size() == 2); } - TEST_CASE_TEMPLATE("==", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator==", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector1, vector2; @@ -64,10 +67,10 @@ TEST_SUITE(FF_TEST_SUITE) { vector2.push_back(15); vector2.push_back(20); - CHECK_WITHOUT_STRINGIFY(vector1 == vector2); + CHECK(vector1 == vector2); } - TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -79,7 +82,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.back() == 20); } - TEST_CASE_TEMPLATE("Arbitrary", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector - check for size bound", T, int, double, char) { constexpr std::size_t MAXSIZE = 10; RC_SUBCASE("within bound", [&](stack_vector v) { RC_ASSERT(v.size() <= MAXSIZE); diff --git a/lib/utils/test/src/utils/tuple.cc b/lib/utils/test/src/utils/tuple.cc index bf822cf871..e01d0c333a 100644 --- a/lib/utils/test/src/utils/tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -1,6 +1,5 @@ #include "utils/tuple.h" #include "test/utils/doctest.h" - #include #include @@ -37,23 +36,23 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); + TEST_CASE("tuple_prepend") { + std::tuple t1 = {3.14f, 2.71828}; int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(tuple_compare(result, expected)); + std::tuple result = tuple_prepend(value, t1); + std::tuple correct = {42, 3.14f, 2.71828}; + CHECK(tuple_compare(result, correct)); } - TEST_CASE("Testing tuple_head_t") { + TEST_CASE("tuple_head_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, std::tuple<>>::value); } - TEST_CASE("Testing tuple_slice_t") { + TEST_CASE("tuple_slice_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, @@ -62,17 +61,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::tuple>::value); } - TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; + TEST_CASE("tuple_compare") { + std::tuple tup1 = {1, 3.14, 'a'}; + std::tuple tup2 = {1, 3.14, 'a'}; + std::tuple tup3 = {2, 3.14, 'b'}; CHECK(tuple_compare(tup1, tup2)); CHECK(!tuple_compare(tup1, tup3)); } - TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("get") { + std::tuple tup = {1, 3.14, 'a'}; CHECK(get(tup) == 1); CHECK(get(tup) == 3.14); diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc index 873ed33ae7..a92b9578ea 100644 --- a/lib/utils/test/src/utils/type_index.cc +++ b/lib/utils/test/src/utils/type_index.cc @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("matches function") { + TEST_CASE("matches(std::type_index)") { std::type_index idx = typeid(float); SUBCASE("matching type") { From 3c711ac43c339e5384d2ad0cf1bf9eab192d8815 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 14:39:12 -0700 Subject: [PATCH 16/21] moved graph-testing to other PR --- .../src/substitutions/pcg_pattern_match.cc | 4 +- .../output_expr_to_result_sub_pcg_mapping.cc | 4 +- ...rge_bidicts.h => merge_disjoint_bidicts.h} | 11 +- .../utils/containers/are_all_distinct.h | 1 - .../include/utils/containers/get_one_of.h | 5 +- lib/utils/include/utils/containers/map_keys.h | 13 +- lib/utils/include/utils/containers/subvec.h | 12 +- .../include/utils/containers/transform.h | 12 +- lib/utils/include/utils/fmt/expected.h | 2 +- lib/utils/include/utils/fmt/map.h | 2 +- lib/utils/include/utils/fmt/multiset.h | 2 +- lib/utils/include/utils/fmt/optional.h | 2 +- lib/utils/include/utils/fmt/pair.h | 2 +- lib/utils/include/utils/fmt/set.h | 2 +- .../include/utils/fmt/unordered_multiset.h | 2 +- lib/utils/include/utils/fmt/unordered_set.h | 2 +- lib/utils/include/utils/fmt/variant.h | 2 +- lib/utils/include/utils/fmt/vector.h | 2 +- lib/utils/include/utils/graph/README.md | 168 +++++++++---- .../graph/dataflow_graph/dataflow_graph.h | 2 + .../graph/digraph/algorithms/get_dominators.h | 15 -- .../include/utils/graph/digraph/digraph.h | 2 + .../utils/graph/digraph/digraph_view.h | 2 + .../utils/graph/multidigraph/multidigraph.h | 2 + lib/utils/include/utils/graph/node/graph.h | 2 + .../include/utils/graph/node/graph_view.h | 2 + .../include/utils/graph/node/node.struct.toml | 1 - .../utils/graph/undirected/undirected_edge.h | 27 +- .../undirected/undirected_edge.struct.toml | 18 -- .../utils/graph/undirected/undirected_graph.h | 2 + .../graph/undirected/undirected_graph_view.h | 2 + lib/utils/include/utils/stack_map.h | 11 - lib/utils/include/utils/stack_string.h | 18 +- lib/utils/include/utils/stack_vector.h | 11 - .../utils/bidict/algorithms/merge_bidicts.cc | 1 - .../algorithms/merge_disjoint_bidicts.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 5 +- .../instances/hashmap_undirected_graph.cc | 37 +-- .../utils/graph/undirected/undirected_edge.cc | 33 ++- lib/utils/src/utils/graph/views/views.cc | 47 ++-- ...e_bidicts.cc => merge_disjoint_bidicts.cc} | 12 +- lib/utils/test/src/utils/bidict/bidict.cc | 46 ++-- lib/utils/test/src/utils/containers/find.cc | 8 +- .../test/src/utils/containers/get_one_of.cc | 2 +- .../test/src/utils/containers/index_of.cc | 8 +- .../src/utils/containers/product_where.cc | 9 + .../test/src/utils/containers/sorted_by.cc | 2 +- lib/utils/test/src/utils/containers/subvec.cc | 8 +- .../test/src/utils/containers/sum_where.cc | 13 +- .../src/utils/graph/digraph/algorithms.cc | 106 -------- .../digraph/algorithms/get_dominators.cc | 68 ------ .../test/src/utils/graph/digraph/digraph.cc | 85 ------- .../graph/digraph/directed_edge_query.cc | 70 ------ .../graph/digraph/get_topological_ordering.cc | 34 --- .../get_weakly_connected_components.cc | 23 -- .../algorithms/get_incoming_edges.cc | 36 --- .../algorithms/get_outgoing_edges.cc | 40 --- lib/utils/test/src/utils/graph/traversal.cc | 110 --------- .../algorithms/get_connected_components.cc | 24 -- .../graph/undirected/undirected_graph.cc | 65 ----- lib/utils/test/src/utils/graph/views/views.cc | 230 ------------------ lib/utils/test/src/utils/random_utils.cc | 2 +- 62 files changed, 327 insertions(+), 1165 deletions(-) rename lib/utils/include/utils/bidict/algorithms/{merge_bidicts.h => merge_disjoint_bidicts.h} (71%) delete mode 100644 lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml delete mode 100644 lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc rename lib/utils/test/src/utils/bidict/algorithms/{merge_bidicts.cc => merge_disjoint_bidicts.cc} (74%) delete mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms.cc delete mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc delete mode 100644 lib/utils/test/src/utils/graph/digraph/digraph.cc delete mode 100644 lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc delete mode 100644 lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc delete mode 100644 lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc delete mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc delete mode 100644 lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc delete mode 100644 lib/utils/test/src/utils/graph/traversal.cc delete mode 100644 lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc delete mode 100644 lib/utils/test/src/utils/graph/undirected/undirected_graph.cc delete mode 100644 lib/utils/test/src/utils/graph/views/views.cc diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index f1f4e31d57..b701be65cf 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -2,7 +2,7 @@ #include "substitutions/pcg_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/map_values.h" #include "utils/containers/zip.h" @@ -27,7 +27,7 @@ bidict bidict_from_keys_and_values(pattern_node_outputs, matched_layer_output_tensors); - result = merge_bidicts(result, mapping); + result = merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc index 083334f0db..22e6a9f333 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" namespace FlexFlow { @@ -23,7 +23,7 @@ bidict mapping_for_layer = bidict_from_keys_and_values( layer_outputs, output_graph_expr_outputs); - result = merge_bidicts(result, mapping_for_layer); + result = merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h similarity index 71% rename from lib/utils/include/utils/bidict/algorithms/merge_bidicts.h rename to lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index efb1e74cd0..60637349c7 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" @@ -10,14 +10,15 @@ namespace FlexFlow { template -bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { +bidict merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { throw mk_runtime_error( - "Left entries of merge_bidicts parameters are non-disjoint"); + fmt::format("Left entries of {} and {} are non-disjoint"), lhs, rhs); } if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { throw mk_runtime_error( - "Right entries of merge_bidicts parameters are non-disjoint"); + fmt::format("Right entries of {} and {} are non-disjoint"), lhs, rhs); } bidict result; diff --git a/lib/utils/include/utils/containers/are_all_distinct.h b/lib/utils/include/utils/containers/are_all_distinct.h index 0990d76ab7..d02845ba16 100644 --- a/lib/utils/include/utils/containers/are_all_distinct.h +++ b/lib/utils/include/utils/containers/are_all_distinct.h @@ -3,7 +3,6 @@ #include "utils/containers/unordered_multiset_of.h" #include "utils/containers/unordered_set_of.h" -#include "utils/exception.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/get_one_of.h b/lib/utils/include/utils/containers/get_one_of.h index eb772d917a..47c46fb1d6 100644 --- a/lib/utils/include/utils/containers/get_one_of.h +++ b/lib/utils/include/utils/containers/get_one_of.h @@ -2,14 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H #include "utils/exception.h" +#include "utils/fmt/unordered_set.h" #include - namespace FlexFlow { template T get_one_of(std::unordered_set const &s) { if (s.empty()) { - throw mk_runtime_error("input to are_all_same must be non-empty container"); + throw mk_runtime_error(fmt::format( + "get_one_of expected non-empty container but receieved {}", s)); } return *s.cbegin(); } diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index d3ef658c32..4e5352748d 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H -#include "utils/containers/are_all_distinct.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" +#include "utils/exception.h" #include #include @@ -20,17 +20,16 @@ template > std::unordered_map map_keys(std::unordered_map const &m, F const &f) { - std::unordered_multiset transformed_keys = - transform(unordered_multiset_of(keys(m)), f); - if (!are_all_distinct(transformed_keys)) { - throw mk_runtime_error(fmt::format( - "keys passed to map_keys must be transformed into distinct keys")); - } std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } + if (keys(m).size() != keys(result).size()) { + throw mk_runtime_error( + "keys passed to map_keys must be transformed into distinct keys"); + } + return result; } diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 21bec9ad24..13586d21da 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -1,7 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H +#include "utils/exception.h" #include +#include #include namespace FlexFlow { @@ -15,11 +17,15 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { + int size = static_cast(v.size()); + int new_idx = idx; if (idx < 0) { - return std::max(0, static_cast(v.size()) + idx); - } else { - return std::min(idx, static_cast(v.size())); + new_idx = size + idx; } + if (new_idx < 0 || new_idx > size) { + throw mk_runtime_error("Index {} is out of bounds for array {}"); + } + return new_idx; }; if (maybe_start.has_value()) { diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index 02dadb5352..ec4185f822 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -22,9 +22,7 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template ()(std::declval()))> +template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; for (auto const &e : v) { @@ -33,9 +31,7 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } -template ()(std::declval()))> +template > std::unordered_multiset transform(std::unordered_multiset const &v, F const &f) { std::unordered_multiset result; @@ -45,9 +41,7 @@ std::unordered_multiset transform(std::unordered_multiset const &v, return result; } -template ()(std::declval()))> +template > std::set transform(std::set const &v, F const &f) { std::set result; for (auto const &e : v) { diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 21a6d28ca2..d339e6e251 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::tl::expected const &m, FormatContext &ctx) + auto format(::tl::expected const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result; diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 8e186928fd..f01a8d63df 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -18,7 +18,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::map const &m, FormatContext &ctx) + auto format(::std::map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index cff150dc29..b60471d8a8 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::multiset const &m, FormatContext &ctx) + auto format(::std::multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 2c6067cdb5..4917454a82 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::optional const &m, FormatContext &ctx) + auto format(::std::optional const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 6f7e6f6b52..218764e485 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::pair const &m, FormatContext &ctx) + auto format(::std::pair const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(L); CHECK_FMTABLE(R); diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index 1f8012f240..9dd3438fae 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -17,7 +17,7 @@ struct formatter<::std::set, std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::set const &m, FormatContext &ctx) + auto format(::std::set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 41abbc925e..2648bc6589 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_multiset const &m, FormatContext &ctx) + auto format(::std::unordered_multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 646ef0c7c5..4000a2f098 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -17,7 +17,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_set const &m, FormatContext &ctx) + auto format(::std::unordered_set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 867577f72a..81786f9cb0 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -12,7 +12,7 @@ struct formatter, Char> /* std::enable_if_t>::value>> */ : formatter<::std::string> { template - auto format(std::variant const &m, FormatContext &ctx) + auto format(std::variant const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result = diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 96526175a8..b46ac9423a 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::vector const &m, FormatContext &ctx) + auto format(::std::vector const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/graph/README.md b/lib/utils/include/utils/graph/README.md index f3d31e7bc8..25b0103f9c 100644 --- a/lib/utils/include/utils/graph/README.md +++ b/lib/utils/include/utils/graph/README.md @@ -15,16 +15,10 @@ There is no single type of graph. Should it be directed? Allow multiple edges be Because there is no single answer to this question, similar to [networkx](https://networkx.org/) we provide a number of different graph variants. At their core, they are as follows: -- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected. -- `DiGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) -- `MultiDiGraph`: arbitrary numbers of directed edges allowed between every pair of nodes. -- `DataflowGraph`: similar to `MultiDiGraph`, but with the following differences: - - The edges entering, exiting a given nodes now have a well-defined order. - - Due to the interface used to construct them (where essentially a node can only be added to the graph after all of its predecessor nodes have been added) `DataflowGraph`s are directed acyclic graphs. - - Each node has an associated ordered sequence of inputs and outputs, with the restriction that one and only one edge can enter an individual input. - -Conceptually, `DataflowGraph` is used within FlexFlow to represent computation-style graphs, where edges represent value uses and nodes represent multivariate functions from tuples of inputs to tuples of outputs. - +- `UndirectedGraph`: at most one edge allowed between every pair of nodes, edges are undirected +- `DirectedGraph`: at most one edge allowed between every ordered pair of nodes, edges are directed (i.e., have a source node and a destination node) +- `MultiDiGraph`: arbitrary numbers of edges allowed between every pair of nodes, but each must have not only source/destination nodes but also _source/destination indices_, which serve to disambiguate different edges between the same nodes. There can exist at most one edge for every ordered tuple of source node, destination node, source index, and destination index. + Examples of the different graph variants are shown below. Example of `UndirectedGraph`: @@ -43,7 +37,7 @@ flowchart TD D --- B ``` -Example of `DiGraph`: +Example of `DirectedGraph`: ```mermaid flowchart TD A(" ") @@ -64,34 +58,98 @@ flowchart TD Example of `MultiDiGraph`: ```mermaid flowchart TD - A - B - C - D - E - F - - A --> B - B --> C - C --> D - D --> A - B --> E - E --> B - D --> A - A --> E - D --> D - E --> E + A("A") + B("B") + C("C") + D("D") + E("E") + F("F") + + A -->|"(■, ★)"| B + B -->|"(●, ★)"| C + C -->|"(♥, ▲)"| D + D -->|"(●, ■)"| A + B -->|"(★, ●)"| E + E -->|"(■, ■)"| B + D -->|"(●, ●)"| A + A -->|"(●, ■)"| E + D -->|"(■, ●)"| D + E -->|"(■, ■)"| E +``` +or visualized a different way, +```mermaid +flowchart TD + Acirc("●") + Asqua("■") + Bcirc("●") + Bstar("★") + Bsqua("■") + Chear("♥") + Cstar("★") + Dsqua("■") + Dcirc("●") + Dtria("▲") + Ecirc("●") + Esqua("■") + Fplaceholder(" ") + + style Fplaceholder fill:#0000,stroke:#0000 + + subgraph "A" + Acirc + Asqua + end + + subgraph "B" + Bsqua + Bcirc + Bstar + end + + subgraph "C" + Chear + Cstar + end + + subgraph "D" + Dsqua + Dcirc + Dtria + end + + subgraph "E" + Ecirc + Esqua + end + + subgraph "F" + Fplaceholder + end + + Asqua --> Bstar + Bcirc --> Cstar + Chear --> Dtria + Dcirc --> Asqua + Bstar --> Ecirc + Esqua --> Bsqua + Dcirc --> Acirc + Acirc --> Esqua + Dsqua --> Dcirc + Esqua --> Esqua ``` -Note that the node names are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. -This is the case with all of the 4 core graph classes. +Note that the nodes and source/destination indices are just nameless things: they have no apparent ordering or other meaning besides representing the topology of the graph. +This is the case as well with `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`. Nodes are of type `Node`, and from a user perspective are simply opaque handles, and source and destination indices should similarly be considered opaque from a user point of view. In addition, nodes should only be used in the context of their graph, so comparing or checking equality of nodes between different graphs (even of the same type) is undefined behavior[^1]. All three core graph variants allow insertion and deletion of both edges and nodes. To add a node to an `UndirectedGraph g`, simply call `g.add_node()` (the interface is identical for `DiGraph` and `MultiDiGraph`). To add an edge between two nodes `Node n1` and `Node n2` to an `UndirectedGraph g`, call `g.add_edge({n1, n2})`. -In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph`, `MultiDiGraph` and `DataflowGraph`. +In `UndirectedGraph` the order of the arguments of `add_edge` doesn't matter as edges are undirected, but the order does matter for `DiGraph` and `MultiDiGraph`. +`MultiDiGraph::add_edge` takes in two additional arguments of type `NodePort`, specifying the source and destination indices. +Similar to `Node`s, `NodePort`s can be generated via `g.add_node_port()`. +`NodePort:` an opaque object used within `MultiDiGraph` to disambiguate between multiple edges. `MultiDiGraph` will be able to distinguish between 2 edges that share the same source and destination as long as at at least one `NodePort` differs. Within the context of a PCG, `NodePorts` must be thought of as the various inputs and outputs of a single node. The last paragraph covered the base API used to write to graphs, but we also want to be able to read from graphs. Reading from graphs is implemented with the `query_nodes` and `query_edges` methods, which can be thought of as executing a database query over the nodes and edges of the target graph, respectively (where queries are restricted to an incredibly simple set of operations). @@ -100,13 +158,11 @@ The argument to `query_nodes` is a `NodeQuery` (which is simply a set of `Node`s The set of nodes in the query is actually an `optional`, so `nullopt` could also be passed, which would simply retrieve all nodes from the target graph (essentially `nullopt` acts as the set of all nodes that could ever exist). `query_edges` functions similarly, but as with `add_edge` its behavior is differs slightly between the three graph variants. `UndirectedGraph::query_edges` simply takes an optional set of nodes and returns all edges that touch any of those nodes. -`DiGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. +`DirectedGraph::query_edges` allows separate sets for source and destination nodes, and `MultiDiGraph::query_edges` adds the ability to filter by source and destination indices as well. In practice you will rarely ever use `query_nodes` and `query_edges` as the graph library provides a large number of algorithms that do that work for you, but it can be helpful to understand this base layer if you ever need to implement your own algorithms. -The layer users will most commonly interact with is the interface provided within either the `algorithms.h` header files or the `algorithms` folders, present in their respective graph class folders. -They provide a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. -Note that, due to the internal virtual inheritance structure, some functions for more privitive classes can be employed by the derived classes. (For example, `get_nodes` present in `node/algorithms.h` can be used by `DiGraph`). -You may notice that the most of algorithms present take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but rather `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. +The layer users will most commonly interact with is the interface provided by [algorithms.h](./algorithms.h), which provides a large number of pre-implemented algorithms on graphs, ranging from as simple as `get_nodes` to as complex as `get_transitive_reduction` and `get_dominators`. +You may notice that the most of the functions declared in `algorithms.h` take as arguments not `UndirectedGraph`, `DiGraph`, and `MultiDiGraph`, but actually operator on `UndirectedGraphView`, `DiGraphView`, and `MultiDiGraphView`. These `GraphView` objects represent read-only (i.e., immutable) graphs. Similar to C++'s `const` semantics, `Graph`s can be coerced[^2] to `GraphView`s but not the other way around. To transform a `GraphView` to a `Graph`, we can perform an explicit copy with `materialize_view`. @@ -115,35 +171,40 @@ This may seem wasteful (oftentimes graphs are large objects that are passed arou At this point, however, we still have not discussed how to create a graph. The user-facing graph interface is intentially separated from the underlying graph representations, so representations can be changed without requiring any user-side code modifications besides the choice of which implementation to use. -For example, to construct a `DiGDiraph` which internally uses a representation such as `AdjacencyDiGraph` we do the following: +For example, to construct a `DiGraph` which internally uses a representation `MyDiGraphImpl`: ```cpp -DiGraph g = DiGraph::create(); +DiGraph g = DiGraph::create(); ``` Generally users will use underlying representations provided by the graph library, but advanced users can create their own implementations (see the [Internals](#internals) section). [^1]: At some point we will likely add actual runtime checks on this, but for now we rely on the user not to mess up. Currently the implementation will keep going silently until the incorrectness grows so large that something breaks/crashes. [^2]: See if you're not familiar with the term _type coercion_ -### Open DataFlow Variant +### Open, Upward, Downward `Open` is to be intended similarly to the topological sense: that is, a graph that contains some edges where one of the 2 nodes is not present in the graph itself. -This graph class is particularly useful for processing a sub-graph of a given graph while still maintaining information regarding the edges that cross the cut. +We can further specify the "openeness" of a **directed** graph by specifying whether they are `UpwardOpen` (so some of the incoming edges are open) or `DownwardOpen` (so some of the outgoing edges are open). -### Labelled Dataflow Variant +![Open graphs inheritance diagram](docs/open.svg) -As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. -Thus, FlexFlow's graph library provides the ability to add labels to `DataflowGraph`, through the `LabelleledDataflowGraph` and `OpenLabelleledDataflowGraph`, which allow users to label different components of the graph. -- `LabelledDataflowGraph` allows for labelling of `Node`s and `DataflowOutput`s. -- `OpenLabelledDataflowGraph` allows for labelling of `Node`s and `OpenDataflowValue`s, which is a variant describing both `DataflowOutput`s and `DataflowGraphInput`s, which represent the open inputs to the graph (i.e. the inputs for which their corresponding output is not present in the graph). +Arrows with pointed tips indicate inheritance, while arrows with square tips indicate that the pointing class has a 'cow_ptr' of the type of the pointed class. (for more info, see [cow_ptr](#cow_ptr-and-interfaces)) -While the interfaces of these graphs differ slightly from the core graph variants, they still have the corresponding `add_node` methods, and `query_nodes`/`query_edges` methods. (Note that there is no `add_edge` method since, for `DataflowGraph`, edges are implicitly added when we add a node and specify its predecessors) -Note that all of the labelled graph types require that each element of the labelled types have a label, which is enforced via the interfaces they provide. + +### Labelled Graphs + +As nice as all of the above is, graphs without labels are mostly useless--in practice, nodes and edges represent some other system and the properties of that system (or at least a way to map the result of graph algorithms back to the underlying system) are necessary. +Thus, FlexFlow's graph library provides the ability to add labels via [labelled\_graphs.h](./labelled_graphs.h): examples include `NodeLabelledMultiDiGraph` (nodes have labels of type `T` and edges are unlabelled) and `OutputLabelledMultiDiGraph` (nodes have labels of type `T` and source indices have labels of type `U`). +While the interfaces of these graphs differ slightly from the core graph variants, they still have corresponding `GraphView` types, `add_node`/`add_edge` methods, and `query_nodes`/`query_edges` methods. +Note that all of the labelled graph types require that each element of the labelled types have a label (e.g., every node in a `NodeLabelledMultiDiGraph` must have a label of type `T`)., which is enforced via the interfaces they provide. Partial labelling can be implement via wrapping the label type in `optional`. -Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes/edges to labels. -As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants. +Interacting with `Node` and `Edge` objects is still necessary to use the labelled graph types: intuitively the labelled graph types can be thought of as a pair of a core graph variant and a hash map the maps nodes (or other types depending in which labelled graph type is used) to labels. +As such, the labelled graph types provide the typical `at` method (as on `std::unordered_map`[^3]) and can be coerced to their underlying core graph variants for use in functions provided by `algorithms.h`, etc. [^3]: `operator[]` currently is not present because all nodes must have labels and we don't require label types to be default constructible, though some simple template programming could probably add `operator[]` support in the cases where the label types _are_ default constructible. +![Labelled Graphs Inheritance Diagram](docs/labelled.svg) + + ## Internals @@ -175,7 +236,12 @@ To address this, graph classes store a `cow_ptr` as a member variable, which poi All member functions present in `ClassName` and `ClassNameView` delegate their calls to their corresponding interface classes (which implement the actual logic), meaning that these classes essentially act as wrappers to their interface counterparts. +To create graphs within the library, we thus use the following syntax: +`BaseGraph obj = BaseGraph::create();` + +Resulting in an object that, while of type `BaseGraph`, can access at runtime the member functions defined in `DerivedGraph` + ### Virtual Inheritance -Due to the complexity of the graph library, diamond-style inheritance patterns emerge. -In the case of a diamond inheritance pattern, C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. +Due to the complexity of the graph library, diamond-style inheritance patterns emerge (consider, for example, the `OutputLabelledOpenMultiDiGraphView` class, which inherits from both `NodeLabelledOpenMultiDiGraphView` and `OutputLabelledMultiDiGraphView`, which in turn inherit from both `NodeLabelledMultiDiGraphView`). +In the case of a diamond inheritance pattern C++ will instantiate multiple copies of the base class whenever we instantiate a derived class. To address this issue, we employ [Virtual Inheritance](https://en.wikipedia.org/wiki/Virtual_inheritance), which removes the ambiguity associated with the multiple copies. diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index d73175c7dd..6a1898dd13 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -42,6 +42,8 @@ struct DataflowGraph : virtual public DataflowGraphView { private: IDataflowGraph &get_interface(); IDataflowGraph const &get_interface() const; + + friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h index 96e8864bc1..1e4d09d3ae 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_dominators.h @@ -5,22 +5,7 @@ namespace FlexFlow { -/** - * @brief See https://en.wikipedia.org/wiki/Dominator_(graph_theory) - * - * @note By definition, the root node dominates every node and every node - * dominates itself. - * - */ std::unordered_set get_dominators(DiGraphView const &, Node const &); - -/** - * @brief Returns the intersection of the dominators of the given set of nodes. - * @note This is conceptually equivalent to merging the given set of nodes and - * then finding the set of dominators of the new merged node (where merged means - * that all edges belonging to the set of nodes now pass through a single - * unified node). - */ std::unordered_set get_dominators(DiGraphView const &, std::unordered_set const &); diff --git a/lib/utils/include/utils/graph/digraph/digraph.h b/lib/utils/include/utils/graph/digraph/digraph.h index 3d320b1c06..e36b90d4bf 100644 --- a/lib/utils/include/utils/graph/digraph/digraph.h +++ b/lib/utils/include/utils/graph/digraph/digraph.h @@ -40,6 +40,8 @@ struct DiGraph : virtual DiGraphView { private: IDiGraph &get_ptr(); IDiGraph const &get_ptr() const; + + friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraph); diff --git a/lib/utils/include/utils/graph/digraph/digraph_view.h b/lib/utils/include/utils/graph/digraph/digraph_view.h index 0380751c55..54f84f8d2c 100644 --- a/lib/utils/include/utils/graph/digraph/digraph_view.h +++ b/lib/utils/include/utils/graph/digraph/digraph_view.h @@ -31,6 +31,8 @@ struct DiGraphView : virtual public GraphView { private: IDiGraphView const &get_ptr() const; + + friend struct GraphInternal; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(DiGraphView); diff --git a/lib/utils/include/utils/graph/multidigraph/multidigraph.h b/lib/utils/include/utils/graph/multidigraph/multidigraph.h index 692ee33783..69080b9348 100644 --- a/lib/utils/include/utils/graph/multidigraph/multidigraph.h +++ b/lib/utils/include/utils/graph/multidigraph/multidigraph.h @@ -40,6 +40,8 @@ struct MultiDiGraph : virtual public MultiDiGraphView { private: IMultiDiGraph &get_interface(); IMultiDiGraph const &get_interface() const; + + friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph.h b/lib/utils/include/utils/graph/node/graph.h index 1d94d1a65e..bddefdacb3 100644 --- a/lib/utils/include/utils/graph/node/graph.h +++ b/lib/utils/include/utils/graph/node/graph.h @@ -31,6 +31,8 @@ struct Graph : virtual GraphView { private: IGraph const &get_ptr() const; IGraph &get_ptr(); + + friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/graph_view.h b/lib/utils/include/utils/graph/node/graph_view.h index 8d904e05f2..fce3177ef1 100644 --- a/lib/utils/include/utils/graph/node/graph_view.h +++ b/lib/utils/include/utils/graph/node/graph_view.h @@ -22,6 +22,8 @@ struct GraphView { GraphView(); cow_ptr_t ptr; GraphView(cow_ptr_t ptr); + + friend struct GraphInternal; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index cf7309c5a6..0b6f348ddf 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,7 +5,6 @@ features = [ "ord", "hash", "fmt", - "rapidcheck" ] includes = [ diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.h b/lib/utils/include/utils/graph/undirected/undirected_edge.h index d051413faa..33d50192cb 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge.h @@ -2,12 +2,33 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_H #include "utils/graph/node/node.dtg.h" -#include "utils/graph/undirected/undirected_edge.dtg.h" - namespace FlexFlow { -bool is_connected_to(UndirectedEdge const &e, Node const &n); +struct UndirectedEdge { +public: + UndirectedEdge() = delete; + UndirectedEdge(Node const &src, Node const &dst); + + bool operator==(UndirectedEdge const &) const; + bool operator!=(UndirectedEdge const &) const; + bool operator<(UndirectedEdge const &) const; + +public: + Node smaller; + Node bigger; +}; + +bool is_connected_to(UndirectedEdge const &, Node const &); } // namespace FlexFlow +namespace std { + +template <> +struct hash<::FlexFlow::UndirectedEdge> { + size_t operator()(::FlexFlow::UndirectedEdge const &) const; +}; + +} // namespace std + #endif diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml b/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml deleted file mode 100644 index f5258b0bfd..0000000000 --- a/lib/utils/include/utils/graph/undirected/undirected_edge.struct.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "UndirectedEdge" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck" -] - -includes = [ - "utils/commutative_pair.h", - "utils/graph/node/node.dtg.h", -] - -[[fields]] -name = "endpoints" -type = "::FlexFlow::commutative_pair<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph.h b/lib/utils/include/utils/graph/undirected/undirected_graph.h index 09b6495699..69975991ce 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph.h @@ -34,6 +34,8 @@ struct UndirectedGraph : virtual UndirectedGraphView { using UndirectedGraphView::UndirectedGraphView; + friend struct GraphInternal; + private: IUndirectedGraph const &get_ptr() const; IUndirectedGraph &get_ptr(); diff --git a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h index 90dd5dd5d8..c2df96abc0 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_graph_view.h +++ b/lib/utils/include/utils/graph/undirected/undirected_graph_view.h @@ -29,6 +29,8 @@ struct UndirectedGraphView : virtual GraphView { using GraphView::GraphView; + friend struct GraphInternal; + private: IUndirectedGraphView const &get_ptr() const; }; diff --git a/lib/utils/include/utils/stack_map.h b/lib/utils/include/utils/stack_map.h index ae8ed5cc52..c70842de7e 100644 --- a/lib/utils/include/utils/stack_map.h +++ b/lib/utils/include/utils/stack_map.h @@ -130,15 +130,4 @@ CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(stack_map); } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(FlexFlow::stack_map const &map) { - return toString(fmt::to_string(map)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index bdc10a98db..321e66bac8 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -82,6 +82,13 @@ void from_json(json const &j, stack_string &v) { v = stack_string{as_string}; } +template +std::basic_ostream & + operator<<(std::basic_ostream &s, + stack_basic_string const &v) { + return s << fmt::to_string(v); +} + } // namespace FlexFlow namespace std { @@ -134,15 +141,4 @@ CHECK_HASHABLE(stack_string<1>); } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(FlexFlow::stack_string const &s) { - return toString(fmt::to_string(s)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 21bcdf34ed..1d654e3415 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -369,15 +369,4 @@ struct Arbitrary<::FlexFlow::stack_vector> { } // namespace rc -namespace doctest { - -template -struct StringMaker> { - static String convert(FlexFlow::stack_vector const &v) { - return toString(fmt::to_string(v)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc deleted file mode 100644 index f70be2355f..0000000000 --- a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/bidict/algorithms/merge_bidicts.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..754b8d2e90 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 4c5d61b9b9..323f444a22 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -184,8 +184,7 @@ bool contains_edge(DiGraphView const &g, DirectedEdge const &e) { } bool contains_edge(UndirectedGraphView const &g, UndirectedEdge const &e) { - UndirectedEdgeQuery q = - UndirectedEdgeQuery{{e.endpoints.max(), e.endpoints.min()}}; + UndirectedEdgeQuery q = UndirectedEdgeQuery{{e.bigger, e.smaller}}; return contains(g.query_edges(q), e); } @@ -213,7 +212,7 @@ void remove_edges(UndirectedGraph &g, } std::unordered_set get_endpoints(UndirectedEdge const &e) { - return {e.endpoints.min(), e.endpoints.max()}; + return {e.smaller, e.bigger}; } // std::unordered_set get_edges(MultiDiGraphView const &g) { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 649d372dba..92a6d0b9eb 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -1,5 +1,4 @@ #include "utils/graph/instances/hashmap_undirected_graph.h" -#include "utils/containers/contains.h" #include "utils/containers/contains_key.h" #include "utils/containers/keys.h" #include "utils/exception.h" @@ -23,45 +22,33 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { } void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { - if (!contains_key(this->adjacency, e.endpoints.min())) { + if (!contains_key(this->adjacency, e.bigger)) { throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", - e.endpoints.min()); + "Could not add edge connected to non-existent node {}", e.bigger); } - if (!contains_key(this->adjacency, e.endpoints.max())) { + if (!contains_key(this->adjacency, e.smaller)) { throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", - e.endpoints.max()); + "Could not add edge connected to non-existent node {}", e.smaller); } - this->adjacency.at(e.endpoints.min()).insert(e.endpoints.max()); - this->adjacency.at(e.endpoints.max()).insert(e.endpoints.min()); + this->adjacency.at(e.bigger).insert(e.smaller); + this->adjacency.at(e.smaller).insert(e.bigger); } void HashmapUndirectedGraph::remove_edge(UndirectedEdge const &e) { - this->adjacency.at(e.endpoints.max()).erase(e.endpoints.min()); - this->adjacency.at(e.endpoints.min()).erase(e.endpoints.max()); + std::unordered_set &m = this->adjacency.at(e.bigger); + m.erase(e.smaller); + m.erase(e.bigger); } std::unordered_set HashmapUndirectedGraph::query_edges( UndirectedEdgeQuery const &query) const { std::unordered_set result; - std::unordered_set nodes = - keys(query_keys(query.nodes, this->adjacency)); - for (auto const &[src, n] : query_keys(query.nodes, this->adjacency)) { - for (auto const &dst : n) { - result.insert(UndirectedEdge{{src, dst}}); + for (auto const &src_kv : query_keys(query.nodes, this->adjacency)) { + for (auto const &dst : src_kv.second) { + result.insert({src_kv.first, dst}); } } - for (auto const &[src, n] : - query_keys(query_set::matchall(), this->adjacency)) { - for (auto const &dst : n) { - if (contains(nodes, dst)) { - result.insert(UndirectedEdge{{src, dst}}); - } - } - } - return result; } diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge.cc b/lib/utils/src/utils/graph/undirected/undirected_edge.cc index 4cfc6aaaa8..0a575e115c 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge.cc @@ -1,11 +1,40 @@ #include "utils/graph/undirected/undirected_edge.h" #include "utils/hash/tuple.h" -#include namespace FlexFlow { +UndirectedEdge::UndirectedEdge(Node const &n1, Node const &n2) + : smaller(std::min(n1, n2)), bigger(std::max(n1, n2)) {} + +static std::tuple tie(UndirectedEdge const &e) { + return std::tie(e.smaller, e.bigger); +} + +bool UndirectedEdge::operator==(UndirectedEdge const &other) const { + return tie(*this) == tie(other); +} + +bool UndirectedEdge::operator!=(UndirectedEdge const &other) const { + return tie(*this) != tie(other); +} + +bool UndirectedEdge::operator<(UndirectedEdge const &other) const { + return tie(*this) < tie(other); +} + bool is_connected_to(UndirectedEdge const &e, Node const &n) { - return e.endpoints.min() == n || e.endpoints.max() == n; + return e.bigger == n || e.smaller == n; } } // namespace FlexFlow + +namespace std { + +using namespace FlexFlow; + +size_t hash::operator()(UndirectedEdge const &e) const { + std::tuple members = ::FlexFlow::tie(e); + return std::hash{}(members); +} + +} // namespace std diff --git a/lib/utils/src/utils/graph/views/views.cc b/lib/utils/src/utils/graph/views/views.cc index eb61a704af..9b5353de9f 100644 --- a/lib/utils/src/utils/graph/views/views.cc +++ b/lib/utils/src/utils/graph/views/views.cc @@ -1,5 +1,4 @@ #include "utils/graph/views/views.h" -#include "utils/bidict/algorithms/left_entries.h" #include "utils/containers/flatmap.h" #include "utils/containers/transform.h" #include "utils/disjoint_set.h" @@ -23,13 +22,9 @@ UndirectedSubgraphView *UndirectedSubgraphView::clone() const { std::unordered_set UndirectedSubgraphView::query_edges( UndirectedEdgeQuery const &query) const { - std::unordered_set edges_including_open_edges = - this->g.query_edges( - query_intersection(query, UndirectedEdgeQuery{this->subgraph_nodes})); - return filter(edges_including_open_edges, [&](UndirectedEdge const &e) { - return contains(this->subgraph_nodes, e.endpoints.min()) && - contains(this->subgraph_nodes, e.endpoints.max()); - }); + UndirectedEdgeQuery subgraph_query = + UndirectedEdgeQuery{this->subgraph_nodes}; + return this->g.query_edges(query_intersection(query, subgraph_query)); } std::unordered_set @@ -83,11 +78,9 @@ JoinedNodeView::JoinedNodeView(GraphView const &lhs, GraphView const &rhs) { std::unordered_set JoinedNodeView::query_nodes(NodeQuery const &query) const { - std::unordered_set all_nodes = - transform(left_entries(this->mapping), - [&](JoinNodeKey const &key) { return key.node; }); - - return allowed_values(query_intersection(query, NodeQuery{all_nodes}).nodes); + // TODO @lockshaw this is going to be reimplemented in 984, so don't bother + // fixing it for now + NOT_IMPLEMENTED(); } std::pair, std::unordered_set> @@ -153,18 +146,17 @@ std::unordered_set JoinedUndirectedGraphView::query_edges( UndirectedEdge JoinedUndirectedGraphView::fix_lhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key( - JoinNodeKey{e.endpoints.min(), LRDirection::LEFT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::LEFT})}}; + return { + this->joined_nodes.at_join_key(JoinNodeKey{e.smaller, LRDirection::LEFT}), + this->joined_nodes.at_join_key(JoinNodeKey{e.bigger, LRDirection::LEFT})}; } UndirectedEdge JoinedUndirectedGraphView::fix_rhs_edge(UndirectedEdge const &e) const { - return UndirectedEdge{{this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.min(), LRDirection::RIGHT}), - this->joined_nodes.at_join_key(JoinNodeKey{ - e.endpoints.max(), LRDirection::RIGHT})}}; + return {this->joined_nodes.at_join_key( + JoinNodeKey{e.smaller, LRDirection::RIGHT}), + this->joined_nodes.at_join_key( + JoinNodeKey{e.bigger, LRDirection::RIGHT})}; } JoinedDigraphView::JoinedDigraphView(DiGraphView const &lhs, @@ -216,7 +208,7 @@ DirectedEdge JoinedDigraphView::fix_rhs_edge(DirectedEdge const &e) const { } UndirectedEdge to_undirected_edge(DirectedEdge const &e) { - return UndirectedEdge{{e.src, e.dst}}; + return {e.src, e.dst}; } std::unordered_set to_undirected_edges( @@ -226,9 +218,8 @@ std::unordered_set to_undirected_edges( } std::unordered_set to_directed_edges(UndirectedEdge const &e) { - return std::unordered_set{ - DirectedEdge{e.endpoints.min(), e.endpoints.max()}, - DirectedEdge{e.endpoints.max(), e.endpoints.min()}}; + return std::unordered_set{DirectedEdge{e.smaller, e.bigger}, + DirectedEdge{e.bigger, e.smaller}}; } std::unordered_set to_directed_edges( @@ -244,7 +235,7 @@ std::unordered_set ViewDiGraphAsUndirectedGraph::query_edges( DirectedEdgeQuery q1{undirected_query.nodes, query_set::matchall()}; DirectedEdgeQuery q2{query_set::matchall(), undirected_query.nodes}; return to_undirected_edges( - intersection(this->g.query_edges(q1), this->g.query_edges(q2))); + set_union(this->g.query_edges(q1), this->g.query_edges(q2))); } std::unordered_set ViewDiGraphAsUndirectedGraph::query_nodes( @@ -267,8 +258,8 @@ ViewUndirectedGraphAsDiGraph *ViewUndirectedGraphAsDiGraph::clone() const { std::unordered_set ViewUndirectedGraphAsDiGraph::query_edges( DirectedEdgeQuery const &q) const { std::unordered_set undirected_edges = - set_union(g.query_edges(UndirectedEdgeQuery{q.srcs}), - g.query_edges(UndirectedEdgeQuery{q.dsts})); + intersection(g.query_edges(UndirectedEdgeQuery{q.srcs}), + g.query_edges(UndirectedEdgeQuery{q.dsts})); std::unordered_set directed_edges = flatmap(undirected_edges, [](UndirectedEdge const &e) { return to_directed_edges(e); }); diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc similarity index 74% rename from lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc rename to lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc index 8c835558ad..0a1babd9f9 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("merge_bidicts") { + TEST_CASE("merge_disjoint_bidicts") { SUBCASE("disjoint keys and values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "three"}, {4, "four"}}; - bidict result = merge_bidicts(bd1, bd2); + bidict result = merge_disjoint_bidicts(bd1, bd2); bidict correct = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; @@ -22,21 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS(merge_bidicts(bd1, bd2)); + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping key, same associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "two"}, {3, "three"}}; - CHECK_THROWS(merge_bidicts(bd1, bd2)); + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS(merge_bidicts(bd1, bd2)); + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); } } } diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index fb58f1738e..739d2b72c7 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -10,45 +10,35 @@ TEST_SUITE(FF_TEST_SUITE) { dict.equate(1, "one"); dict.equate(2, "two"); - SUBCASE("bidict::contains_l") { - SUBCASE("L type is not the same as R type") { - CHECK(dict.contains_l(1)); - CHECK_FALSE(dict.contains_l(3)); - } - - SUBCASE("L type is the same as R type") { - bidict bd; - bd.equate(1, 3); + SUBCASE("L type is the same as R type") { + bidict bd; + bd.equate(1, 3); + SUBCASE("bidict::contains_l") { CHECK(bd.contains_l(1)); CHECK_FALSE(bd.contains_l(3)); } - } - - SUBCASE("bidict::contains_r") { - SUBCASE("L type is not the same as R type") { - CHECK(dict.contains_r(std::string("one"))); - CHECK_FALSE(dict.contains_r(std::string("three"))); - } - - SUBCASE("L type is the same as R type") { - bidict bd; - bd.equate(1, 3); + SUBCASE("bidict::contains_r") { CHECK(bd.contains_r(3)); CHECK_FALSE(bd.contains_r(1)); } } - SUBCASE("bidict::contains_r, bidict::contains_r - same type") { - bidict bd; - bd.equate(1, 3); - bd.equate(2, 4); + SUBCASE("L type is not the same as R type") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); + + SUBCASE("bidict::contains_l") { + CHECK(dict.contains_l(1)); + CHECK_FALSE(dict.contains_l(3)); + } - CHECK(bd.contains_l(1)); - CHECK_FALSE(bd.contains_l(3)); - CHECK(bd.contains_r(3)); - CHECK_FALSE(bd.contains_r(1)); + SUBCASE("bidict::contains_r") { + CHECK(dict.contains_r("one")); + CHECK_FALSE(dict.contains_r("three")); + } } SUBCASE("bidict::equate") { diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc index 8e5e93bcf9..432f318170 100644 --- a/lib/utils/test/src/utils/containers/find.cc +++ b/lib/utils/test/src/utils/containers/find.cc @@ -29,11 +29,11 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("unordered_set") { std::unordered_set s = {1, 2, 3, 4, 5}; - SUBCASE("element found") { + SUBCASE("element in container") { CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); } - SUBCASE("element not found") { + SUBCASE("element not in container") { CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); } } @@ -41,11 +41,11 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("set") { std::set s = {1, 2, 3, 4, 5}; - SUBCASE("element found") { + SUBCASE("element in container") { CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); } - SUBCASE("element not found") { + SUBCASE("element not in container") { CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); } } diff --git a/lib/utils/test/src/utils/containers/get_one_of.cc b/lib/utils/test/src/utils/containers/get_one_of.cc index 3470e16f37..326a292560 100644 --- a/lib/utils/test/src/utils/containers/get_one_of.cc +++ b/lib/utils/test/src/utils/containers/get_one_of.cc @@ -13,7 +13,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("empty set") { std::unordered_set s = {}; - CHECK_THROWS_AS(get_one_of(s), std::runtime_error); + CHECK_THROWS(get_one_of(s)); } } } diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc index 9b6bbfe3f4..8be3632ff6 100644 --- a/lib/utils/test/src/utils/containers/index_of.cc +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -11,13 +11,13 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector v = {1, 2, 3, 4, 3, 5}; - SUBCASE("unique element") { + SUBCASE("element occurs once in container") { CHECK(index_of(v, 4).value() == 3); } - SUBCASE("duplicate elements") { - CHECK(index_of(v, 3).value() == 2); // Returns first occurrence + SUBCASE("if element appears multiple times, return the first occurrence") { + CHECK(index_of(v, 3).value() == 2); } - SUBCASE("element not present") { + SUBCASE("element not in container") { CHECK(index_of(v, 7) == std::nullopt); } } diff --git a/lib/utils/test/src/utils/containers/product_where.cc b/lib/utils/test/src/utils/containers/product_where.cc index d175beefd4..098ae01252 100644 --- a/lib/utils/test/src/utils/containers/product_where.cc +++ b/lib/utils/test/src/utils/containers/product_where.cc @@ -7,6 +7,15 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("product_where") { + + SUBCASE("empty starting container") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + SUBCASE("non-empty filtered container") { std::vector input = {1, -2, 3, 4, 5}; auto condition = [](int x) { return x % 2 == 0; }; diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc index 525b1980a6..e5c7e646f9 100644 --- a/lib/utils/test/src/utils/containers/sorted_by.cc +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -24,7 +24,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("duplicates") { + SUBCASE("container contains duplicate elements") { std::vector input = {3, 1, 3, -4, 1}; std::vector result = sorted_by(input, [](int a, int b) { return a < b; }); diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc index 5262ab3119..8ae7a94163 100644 --- a/lib/utils/test/src/utils/containers/subvec.cc +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -52,15 +52,11 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("Out of bounds index from above") { - auto result = subvec(v, 2, 100); - std::vector correct = {3, 4, 5}; - CHECK(result == correct); + CHECK_THROWS(subvec(v, 2, 100)); } SUBCASE("Out of bounds index from below") { - auto result = subvec(v, -100, 2); - std::vector correct = {1, 2}; - CHECK(result == correct); + CHECK_THROWS(subvec(v, -100, 2)); } } } diff --git a/lib/utils/test/src/utils/containers/sum_where.cc b/lib/utils/test/src/utils/containers/sum_where.cc index 5d97d199ed..7a909aea39 100644 --- a/lib/utils/test/src/utils/containers/sum_where.cc +++ b/lib/utils/test/src/utils/containers/sum_where.cc @@ -8,7 +8,16 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("sum_where") { - SUBCASE("Resulting container is non-empty") { + + SUBCASE("starting container is empty") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("resulting container is non-empty") { std::vector input = {1, 2, 3, 4, 5}; auto condition = [](int x) { return x % 2 == 0; }; int correct = 6; @@ -16,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(correct == result); } - SUBCASE("Resulting container is empty") { + SUBCASE("resulting container is empty") { std::vector input = {1, 2, 3, 4, 5}; auto condition = [](int x) { return x > 10; }; int correct = 0; diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms.cc b/lib/utils/test/src/utils/graph/digraph/algorithms.cc deleted file mode 100644 index 0817c69e06..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/algorithms.cc +++ /dev/null @@ -1,106 +0,0 @@ -#include "utils/graph/digraph/algorithms.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("DiGraph - algorithms.cc") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, - }; - add_edges(g, e); - - SUBCASE("get_edges") { - SUBCASE("Base") { - std::unordered_set correct = unordered_set_of(e); - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[0], n[3]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[3], n[1]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge") { - g.remove_edge(DirectedEdge{n[0], n[3]}); - std::unordered_set correct = { - DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, - }; - std::unordered_set result = get_edges(g); - CHECK(result == correct); - } - } - - SUBCASE("get_sinks") { - SUBCASE("Base") { - std::unordered_set correct = {n[2], n[3]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a sink") { - g.add_edge(DirectedEdge{n[3], n[2]}); - std::unordered_set correct = {n[2]}; - std::unordered_set result = get_sinks(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sinks(g); - std::unordered_set correct = {n[3]}; - CHECK(result == correct); - } - } - - SUBCASE("get_sources") { - SUBCASE("Base") { - std::unordered_set correct = {n[0]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Adding an edge to remove a source") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set correct = {}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Removing an edge to create a new source") { - g.remove_edge(DirectedEdge{n[0], n[1]}); - std::unordered_set correct = {n[0], n[1]}; - std::unordered_set result = get_sources(g); - CHECK(result == correct); - } - - SUBCASE("Creating a cycle") { - g.add_edge(DirectedEdge{n[2], n[0]}); - std::unordered_set result = get_sources(g); - std::unordered_set correct = {}; - CHECK(result.empty()); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc deleted file mode 100644 index a330d09839..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_dominators.cc +++ /dev/null @@ -1,68 +0,0 @@ -#include "utils/graph/digraph/algorithms/get_dominators.h" -#include "test/utils/doctest.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/instances/adjacency_digraph.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_dominators") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - }; - add_edges(g, e); - - SUBCASE("single node") { - Node node = n.at(2); - std::unordered_set correct = {n.at(0), n.at(2)}; - std::unordered_set result = get_dominators(g, node); - CHECK(correct == result); - } - - SUBCASE("multiple nodes") { - std::unordered_set nodes = {n.at(1), n.at(3)}; - std::unordered_set result = get_dominators(g, nodes); - std::unordered_set correct = {n.at(0)}; - CHECK(correct == result); - } - - SUBCASE("graph with cycles") { - // example from - // https://en.wikipedia.org/w/index.php?title=Dominator_(graph_theory)&oldid=1189814332 - - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 6); - - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(1), n.at(5)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(3), n.at(4)}, - DirectedEdge{n.at(4), n.at(1)}, - }); - - SUBCASE("node 1") { - std::unordered_set result = get_dominators(g, n.at(1)); - std::unordered_set correct = {n.at(0), n.at(1)}; - CHECK(result == correct); - } - - SUBCASE("node 3") { - std::unordered_set result = get_dominators(g, n.at(3)); - std::unordered_set correct = {n.at(0), n.at(1), n.at(3)}; - CHECK(result == correct); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/digraph/digraph.cc b/lib/utils/test/src/utils/graph/digraph/digraph.cc deleted file mode 100644 index 3a3648eec8..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/digraph.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "utils/graph/digraph/digraph.h" -#include "utils/containers/repeat.h" -#include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/node/node_query.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("DiGraph implementations", T, AdjacencyDiGraph) { - /* - graph TD - - n0 --> n1 - n0 --> n2 - n1 --> n2 - n2 --> n4 - n1 --> n3 - */ - - DiGraph g = DiGraph::create(); - std::vector n = repeat(5, [&] { return g.add_node(); }); - std::vector e = {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[4]}, - DirectedEdge{n[1], n[3]}}; - for (DirectedEdge const &edge : e) { - g.add_edge(edge); - } - - SUBCASE("query_nodes") { - - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - std::unordered_set queried_edges = - g.query_edges(directed_edge_query_all()); - std::unordered_set expected = { - e[0], e[1], e[2], e[3], e[4]}; - CHECK(queried_edges == expected); - - queried_edges = g.query_edges( - DirectedEdgeQuery{query_set{{n[0]}}, query_set{{n[1]}}}); - expected = std::unordered_set{e[0]}; - CHECK(queried_edges == expected); - } - SUBCASE("remove_node_unsafe") { - g.remove_node_unsafe(n[0]); - - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{n[1], n[2], n[3], n[4]}); - - // removing a node also removes its adjacent edges - CHECK(g.query_edges(directed_edge_query_all()) == - std::unordered_set{e[2], e[3], e[4]}); - - g.remove_node_unsafe(n[1]); - - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{n[2], n[3], n[4]}); - - CHECK(g.query_edges(directed_edge_query_all()) == - std::unordered_set{e[3]}); - } - - SUBCASE("remove_edge") { - g.remove_edge(e[0]); - - CHECK(g.query_edges(directed_edge_query_all()) == - std::unordered_set{e[1], e[2], e[3], e[4]}); - CHECK(g.query_nodes(node_query_all()) == - std::unordered_set{n[0], n[1], n[2], n[3], n[4]}); - - g.remove_edge(e[1]); - g.remove_edge(e[3]); - CHECK(g.query_edges(directed_edge_query_all()) == - std::unordered_set{e[2], e[4]}); - } - } -} diff --git a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc b/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc deleted file mode 100644 index 1dde5c8f69..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/directed_edge_query.cc +++ /dev/null @@ -1,70 +0,0 @@ -#include "utils/graph/digraph/directed_edge_query.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/algorithms/get_successors.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("directed_edge_query") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 5); - - add_edges(g, - {DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(1), n.at(3)}}); - - SUBCASE("directed_edge_query_all") { - - DirectedEdgeQuery result = directed_edge_query_all(); - - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(1)})); - CHECK(matches_edge(result, DirectedEdge{n.at(0), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(2)})); - CHECK(matches_edge(result, DirectedEdge{n.at(2), n.at(4)})); - CHECK(matches_edge(result, DirectedEdge{n.at(1), n.at(3)})); - } - - SUBCASE("matches_edge") { - DirectedEdgeQuery q = - DirectedEdgeQuery{query_set{n.at(0)}, query_set{n.at(1)}}; - - CHECK(matches_edge(q, DirectedEdge{n.at(0), n.at(1)})); - CHECK_FALSE(matches_edge(q, DirectedEdge{n.at(1), n.at(2)})); - } - - SUBCASE("query_intersection") { - SUBCASE("standard intersection") { - DirectedEdgeQuery q1 = DirectedEdgeQuery{ - query_set{n.at(0), n.at(1)}, query_set{n.at(1), n.at(2), n.at(4)}}; - DirectedEdgeQuery q2 = DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, - query_set{n.at(2), n.at(3)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1)}, - query_set{n.at(2)}, - }; - - CHECK(result == correct); - } - SUBCASE("intersection with std::nullopt") { - DirectedEdgeQuery q1 = - DirectedEdgeQuery{query_set{n.at(1), n.at(2)}, matchall()}; - DirectedEdgeQuery q2 = - DirectedEdgeQuery{matchall(), query_set{n.at(3), n.at(4)}}; - - DirectedEdgeQuery result = query_intersection(q1, q2); - DirectedEdgeQuery correct = DirectedEdgeQuery{ - query_set{n.at(1), n.at(2)}, query_set{n.at(3), n.at(4)}}; - CHECK(result == correct); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc b/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc deleted file mode 100644 index 954b6bd3b8..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/get_topological_ordering.cc +++ /dev/null @@ -1,34 +0,0 @@ -#include "utils/graph/digraph/algorithms/get_topological_ordering.h" -#include "test/utils/doctest.h" -#include "utils/containers/index_of.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/instances/adjacency_digraph.h" - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(0), n.at(2)}, - DirectedEdge{n.at(1), n.at(5)}, - DirectedEdge{n.at(2), n.at(3)}, - DirectedEdge{n.at(3), n.at(4)}, - DirectedEdge{n.at(4), n.at(5)}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).value() < - index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); - } -} diff --git a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc b/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc deleted file mode 100644 index 8b97c48a03..0000000000 --- a/lib/utils/test/src/utils/graph/digraph/get_weakly_connected_components.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" -#include "test/utils/doctest.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/instances/adjacency_digraph.h" - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - - add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[2], n[1]}}); - - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - std::unordered_set> result = - get_weakly_connected_components(g); - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc deleted file mode 100644 index 89d38111ea..0000000000 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" -#include "test/utils/doctest.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/adjacency_multidigraph.h" -#include "utils/graph/multidigraph/algorithms/add_edges.h" -#include "utils/graph/multidigraph/algorithms/add_nodes.h" -#include "utils/graph/multidigraph/multidigraph.h" -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_incoming_edges(MultiDiGraphView, Node)") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 3); - - std::vector edges = add_edges(g, - {{n.at(0), n.at(0)}, - {n.at(0), n.at(1)}, - {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}}); - - SUBCASE("node has incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(1)); - std::unordered_set correct = {edges.at(1), edges.at(2)}; - CHECK(result == correct); - } - - SUBCASE("node has no incoming edges") { - std::unordered_set result = get_incoming_edges(g, n.at(2)); - std::unordered_set correct = {}; - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc deleted file mode 100644 index 9e5790bb5a..0000000000 --- a/lib/utils/test/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ /dev/null @@ -1,40 +0,0 @@ -#include "utils/graph/multidigraph/algorithms/get_outgoing_edges.h" -#include "test/utils/doctest.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/adjacency_multidigraph.h" -#include "utils/graph/multidigraph/algorithms/add_edges.h" -#include "utils/graph/multidigraph/algorithms/add_nodes.h" -#include "utils/graph/multidigraph/multidigraph.h" -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(MultiDiGraph, Node)") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 3); - - std::vector> input = { - {n.at(0), n.at(0)}, - {n.at(0), n.at(1)}, - {n.at(0), n.at(1)}, - {n.at(1), n.at(0)}, - }; - - std::vector edges = add_edges(g, input); - - SUBCASE("node has outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(0)); - std::unordered_set correct = { - edges.at(0), edges.at(1), edges.at(2)}; - CHECK(result == correct); - } - - SUBCASE("node has no outgoing edges") { - std::unordered_set result = get_outgoing_edges(g, n.at(2)); - std::unordered_set correct = {}; - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/traversal.cc b/lib/utils/test/src/utils/graph/traversal.cc deleted file mode 100644 index 66daaf8400..0000000000 --- a/lib/utils/test/src/utils/graph/traversal.cc +++ /dev/null @@ -1,110 +0,0 @@ -#include "utils/graph/traversal.h" -#include "test/utils/doctest.h" -#include "utils/fmt/vector.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/hash/vector.h" - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_unchecked_dfs_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); - - SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_unchecked_dfs_ordering(g, {n[0]}); - CHECK(correct == result); - } - } - - TEST_CASE("get_bfs_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[3]}, - DirectedEdge{n[2], n[3]}, - DirectedEdge{n[3], n[4]}, - DirectedEdge{n[4], n[5]}}); - - SUBCASE("branching path") { - std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3], n[4], n[5]}, - {n[0], n[2], n[1], n[3], n[4], n[5]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); - CHECK(contains(corrects, result)); - } - - SUBCASE("isolated node") { - std::vector correct = {n[5]}; - std::vector result = get_bfs_ordering(g, {n[5]}); - CHECK(correct == result); - } - - SUBCASE("graph with cycle") { - g = DiGraph::create(); - n = add_nodes(g, 3); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[0], n[2]}, - DirectedEdge{n[1], n[0]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[0]}, - DirectedEdge{n[2], n[1]}}); - std::unordered_set> corrects = {{n[0], n[1], n[2]}, - {n[0], n[2], n[1]}}; - std::vector result = get_bfs_ordering(g, {n[0]}); - CHECK(contains(corrects, result)); - } - } - - TEST_CASE("get_dfs_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - add_edges(g, - {DirectedEdge{n[0], n[1]}, - DirectedEdge{n[1], n[2]}, - DirectedEdge{n[2], n[3]}}); - - SUBCASE("simple path") { - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); - CHECK(correct == result); - } - - SUBCASE("with cycle") { - g.add_edge(DirectedEdge{n[3], n[1]}); - std::vector correct = {n[0], n[1], n[2], n[3]}; - std::vector result = get_dfs_ordering(g, {n[0]}); - CHECK(correct == result); - } - - SUBCASE("branching") { - g.add_edge(DirectedEdge{n[1], n[3]}); - std::unordered_set> corrects = { - {n[0], n[1], n[2], n[3]}, {n[0], n[1], n[3], n[2]}}; - std::vector result = get_dfs_ordering(g, {n[0]}); - CHECK(contains(corrects, result)); - } - - SUBCASE("disconnected") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[0], n[1], n[2]}; - std::vector result = get_dfs_ordering(g, {n[0]}); - CHECK(correct == result); - } - - SUBCASE("isolated node") { - g.remove_edge(DirectedEdge{n[2], n[3]}); - std::vector correct = {n[3]}; - std::vector result = get_dfs_ordering(g, {n[3]}); - CHECK(correct == result); - } - } -} diff --git a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc b/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc deleted file mode 100644 index 20d2bfbf34..0000000000 --- a/lib/utils/test/src/utils/graph/undirected/algorithms/get_connected_components.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "utils/graph/undirected/algorithms/get_connected_components.h" -#include "test/utils/doctest.h" -#include "utils/fmt/unordered_set.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/hashmap_undirected_graph.h" -#include "utils/graph/undirected/undirected_graph.h" - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - add_edges(g, {UndirectedEdge{{n[0], n[1]}}, UndirectedEdge{{n[2], n[1]}}}); - - std::unordered_set> correct = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - std::unordered_set> result = - get_connected_components(g); - - CHECK(correct == result); - } -} diff --git a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc b/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc deleted file mode 100644 index b11cdd4753..0000000000 --- a/lib/utils/test/src/utils/graph/undirected/undirected_graph.cc +++ /dev/null @@ -1,65 +0,0 @@ -#include "utils/graph/undirected/undirected_graph.h" -#include "test/utils/all.h" -#include "test/utils/rapidcheck/visitable.h" -#include "utils/containers/repeat.h" -#include "utils/graph/instances/hashmap_undirected_graph.h" -#include "utils/graph/node/node_query.h" -#include "utils/graph/undirected/undirected_edge_query.h" - -/* namespace rc { */ - -/* template <> */ -/* struct Arbitrary { */ -/* static Gen arbitrary() { */ -/* int num_nodes = *gen::inRange( */ -/* } */ -/* }; */ - -/* } */ - -using namespace FlexFlow; - -using namespace rc; - -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_strong_typedef::value, ""); */ -/* static_assert(supports_rc_arbitrary>::value, ""); - */ -/* static_assert(supports_rc_arbitrary>::value, - * ""); */ -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_streamable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE( - "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct( - gen::construct>(gen::elementOf(n), - gen::elementOf(n)))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(node_query_all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(undirected_edge_query_all()) == unordered_set_of(e)); - }); - } -} diff --git a/lib/utils/test/src/utils/graph/views/views.cc b/lib/utils/test/src/utils/graph/views/views.cc deleted file mode 100644 index 6bc4fdbcea..0000000000 --- a/lib/utils/test/src/utils/graph/views/views.cc +++ /dev/null @@ -1,230 +0,0 @@ -#include "utils/graph/views/views.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/graph/algorithms.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/instances/hashmap_undirected_graph.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/undirected/undirected_graph.h" -#include "utils/graph/undirected/undirected_graph_view.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - - TEST_CASE("UndirectedSubgraphView") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 5); - add_edges(g, - {UndirectedEdge{{n.at(0), n.at(3)}}, - UndirectedEdge{{n.at(1), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(1), n.at(3)}}, - UndirectedEdge{{n.at(2), n.at(3)}}, - UndirectedEdge{{n.at(2), n.at(4)}}}); - std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - UndirectedGraphView view = get_subgraph(g, sub_nodes); - - SUBCASE("get_nodes") { - std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - SUBCASE("get_edges") { - std::unordered_set expected = { - UndirectedEdge{{n.at(0), n.at(3)}}, - UndirectedEdge{{n.at(1), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(3)}}, - }; - - std::unordered_set result = get_edges(view); - - // TODO(@pietro) TODO(@lockshaw) current BUG, get_edges also - CHECK(result == expected); - } - } - - TEST_CASE("DiSubgraphView") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 5); - add_edges(g, - {DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(3), n.at(0)}, - DirectedEdge{n.at(1), n.at(1)}, - DirectedEdge{n.at(2), n.at(1)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(2), n.at(3)}, - DirectedEdge{n.at(3), n.at(2)}, - DirectedEdge{n.at(2), n.at(4)}}); - std::unordered_set sub_nodes = {n.at(0), n.at(1), n.at(3)}; - DiGraphView view = get_subgraph(g, sub_nodes); - - SUBCASE("get_nodes") { - std::unordered_set expected = {n.at(0), n.at(1), n.at(3)}; - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - SUBCASE("get_edges") { - std::unordered_set expected = { - DirectedEdge{n.at(0), n.at(3)}, - DirectedEdge{n.at(3), n.at(0)}, - DirectedEdge{n.at(1), n.at(1)}, - DirectedEdge{n.at(1), n.at(3)}, - }; - - std::unordered_set result = get_edges(view); - - CHECK(result == expected); - } - } - - TEST_CASE("JoinedUndirectedGraphView") { - UndirectedGraph g1 = UndirectedGraph::create(); - UndirectedGraph g2 = UndirectedGraph::create(); - - std::vector n1 = add_nodes(g1, 3); - std::vector n2 = add_nodes(g2, 3); - - add_edges(g1, - {UndirectedEdge{{n1.at(0), n1.at(1)}}, - UndirectedEdge{{n1.at(1), n1.at(2)}}}); - add_edges(g2, - {UndirectedEdge{{n2.at(0), n2.at(2)}}, - UndirectedEdge{{n2.at(1), n2.at(2)}}}); - - UndirectedGraphView view = join(g1, g2); - - SUBCASE("get_nodes") { - std::unordered_set expected = - set_union(unordered_set_of(n1), unordered_set_of(n2)); - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // UndirectedEdge{{n1.at(0), n1.at(1)}}, UndirectedEdge{{n1.at(1), - // n1.at(2)}}, UndirectedEdge{{n2.at(0), n2.at(2)}}, - // UndirectedEdge{{n2.at(1), n2.at(2)}} - // }; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - } - - TEST_CASE("JoinedDigraphView") { - DiGraph g1 = DiGraph::create(); - DiGraph g2 = DiGraph::create(); - - std::vector n1 = add_nodes(g1, 3); - std::vector n2 = add_nodes(g2, 3); - - add_edges( - g1, - {DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), n1.at(2)}}); - add_edges( - g2, - {DirectedEdge{n2.at(0), n2.at(2)}, DirectedEdge{n2.at(1), n2.at(2)}}); - - DiGraphView view = join(g1, g2); - - SUBCASE("get_nodes") { - std::unordered_set expected = - set_union(unordered_set_of(n1), unordered_set_of(n2)); - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - // SUBCASE("get_edges") { - // std::unordered_set expected = { - // DirectedEdge{n1.at(0), n1.at(1)}, DirectedEdge{n1.at(1), - // n1.at(2)}, DirectedEdge{n2.at(0), n2.at(2)}, - // DirectedEdge{n2.at(1), n2.at(2)} - // }; - - // std::unordered_set result = get_edges(view); - - // CHECK(result == expected); - // } - } - - TEST_CASE("ViewDiGraphAsUndirectedGraph") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 3); - add_edges(g, - {DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(0)}, - DirectedEdge{n.at(0), n.at(2)}}); - - UndirectedGraphView view = as_undirected(g); - - SUBCASE("get_nodes") { - std::unordered_set expected = unordered_set_of(n); - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - SUBCASE("get_edges") { - std::unordered_set expected = { - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(2), n.at(0)}}}; - - std::unordered_set result = get_edges(view); - - CHECK(result == expected); - } - } - - TEST_CASE("ViewUndirectedGraphAsDiGraph") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 3); - add_edges(g, - {UndirectedEdge{{n.at(0), n.at(0)}}, - UndirectedEdge{{n.at(0), n.at(1)}}, - UndirectedEdge{{n.at(1), n.at(2)}}, - UndirectedEdge{{n.at(2), n.at(0)}}}); - - DiGraphView view = as_digraph(g); - - SUBCASE("get_nodes") { - std::unordered_set expected = unordered_set_of(n); - - std::unordered_set result = get_nodes(view); - - CHECK(result == expected); - } - - SUBCASE("get_edges") { - std::unordered_set expected = { - DirectedEdge{n.at(0), n.at(0)}, - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(1), n.at(0)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(2), n.at(1)}, - DirectedEdge{n.at(2), n.at(0)}, - DirectedEdge{n.at(0), n.at(2)}}; - - std::unordered_set result = get_edges(view); - - CHECK(result == expected); - } - } -} diff --git a/lib/utils/test/src/utils/random_utils.cc b/lib/utils/test/src/utils/random_utils.cc index ba994cefce..4f003301bd 100644 --- a/lib/utils/test/src/utils/random_utils.cc +++ b/lib/utils/test/src/utils/random_utils.cc @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("correct distribution") { - auto check_probabilities = [](std::vector values, + auto check_probabilities = [](std::vector const &values, std::vector const &weights) { int num_iterations = 10'000; std::vector trials = repeat( From 18eb8a829df1e1dfa87d81d3a23500f15e140044 Mon Sep 17 00:00:00 2001 From: Pietro Max Marsella Date: Wed, 9 Oct 2024 19:45:40 -0700 Subject: [PATCH 17/21] minor fixes --- .../utils/bidict/algorithms/merge_disjoint_bidicts.h | 4 ++-- lib/utils/include/utils/containers/maximum.h | 8 +++++++- lib/utils/test/src/utils/containers/subvec.cc | 8 ++++---- lib/utils/test/src/utils/containers/vector_split.cc | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index 60637349c7..97e7334c26 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -14,11 +14,11 @@ bidict merge_disjoint_bidicts(bidict const &lhs, bidict const &rhs) { if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { throw mk_runtime_error( - fmt::format("Left entries of {} and {} are non-disjoint"), lhs, rhs); + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); } if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { throw mk_runtime_error( - fmt::format("Right entries of {} and {} are non-disjoint"), lhs, rhs); + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); } bidict result; diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 9953c44d0d..10a3f85ce8 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -2,6 +2,12 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #include "utils/exception.h" +#include "utils/fmt/map.h" +#include "utils/fmt/multiset.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include namespace FlexFlow { @@ -10,7 +16,7 @@ template typename C::value_type maximum(C const &c) { if (c.empty()) { throw mk_runtime_error( - "input to function maximum must be non-empty container"); + fmt::format("maximum expected non-empty container but received {}", c)); } return *std::max_element(c.begin(), c.end()); diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc index 8ae7a94163..861d29093e 100644 --- a/lib/utils/test/src/utils/containers/subvec.cc +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -51,12 +51,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("Out of bounds index from above") { - CHECK_THROWS(subvec(v, 2, 100)); + SUBCASE("Upper index is out of bounds by 1") { + CHECK_THROWS(subvec(v, 2, 6)); } - SUBCASE("Out of bounds index from below") { - CHECK_THROWS(subvec(v, -100, 2)); + SUBCASE("Lower index is out of bounds by 1") { + CHECK_THROWS(subvec(v, -6, 2)); } } } diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc index 5cf8672ef4..e6e7011610 100644 --- a/lib/utils/test/src/utils/containers/vector_split.cc +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -22,7 +22,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(postfix == std::vector({1, 2, 3, 4, 5})); } - SUBCASE("Boundary case: idx is the last index in the list") { + SUBCASE("Boundary case: idx == list_size") { auto [prefix, postfix] = vector_split(v, 5); CHECK(prefix == std::vector({1, 2, 3, 4, 5})); CHECK(postfix.empty()); From 8322f9bda7a33c9be338345aba79d78f0912e7f6 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 9 Oct 2024 20:05:07 -0700 Subject: [PATCH 18/21] Remove unnecessary includes --- lib/utils/include/utils/containers/maximum.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 10a3f85ce8..5d21dd9fa1 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -2,12 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #include "utils/exception.h" -#include "utils/fmt/map.h" -#include "utils/fmt/multiset.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" #include namespace FlexFlow { From 3579ee4169b09d8d8debc9399d2ea0220a463478 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 9 Oct 2024 21:25:56 -0700 Subject: [PATCH 19/21] Post-merge fixes --- .../machine_mapping/machine_mapping.cc | 1 - lib/pcg/src/pcg/operator_task_space.cc | 3 +- lib/utils/include/utils/fmt/optional.h | 4 +- .../src/utils/cli/cli_get_help_message.cc | 3 +- lib/utils/src/utils/fmt/optional.cc | 8 ++ .../get_cbc_decomposition.cc | 4 - .../is_complete_bipartite_digraph.cc | 1 - .../include/test/utils/doctest/fmt/optional.h | 5 + .../src/test/utils/doctest/fmt/optional.cc | 8 ++ .../test/src/utils/containers/as_vector.cc | 17 ---- .../test/src/utils/containers/compare_by.cc | 2 +- .../test/src/utils/containers/enumerate.cc | 3 +- .../test/src/utils/containers/filter_keys.cc | 2 +- lib/utils/test/src/utils/containers/find.cc | 3 +- .../test/src/utils/containers/flatmap.cc | 2 + .../test/src/utils/containers/index_of.cc | 2 +- lib/utils/test/src/utils/containers/keys.cc | 2 +- .../test/src/utils/containers/map_keys.cc | 2 +- .../test/src/utils/containers/map_values.cc | 2 +- .../test/src/utils/containers/maximum.cc | 2 +- .../test/src/utils/containers/merge_maps.cc | 2 +- .../src/utils/containers/restrict_keys.cc | 3 +- .../test/src/utils/containers/set_union.cc | 2 +- .../test/src/utils/containers/sorted_by.cc | 2 +- lib/utils/test/src/utils/containers/subvec.cc | 2 +- .../test/src/utils/containers/value_all.cc | 4 +- lib/utils/test/src/utils/containers/values.cc | 10 +- .../test/src/utils/containers/vector_split.cc | 2 +- .../src/utils/deduplicated_priority_queue.cc | 2 + lib/utils/test/src/utils/disjoint_set.cc | 4 +- .../utils/graph/multidigraph/multidigraph.cc | 2 +- lib/utils/test/src/utils/hash-utils.cc | 98 ------------------- lib/utils/test/src/utils/hash/map.cc | 20 ++++ lib/utils/test/src/utils/hash/set.cc | 20 ++++ lib/utils/test/src/utils/hash/tuple.cc | 21 ++++ .../test/src/utils/hash/unordered_map.cc | 20 ++++ .../test/src/utils/hash/unordered_set.cc | 20 ++++ lib/utils/test/src/utils/hash/vector.cc | 20 ++++ lib/utils/test/src/utils/join_strings.cc | 2 +- lib/utils/test/src/utils/random_utils.cc | 8 +- lib/utils/test/src/utils/stack_map.cc | 8 +- lib/utils/test/src/utils/stack_string.cc | 2 +- lib/utils/test/src/utils/stack_vector.cc | 3 +- lib/utils/test/src/utils/type_index.cc | 3 +- lib/utils/test/src/utils/variant.cc | 6 +- lib/utils/test/src/utils/vector.cc | 4 + 46 files changed, 203 insertions(+), 163 deletions(-) delete mode 100644 lib/utils/test/src/utils/containers/as_vector.cc delete mode 100644 lib/utils/test/src/utils/hash-utils.cc create mode 100644 lib/utils/test/src/utils/hash/map.cc create mode 100644 lib/utils/test/src/utils/hash/set.cc create mode 100644 lib/utils/test/src/utils/hash/tuple.cc create mode 100644 lib/utils/test/src/utils/hash/unordered_map.cc create mode 100644 lib/utils/test/src/utils/hash/unordered_set.cc create mode 100644 lib/utils/test/src/utils/hash/vector.cc diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 6f350d8773..57e82684e9 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,5 +1,4 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "utils/containers.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc index 02522ae411..2538cb4ea0 100644 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -5,6 +5,7 @@ #include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -25,7 +26,7 @@ std::unordered_set TaskSpaceCoordinate get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { - return maximum(get_task_space_coordinates(task)).value(); + return maximum(get_task_space_coordinates(task)); } size_t num_dims(OperatorTaskSpace const &task) { diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index cb72cd03e3..16ccf61878 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -48,9 +48,7 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { return s << fmt::to_string(t); } -inline std::ostream &operator<<(std::ostream &s, std::nullopt_t) { - return s << "nullopt"; -} +std::ostream &operator<<(std::ostream &, std::nullopt_t); } // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc index 03c53c9356..51947aa885 100644 --- a/lib/utils/src/utils/cli/cli_get_help_message.cc +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -2,6 +2,7 @@ #include "utils/containers/concat_vectors.h" #include "utils/containers/maximum.h" #include "utils/containers/transform.h" +#include "utils/fmt/vector.h" #include "utils/integer_conversions.h" #include "utils/join_strings.h" #include @@ -53,7 +54,7 @@ std::string cli_get_help_message(std::string const &program_name, if (!all_arg_columns.empty()) { int max_column_width = - std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + std::min(int_from_size_t(maximum(all_arg_column_widths)), 20); auto render_column = [&](std::string const &key, std::optional const &description) { diff --git a/lib/utils/src/utils/fmt/optional.cc b/lib/utils/src/utils/fmt/optional.cc index e21b32eaa9..4642292920 100644 --- a/lib/utils/src/utils/fmt/optional.cc +++ b/lib/utils/src/utils/fmt/optional.cc @@ -1 +1,9 @@ #include "utils/fmt/optional.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::nullopt_t) { + return (s << std::string{"nullopt"}); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 8afe7da926..92bd1e32ca 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,12 +1,9 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" -#include "utils/containers/set_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" -#include "utils/fmt/set.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" @@ -16,7 +13,6 @@ #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" -#include "utils/hash/unordered_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index 2eab8371b2..ccd2808603 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h index 519cde7d74..25552003e3 100644 --- a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -13,6 +13,11 @@ struct StringMaker> { } }; +template <> +struct StringMaker { + static String convert(std::nullopt_t const &); +}; + } // namespace doctest #endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc index 8a3f7f158e..09b02ac059 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -1 +1,9 @@ #include "test/utils/doctest/fmt/optional.h" + +namespace doctest { + +String StringMaker::convert(std::nullopt_t const &m) { + return toString(fmt::to_string(m)); +} + +} // namespace doctest diff --git a/lib/utils/test/src/utils/containers/as_vector.cc b/lib/utils/test/src/utils/containers/as_vector.cc deleted file mode 100644 index deef9c00fd..0000000000 --- a/lib/utils/test/src/utils/containers/as_vector.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "utils/containers/as_vector.h" -#include "utils/containers/sorted.h" -#include "utils/fmt/vector.h" -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("as_vector") { - std::unordered_set input = {1, 2, 3}; - std::vector result = as_vector(input); - std::vector correct = {1, 2, 3}; - CHECK(sorted(result) == sorted(correct)); - } -} diff --git a/lib/utils/test/src/utils/containers/compare_by.cc b/lib/utils/test/src/utils/containers/compare_by.cc index 52e71ba175..8c7221ffb4 100644 --- a/lib/utils/test/src/utils/containers/compare_by.cc +++ b/lib/utils/test/src/utils/containers/compare_by.cc @@ -1,5 +1,5 @@ #include "utils/containers/compare_by.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index fb55643fb3..2f9a5b3c02 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,10 +1,11 @@ #include "utils/containers/enumerate.h" #include "test/utils/doctest/fmt/map.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include "utils/containers/keys.h" -#include "utils/containers/unordered_set_of.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include diff --git a/lib/utils/test/src/utils/containers/filter_keys.cc b/lib/utils/test/src/utils/containers/filter_keys.cc index c7a1e8ab0d..00e327a6f1 100644 --- a/lib/utils/test/src/utils/containers/filter_keys.cc +++ b/lib/utils/test/src/utils/containers/filter_keys.cc @@ -1,5 +1,5 @@ #include "utils/containers/filter_keys.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc index 432f318170..36d4b771d8 100644 --- a/lib/utils/test/src/utils/containers/find.cc +++ b/lib/utils/test/src/utils/containers/find.cc @@ -1,6 +1,7 @@ #include "utils/containers/find.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/check_without_stringify.h" #include +#include #include #include #include diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index 25dcea38dc..bd6d3ae5be 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -2,6 +2,7 @@ #include "test/utils/doctest/fmt/pair.h" #include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/map_keys.h" #include "utils/hash/pair.h" #include @@ -41,6 +42,7 @@ TEST_SUITE(FF_TEST_SUITE) { "1", "2", "4", "3", "4", "8", "9", "10", "20"}; CHECK(result == correct); } + } TEST_CASE("flatmap(std::unordered_set, F)") { auto get_chars = [](std::string const &s) { diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc index 8be3632ff6..6ab49cfd42 100644 --- a/lib/utils/test/src/utils/containers/index_of.cc +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/index_of.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc index 1244f5c292..5bdaef6d08 100644 --- a/lib/utils/test/src/utils/containers/keys.cc +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -1,5 +1,5 @@ #include "utils/containers/keys.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc index 256a991a67..5c0a81d5e6 100644 --- a/lib/utils/test/src/utils/containers/map_keys.cc +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -1,5 +1,5 @@ #include "utils/containers/map_keys.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/map_values.cc b/lib/utils/test/src/utils/containers/map_values.cc index c4c9fcafe6..a21645d0d5 100644 --- a/lib/utils/test/src/utils/containers/map_values.cc +++ b/lib/utils/test/src/utils/containers/map_values.cc @@ -1,5 +1,5 @@ #include "utils/containers/map_values.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc index 0baf26630c..2309458069 100644 --- a/lib/utils/test/src/utils/containers/maximum.cc +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -1,6 +1,6 @@ #include "utils/containers/maximum.h" +#include "test/utils/doctest/fmt/vector.h" #include -#include #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/merge_maps.cc b/lib/utils/test/src/utils/containers/merge_maps.cc index 148ad6fd87..a083e94de3 100644 --- a/lib/utils/test/src/utils/containers/merge_maps.cc +++ b/lib/utils/test/src/utils/containers/merge_maps.cc @@ -1,5 +1,5 @@ #include "utils/containers/merge_maps.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include #include diff --git a/lib/utils/test/src/utils/containers/restrict_keys.cc b/lib/utils/test/src/utils/containers/restrict_keys.cc index 169386d899..b4b376784e 100644 --- a/lib/utils/test/src/utils/containers/restrict_keys.cc +++ b/lib/utils/test/src/utils/containers/restrict_keys.cc @@ -1,7 +1,8 @@ #include "utils/containers/restrict_keys.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include #include + using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/utils/containers/set_union.cc b/lib/utils/test/src/utils/containers/set_union.cc index 92b53f0845..d842e4df96 100644 --- a/lib/utils/test/src/utils/containers/set_union.cc +++ b/lib/utils/test/src/utils/containers/set_union.cc @@ -1,5 +1,5 @@ #include "utils/containers/set_union.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc index e5c7e646f9..0ae2e0da77 100644 --- a/lib/utils/test/src/utils/containers/sorted_by.cc +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -1,5 +1,5 @@ #include "utils/containers/sorted_by.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc index 861d29093e..610fc55b5a 100644 --- a/lib/utils/test/src/utils/containers/subvec.cc +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -1,5 +1,5 @@ #include "utils/containers/subvec.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include #include diff --git a/lib/utils/test/src/utils/containers/value_all.cc b/lib/utils/test/src/utils/containers/value_all.cc index 07f96aa807..1a5f2c508a 100644 --- a/lib/utils/test/src/utils/containers/value_all.cc +++ b/lib/utils/test/src/utils/containers/value_all.cc @@ -1,6 +1,6 @@ #include "utils/containers/value_all.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include diff --git a/lib/utils/test/src/utils/containers/values.cc b/lib/utils/test/src/utils/containers/values.cc index 245622c4ab..5fe69ac5e9 100644 --- a/lib/utils/test/src/utils/containers/values.cc +++ b/lib/utils/test/src/utils/containers/values.cc @@ -1,6 +1,5 @@ #include "utils/containers/values.h" -#include "utils/containers/sorted.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include #include @@ -12,8 +11,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("values") { std::unordered_map m = { {1, "one"}, {2, "two"}, {3, "three"}, {33, "three"}}; - std::vector result = values(m); - std::vector correct = {"one", "two", "three", "three"}; - CHECK(sorted(result) == sorted(correct)); + std::unordered_multiset result = values(m); + std::unordered_multiset correct = { + "one", "two", "three", "three"}; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc index e6e7011610..76bb21348e 100644 --- a/lib/utils/test/src/utils/containers/vector_split.cc +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -1,5 +1,5 @@ #include "utils/containers/vector_split.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include #include #include diff --git a/lib/utils/test/src/utils/deduplicated_priority_queue.cc b/lib/utils/test/src/utils/deduplicated_priority_queue.cc index 048e95acb7..d5b35cf654 100644 --- a/lib/utils/test/src/utils/deduplicated_priority_queue.cc +++ b/lib/utils/test/src/utils/deduplicated_priority_queue.cc @@ -1,6 +1,8 @@ #include "utils/deduplicated_priority_queue.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { DeduplicatedPriorityQueue queue; diff --git a/lib/utils/test/src/utils/disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc index 283b0a8983..be88a30cdd 100644 --- a/lib/utils/test/src/utils/disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,6 +1,6 @@ #include "utils/disjoint_set.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc index 16ee406255..ce4d7a373b 100644 --- a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -1,9 +1,9 @@ #include "utils/graph/multidigraph/multidigraph.h" -#include "test/utils/doctest.h" #include "utils/containers/contains.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/multidiedge_query.h" #include "utils/graph/query_set.h" +#include #include #include diff --git a/lib/utils/test/src/utils/hash-utils.cc b/lib/utils/test/src/utils/hash-utils.cc deleted file mode 100644 index 077037d36a..0000000000 --- a/lib/utils/test/src/utils/hash-utils.cc +++ /dev/null @@ -1,98 +0,0 @@ -#include "utils/hash-utils.h" -#include "test/utils/doctest.h" -#include "utils/hash/map.h" -#include "utils/hash/set.h" -#include "utils/hash/tuple.h" -#include "utils/hash/unordered_map.h" -#include "utils/hash/unordered_set.h" -#include "utils/hash/vector.h" - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("hash-utils") { - SUBCASE("vector") { - std::vector vec1{1, 2, 3}; - std::vector vec2{1, 2, 3, 4}; - - size_t hash1 = get_std_hash(vec1); - size_t hash2 = get_std_hash(vec2); - - CHECK(hash1 != hash2); - - vec1.push_back(4); - hash1 = get_std_hash(vec1); - CHECK(hash1 == hash2); - } - - SUBCASE("map") { - std::map map1{{1, 2}}; - std::map map2{{1, 2}, {3, 4}}; - - size_t hash1 = get_std_hash(map1); - size_t hash2 = get_std_hash(map2); - - CHECK(hash1 != hash2); - - map1.insert({3, 4}); - hash1 = get_std_hash(map1); - CHECK(hash1 == hash2); - } - - SUBCASE("unordered_map") { - std::unordered_map map1{{1, 2}}; - std::unordered_map map2{{1, 2}, {3, 4}}; - - size_t hash1 = get_std_hash(map1); - size_t hash2 = get_std_hash(map2); - - CHECK(hash1 != hash2); - - map1.insert({3, 4}); - hash1 = get_std_hash(map1); - CHECK(hash1 == hash2); - } - - SUBCASE("set") { - std::set set1{1, 2, 3}; - std::set set2{1, 2, 3, 4}; - - size_t hash1 = get_std_hash(set1); - size_t hash2 = get_std_hash(set2); - - CHECK(hash1 != hash2); - - set1.insert(4); - hash1 = get_std_hash(set1); - CHECK(hash1 == hash2); - } - - SUBCASE("unordered_set") { - std::unordered_set set1{1, 2, 3}; - std::unordered_set set2{1, 2, 3, 4}; - - size_t hash1 = get_std_hash(set1); - size_t hash2 = get_std_hash(set2); - - CHECK(hash1 != hash2); - - set1.insert(4); - hash1 = get_std_hash(set1); - CHECK(hash1 == hash2); - } - - SUBCASE("tuple") { - std::tuple tuple1{1, "test", 3.14}; - std::tuple tuple2{2, "test", 3.14}; - - size_t hash1 = get_std_hash(tuple1); - size_t hash2 = get_std_hash(tuple2); - - CHECK(hash1 != hash2); - - std::get<0>(tuple1) = 2; - hash1 = get_std_hash(tuple1); - CHECK(hash1 == hash2); - } - } -} diff --git a/lib/utils/test/src/utils/hash/map.cc b/lib/utils/test/src/utils/hash/map.cc new file mode 100644 index 0000000000..b4da6ddb68 --- /dev/null +++ b/lib/utils/test/src/utils/hash/map.cc @@ -0,0 +1,20 @@ +#include "utils/hash/map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::map map1{{1, 2}}; + std::map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/set.cc b/lib/utils/test/src/utils/hash/set.cc new file mode 100644 index 0000000000..f9ccd925c6 --- /dev/null +++ b/lib/utils/test/src/utils/hash/set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::set set1{1, 2, 3}; + std::set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/tuple.cc b/lib/utils/test/src/utils/hash/tuple.cc new file mode 100644 index 0000000000..61240fd7b1 --- /dev/null +++ b/lib/utils/test/src/utils/hash/tuple.cc @@ -0,0 +1,21 @@ +#include "utils/hash/tuple.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::tuple tuple1{1, "test", 3.14}; + std::tuple tuple2{2, "test", 3.14}; + + size_t hash1 = get_std_hash(tuple1); + size_t hash2 = get_std_hash(tuple2); + + CHECK(hash1 != hash2); + + std::get<0>(tuple1) = 2; + hash1 = get_std_hash(tuple1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_map.cc b/lib/utils/test/src/utils/hash/unordered_map.cc new file mode 100644 index 0000000000..bc0b2879a4 --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_map.cc @@ -0,0 +1,20 @@ +#include "utils/hash/unordered_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_set.cc b/lib/utils/test/src/utils/hash/unordered_set.cc new file mode 100644 index 0000000000..299b3e5d13 --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_set set1{1, 2, 3}; + std::unordered_set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/vector.cc b/lib/utils/test/src/utils/hash/vector.cc new file mode 100644 index 0000000000..ee5bb456eb --- /dev/null +++ b/lib/utils/test/src/utils/hash/vector.cc @@ -0,0 +1,20 @@ +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::vector vec1{1, 2, 3}; + std::vector vec2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(vec1); + size_t hash2 = get_std_hash(vec2); + + CHECK(hash1 != hash2); + + vec1.push_back(4); + hash1 = get_std_hash(vec1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc index 84f2364fd8..ca1887e84d 100644 --- a/lib/utils/test/src/utils/join_strings.cc +++ b/lib/utils/test/src/utils/join_strings.cc @@ -1,5 +1,5 @@ #include "utils/join_strings.h" -#include "test/utils/doctest.h" +#include #include #include diff --git a/lib/utils/test/src/utils/random_utils.cc b/lib/utils/test/src/utils/random_utils.cc index 4f003301bd..8e7d22138f 100644 --- a/lib/utils/test/src/utils/random_utils.cc +++ b/lib/utils/test/src/utils/random_utils.cc @@ -1,11 +1,13 @@ #include "utils/random_utils.h" -#include "test/utils/doctest.h" #include "utils/containers/contains.h" #include "utils/containers/filter.h" #include "utils/containers/repeat.h" #include "utils/containers/sum.h" #include "utils/containers/zip.h" #include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("select_random(std::vector)") { @@ -31,7 +33,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::vector trials = repeat( num_iterations, [&]() { return select_random(values, weights); }); - for (auto const [v, w] : zip(values, weights)) { + for (std::pair const &p : zip(values, weights)) { + int v = p.first; + float w = p.second; float expectedProbability = w / sum(weights); int num_occurrences = filter(trials, [&](int c) { return (c == v); }).size(); diff --git a/lib/utils/test/src/utils/stack_map.cc b/lib/utils/test/src/utils/stack_map.cc index 20659d49a5..3ab4faaa80 100644 --- a/lib/utils/test/src/utils/stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -1,7 +1,7 @@ #include "utils/stack_map.h" -#include "test/utils/doctest.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" +#include using namespace FlexFlow; @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("stack_map") { stack_map map; - SUBCASE("bracket operator") { + SUBCASE("operator[]") { map[1] = 10; map[2] = 20; diff --git a/lib/utils/test/src/utils/stack_string.cc b/lib/utils/test/src/utils/stack_string.cc index e0da9044c4..8dbfe36c9d 100644 --- a/lib/utils/test/src/utils/stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -1,5 +1,5 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include "test/utils/rapidcheck.h" #include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc index bce068da20..6cdd91ece1 100644 --- a/lib/utils/test/src/utils/stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,5 +1,6 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" #include #include diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc index 729ad2797a..0d53868a92 100644 --- a/lib/utils/test/src/utils/type_index.cc +++ b/lib/utils/test/src/utils/type_index.cc @@ -1,8 +1,9 @@ #include "utils/type_index.h" +#include "test/utils/doctest/check_without_stringify.h" #include #include -using namespace FlexFlow; +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_type_index_for_type") { diff --git a/lib/utils/test/src/utils/variant.cc b/lib/utils/test/src/utils/variant.cc index 0740a7ecec..36b014b2e3 100644 --- a/lib/utils/test/src/utils/variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,9 +1,11 @@ #include "utils/variant.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/variant.h" #include "test/utils/rapidcheck.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/variant.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { diff --git a/lib/utils/test/src/utils/vector.cc b/lib/utils/test/src/utils/vector.cc index c6eb0828b8..18eda49543 100644 --- a/lib/utils/test/src/utils/vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,9 @@ #include "utils/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { From 27a0cfe3431810f654eae78e0224fd0aa47928b6 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Wed, 9 Oct 2024 22:22:28 -0700 Subject: [PATCH 20/21] Small bugfix and remove some unnecessary includes --- lib/kernels/include/kernels/accessor.h | 1 - .../include/kernels/initializer_kernels.h | 1 - lib/pcg/include/pcg/optimizer_attrs.h | 2 +- lib/utils/include/utils/graph/cow_ptr_t.h | 2 -- lib/utils/include/utils/rapidcheck/variant.h | 18 ++++++++++++++++++ lib/utils/include/utils/unique.h | 13 ------------- lib/utils/include/utils/variant.h | 11 ----------- lib/utils/src/utils/rapidcheck/variant.cc | 11 +++++++++++ lib/utils/test/src/utils/rapidcheck/variant.cc | 13 +++++++++++++ lib/utils/test/src/utils/variant.cc | 7 ------- 10 files changed, 43 insertions(+), 36 deletions(-) create mode 100644 lib/utils/include/utils/rapidcheck/variant.h delete mode 100644 lib/utils/include/utils/unique.h create mode 100644 lib/utils/src/utils/rapidcheck/variant.cc create mode 100644 lib/utils/test/src/utils/rapidcheck/variant.cc diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 5fbcd91a06..39da65c3be 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -7,7 +7,6 @@ #include "op-attrs/datatype.h" #include "utils/exception.h" #include "utils/required.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 52609a303f..9840e457e6 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -4,7 +4,6 @@ #include "accessor.h" #include "kernels/cpu.h" #include "op-attrs/datatype_value.dtg.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 4bac74b999..3e787503d6 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -3,7 +3,7 @@ #include "pcg/optimizers/adam_optimizer_attrs.h" #include "pcg/optimizers/sgd_optimizer_attrs.h" -#include "utils/variant.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9a655ae072..7aed437136 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -2,8 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_COW_PTR_T_H #include "utils/type_traits.h" -#include "utils/unique.h" -#include "utils/variant.h" #include #include diff --git a/lib/utils/include/utils/rapidcheck/variant.h b/lib/utils/include/utils/rapidcheck/variant.h new file mode 100644 index 0000000000..1a295e19e1 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf(gen::construct>(gen::arbitrary())...); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/unique.h b/lib/utils/include/utils/unique.h deleted file mode 100644 index cf6eb39026..0000000000 --- a/lib/utils/include/utils/unique.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H -#define _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H - -#include - -namespace FlexFlow { -template -std::unique_ptr make_unique(Args &&...args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb2286a9cd..241d631200 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -213,15 +213,4 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::oneOf(gen::cast>(gen::arbitrary())...); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/src/utils/rapidcheck/variant.cc b/lib/utils/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..5fbc9f6910 --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/variant.cc @@ -0,0 +1,11 @@ +#include "utils/rapidcheck/variant.h" + +namespace rc { + +using T0 = int; +using T1 = std::string; + +template + struct Arbitrary>; + +} // namespace rc diff --git a/lib/utils/test/src/utils/rapidcheck/variant.cc b/lib/utils/test/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..e201c7af8f --- /dev/null +++ b/lib/utils/test/src/utils/rapidcheck/variant.cc @@ -0,0 +1,13 @@ +#include "utils/rapidcheck/variant.h" +#include +#include "test/utils/rapidcheck.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Arbitrary") { + RC_SUBCASE("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + }); + } +} diff --git a/lib/utils/test/src/utils/variant.cc b/lib/utils/test/src/utils/variant.cc index 36b014b2e3..3f6feadda0 100644 --- a/lib/utils/test/src/utils/variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,7 +1,6 @@ #include "utils/variant.h" #include "test/utils/doctest/fmt/optional.h" #include "test/utils/doctest/fmt/variant.h" -#include "test/utils/rapidcheck.h" #include using namespace ::FlexFlow; @@ -74,10 +73,4 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } - - TEST_CASE("Arbitrary") { - RC_SUBCASE("valid type", [](std::variant v) { - return std::holds_alternative(v) || std::holds_alternative(v); - }); - } } From c10c7407d798cce0305b47bd493fb6d52ce214fd Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 10 Oct 2024 07:33:16 -0700 Subject: [PATCH 21/21] Format --- lib/utils/include/utils/rapidcheck/variant.h | 5 +++-- lib/utils/src/utils/rapidcheck/variant.cc | 3 +-- lib/utils/test/src/utils/rapidcheck/variant.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/utils/include/utils/rapidcheck/variant.h b/lib/utils/include/utils/rapidcheck/variant.h index 1a295e19e1..bc741ea340 100644 --- a/lib/utils/include/utils/rapidcheck/variant.h +++ b/lib/utils/include/utils/rapidcheck/variant.h @@ -1,15 +1,16 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H -#include #include +#include namespace rc { template struct Arbitrary> { static Gen> arbitrary() { - return gen::oneOf(gen::construct>(gen::arbitrary())...); + return gen::oneOf( + gen::construct>(gen::arbitrary())...); } }; diff --git a/lib/utils/src/utils/rapidcheck/variant.cc b/lib/utils/src/utils/rapidcheck/variant.cc index 5fbc9f6910..f0537d454f 100644 --- a/lib/utils/src/utils/rapidcheck/variant.cc +++ b/lib/utils/src/utils/rapidcheck/variant.cc @@ -5,7 +5,6 @@ namespace rc { using T0 = int; using T1 = std::string; -template - struct Arbitrary>; +template struct Arbitrary>; } // namespace rc diff --git a/lib/utils/test/src/utils/rapidcheck/variant.cc b/lib/utils/test/src/utils/rapidcheck/variant.cc index e201c7af8f..ca830fb3ab 100644 --- a/lib/utils/test/src/utils/rapidcheck/variant.cc +++ b/lib/utils/test/src/utils/rapidcheck/variant.cc @@ -1,6 +1,6 @@ #include "utils/rapidcheck/variant.h" -#include #include "test/utils/rapidcheck.h" +#include using namespace ::FlexFlow;