Skip to content

Commit

Permalink
merge the conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
lambda7xx committed Jul 15, 2023
1 parent a85093d commit 05bd453
Show file tree
Hide file tree
Showing 18 changed files with 367 additions and 353 deletions.
2 changes: 1 addition & 1 deletion lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
#add_subdirectory(kernels)
add_subdirectory(utils)
# add_subdirectory(ffi)
add_subdirectory(substitutions)
#add_subdirectory(substitutions)
4 changes: 1 addition & 3 deletions lib/utils/include/utils/graph/digraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ struct DirectedEdgeQuery {
query_set<Node> srcs;
query_set<Node> dsts;

static DirectedEdgeQuery all() {
NOT_IMPLEMENTED();
}
static DirectedEdgeQuery all();
};
FF_VISITABLE_STRUCT(DirectedEdgeQuery, srcs, dsts);

Expand Down
1 change: 1 addition & 0 deletions lib/utils/include/utils/graph/multidiedge.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct MultiDiEdge {
NodePort srcIdx, dstIdx;
};
FF_VISITABLE_STRUCT(MultiDiEdge, src, dst, srcIdx, dstIdx);
std::ostream& operator<<(std::ostream& os, const MultiDiEdge& edge);

struct MultiDiInput {
Node node;
Expand Down
8 changes: 7 additions & 1 deletion lib/utils/include/utils/graph/multidigraph_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ struct MultiDiEdgeQuery {
MultiDiEdgeQuery with_src_idxs(query_set<NodePort> const &) const;
MultiDiEdgeQuery with_dst_idxs(query_set<NodePort> const &) const;

MultiDiEdgeQuery with_dst_node(Node const &) const;
MultiDiEdgeQuery with_src_node(Node const &) const;
MultiDiEdgeQuery with_src_idx(NodePort const &) const;
MultiDiEdgeQuery with_dst_idx(NodePort const &) const;

static MultiDiEdgeQuery all();
};
FF_VISITABLE_STRUCT(MultiDiEdgeQuery, srcs, dsts, srcIdxs, dstIdxs);
Expand All @@ -35,7 +40,7 @@ struct IMultiDiGraphView : public IGraphView {
using EdgeQuery = MultiDiEdgeQuery;

virtual std::unordered_set<Edge> query_edges(EdgeQuery const &) const = 0;
virtual ~IMultiDiGraphView();
virtual ~IMultiDiGraphView()=default;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraphView);

Expand All @@ -51,6 +56,7 @@ struct IMultiDiGraph : public IMultiDiGraphView, public IGraph {
}

virtual IMultiDiGraph *clone() const override = 0;
std::size_t next_nodeport_idx = 0;
};
CHECK_RC_COPY_VIRTUAL_COMPLIANT(IMultiDiGraph);

Expand Down
5 changes: 2 additions & 3 deletions lib/utils/include/utils/graph/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@ struct Node : public strong_typedef<Node, size_t> {
};
FF_TYPEDEF_HASHABLE(Node);
FF_TYPEDEF_PRINTABLE(Node, "Node");
std::ostream &operator<<(std::ostream &, Node const &);

struct NodeQuery {
NodeQuery(query_set<Node> const &nodes) : nodes(nodes) {}

query_set<Node> nodes;

static NodeQuery all() {
NOT_IMPLEMENTED();
}
static NodeQuery all();
};
FF_VISITABLE_STRUCT(NodeQuery, nodes);

Expand Down
5 changes: 2 additions & 3 deletions lib/utils/include/utils/graph/open_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ struct OpenMultiDiGraph {
return OpenMultiDiGraph(make_unique<T>());
}

private:
OpenMultiDiGraph(std::unique_ptr<IOpenMultiDiGraph>);

OpenMultiDiGraph(std::unique_ptr<IOpenMultiDiGraph> ptr);

private:
cow_ptr_t<IOpenMultiDiGraph> ptr;
};
Expand Down
5 changes: 2 additions & 3 deletions lib/utils/include/utils/graph/query_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ struct query_set {
query_set(T const &) {
NOT_IMPLEMENTED();
}
query_set(std::unordered_set<T> const &) {
NOT_IMPLEMENTED();
}
query_set(std::unordered_set<T> const & query): query(query) {}

query_set(optional<std::unordered_set<T>> const &) {
NOT_IMPLEMENTED();
}
Expand Down
8 changes: 4 additions & 4 deletions lib/utils/include/utils/graph/serialparallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ struct Serial {
std::vector<variant<Parallel, Node>> children;
};

FF_VISITABLE_STRUCT(Serial, children);
MAKE_VISIT_HASHABLE(Serial);
// FF_VISITABLE_STRUCT(Serial, children);
// MAKE_VISIT_HASHABLE(Serial);

struct Parallel {
std::vector<variant<Serial, Node>> children;
};

FF_VISITABLE_STRUCT(Parallel, children);
MAKE_VISIT_HASHABLE(Parallel);
// FF_VISITABLE_STRUCT(Parallel, children);
// MAKE_VISIT_HASHABLE(Parallel);

using SerialParallelDecomposition = variant<Serial, Parallel, Node>;

Expand Down
4 changes: 1 addition & 3 deletions lib/utils/include/utils/graph/undirected.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ namespace FlexFlow {
struct UndirectedEdgeQuery {
query_set<Node> nodes;

static UndirectedEdgeQuery all() {
NOT_IMPLEMENTED();
}
static UndirectedEdgeQuery all();
};
FF_VISITABLE_STRUCT(UndirectedEdgeQuery, nodes);

Expand Down
1 change: 1 addition & 0 deletions lib/utils/src/graph/adjacency_multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace FlexFlow {
Node AdjacencyMultiDiGraph::add_node() {
Node node{this->next_node_idx};
adjacency[node];
std::cout<<"add node "<<node.value()<<std::endl;
this->next_node_idx++;
return node;
}
Expand Down
22 changes: 13 additions & 9 deletions lib/utils/src/graph/digraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,23 @@ DiGraphView unsafe_create(IDiGraphView const &graphView) {
return DiGraphView(ptr);
}

DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs,
DirectedEdgeQuery const &rhs) {
DirectedEdgeQuery DirectedEdgeQuery::all() {
return {matchall<Node>(),
matchall<Node>()};
}

DirectedEdgeQuery query_intersection(DirectedEdgeQuery const &lhs, DirectedEdgeQuery const &rhs){
assert(lhs != tl::nullopt);
assert(rhs != tl::nullopt);
assert(lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() &&
rhs.dsts.has_value());
assert (lhs.srcs.has_value() && lhs.dsts.has_value() && rhs.srcs.has_value() && rhs.dsts.has_value());

tl::optional<std::unordered_set<Node>> srcs_t1 =
intersection(*lhs.srcs, *rhs.srcs);
tl::optional<std::unordered_set<Node>> dsts_t1 =
intersection(*lhs.dsts, *rhs.dsts);
std::unordered_set<Node> srcs_t1 = intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 = intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));

return DirectedEdgeQuery(srcs_t1, dsts_t1);
DirectedEdgeQuery result = DirectedEdgeQuery::all();
result.srcs = srcs_t1;
result.dsts = dsts_t1;
return result;
}

} // namespace FlexFlow
45 changes: 26 additions & 19 deletions lib/utils/src/graph/multidigraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,11 @@ MultiDiEdgeQuery
return e;
}

<<<<<<< HEAD
std::ostream &operator<<(std::ostream &os, MultiDiEdge const &edge) {
return os << "MultiDiEdge{" << edge.src.value() << "," << edge.dst.value()
<< "," << edge.srcIdx.value() << "," << edge.dstIdx.value() << "}";
}

MultiDiEdgeQuery::MultiDiEdgeQuery(
tl::optional<std::unordered_set<Node>> const &srcs,
tl::optional<std::unordered_set<Node>> const &dsts,
tl::optional<std::unordered_set<NodePort>> const &srcIdxs,
tl::optional<std::unordered_set<NodePort>> const &dstIdxs)
: srcs(srcs), dsts(dsts), srcIdxs(srcIdxs), dstIdxs(dstIdxs) {}

MultiDiEdgeQuery MultiDiEdgeQuery::with_src_node(Node const &n) const {
return this->with_src_nodes({n});
}
Expand All @@ -47,26 +39,41 @@ MultiDiEdgeQuery
return e;
}

MultiDiEuery query_intersection(MultiDiEdgeQuery const &lhs,
MultiDiEdgeQuery query_intersection(MultiDiEdgeQuery const &lhs,
MultiDiEdgeQuery const &rhs) {
assert(lhs.srcs.has_value() && dgeQ lhs.dsts.has_value() &&
rhs.srcs.has_value() && rhs.dsts.has_value());
tl::optional<std::unordered_set<Node>> srcs =
intersection(*lhs.srcs, *rhs.srcs);
tl::optional<std::unordered_set<Node>> dsts =
intersection(*lhs.dsts, *rhs.dsts);
return MultiDiEdgeQuery(srcs, dsts);
assert(lhs != tl::nullopt);
assert(rhs != tl::nullopt);

std::unordered_set<Node> srcs_t1 = intersection(allowed_values(lhs.srcs), allowed_values(rhs.srcs));
std::unordered_set<Node> dsts_t1 = intersection(allowed_values(lhs.dsts), allowed_values(rhs.dsts));

std::unordered_set<NodePort> srcIdxs_t1 = intersection(allowed_values(lhs.srcIdxs), allowed_values(rhs.srcIdxs));
std::unordered_set<NodePort> dstIdxs_t1 = intersection(allowed_values(lhs.dstIdxs), allowed_values(rhs.dstIdxs));

MultiDiEdgeQuery e = MultiDiEdgeQuery::all();
e.srcs = srcs_t1;
e.dsts = dsts_t1;
e.srcIdxs = srcIdxs_t1;
e.dstIdxs = dstIdxs_t1;
return e;
}

MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_node(Node const &n) const {
return this->with_dst_nodes({n});
}
MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idx(NodePort const & p) const {
return this->with_src_idxs({p});
}

MultiDiEdgeQuery MultiDiEdgeQuery::with_dst_idx(NodePort const & p) const {
return this->with_dst_idxs({p});
}

MultiDiEdgeQuery MultiDiEdgeQuery::with_src_idxs(
std::unordered_set<NodePort> const &idxs) const {
query_set<NodePort> const &idxs) const {
MultiDiEdgeQuery e{*this};
if (e.srcIdxs != tl::nullopt) {
throw std::runtime_error("expected srcIdxs == tl::nullopt");
if (is_matchall(e.srcIdxs)) {
throw mk_runtime_error("Expected matchall previous value");
}
e.srcIdxs = idxs;
return e;
Expand Down
20 changes: 13 additions & 7 deletions lib/utils/src/graph/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,23 @@ std::ostream &operator<<(std::ostream &os, Node const &node) {
return os << fmt::format("Node({})", node.value());
}

NodeQuery::NodeQuery(std::unordered_set<Node> const &nodes)
: NodeQuery(tl::optional<std::unordered_set<Node>>{nodes}) {}
NodeQuery NodeQuery::all() {
return {matchall<Node>()} ;
}

NodeQuery::NodeQuery(tl::optional<std::unordered_set<Node>> const &nodes)
: nodes(nodes) {}

NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) {
NodeQuery query_intersection(NodeQuery const & lhs, NodeQuery const & rhs) {
assert(lhs != tl::nullopt && rhs != tl::nullopt);
return intersection(*lhs.nodes, *rhs.nodes);
std::unordered_set<Node> nodes = intersection(allowed_values(lhs.nodes), allowed_values(rhs.nodes));
NodeQuery intersection_result = NodeQuery::all();
intersection_result.nodes = nodes;
return intersection_result;
}

// NodeQuery query_intersection(NodeQuery const &lhs, NodeQuery const &rhs) {
// assert(lhs != tl::nullopt && rhs != tl::nullopt);
// return intersection(*lhs.nodes, *rhs.nodes);
// }

std::unordered_set<Node> GraphView::query_nodes(NodeQuery const &g) const {
return this->ptr->query_nodes(g);
}
Expand Down
2 changes: 1 addition & 1 deletion lib/utils/src/graph/open_graphs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ std::unordered_set<OpenMultiDiEdge>
}

OpenMultiDiGraph::OpenMultiDiGraph(OpenMultiDiGraph const &other)
: ptr(other.ptr->clone()) {}
: ptr(other.ptr) {}

OpenMultiDiGraph &OpenMultiDiGraph::operator=(OpenMultiDiGraph other) {
swap(*this, other);
Expand Down
4 changes: 4 additions & 0 deletions lib/utils/src/graph/undirected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ namespace FlexFlow {
UndirectedEdge::UndirectedEdge(Node const &src, Node const &dst)
: smaller(std::min(smaller, bigger)), bigger(std::max(smaller, bigger)) {}

UndirectedEdgeQuery UndirectedEdgeQuery::all() {
return {matchall<Node>()};
}

UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs,
UndirectedEdgeQuery const &rhs) {
return {
Expand Down
39 changes: 14 additions & 25 deletions lib/utils/src/graph/views.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ std::unordered_set<Node> ViewOpenMultiDiGraphAsMultiDiGraph::query_nodes(

std::unordered_set<MultiDiEdge> ViewOpenMultiDiGraphAsMultiDiGraph::query_edges(
MultiDiEdgeQuery const &query) const {
OpenMultiDiEdgeQuery q;
q.standard_edge_query = query;
std::unordered_set<OpenMultiDiEdge> edges = g.query_edges(q);
std::unordered_set<MultiDiEdge> result;

for (auto const &edge : edges) {
if (holds_alternative<MultiDiEdge>(edge)) {
result.insert(get<MultiDiEdge>(edge));
}
}

// OpenMultiDiEdgeQuery q{query};
// // q.standard_edge_query = query;
// std::unordered_set<OpenMultiDiEdge> edges = g.query_edges(q);
// std::unordered_set<MultiDiEdge> result;

return result;
// for (auto const &edge : edges) {
// if (holds_alternative<MultiDiEdge>(edge)) {
// result.insert(get<MultiDiEdge>(edge));
// }
// }

// return result;
NOT_IMPLEMENTED();
}

DirectedEdge flipped(DirectedEdge const &e) {
Expand Down Expand Up @@ -229,15 +231,9 @@ std::unordered_set<Node>

std::unordered_set<DirectedEdge>
JoinedDigraphView::query_edges(DirectedEdgeQuery const &query) const {
<<<<<<< HEAD
std::unordered_set<Node> srcs =
query.srcs.value_or(get_nodes(unsafe_create(*this)));
std::unordered_set<Node> dsts =
query.dsts.value_or(get_nodes(unsafe_create(*this)));
=======

std::unordered_set<Node> srcs = this->query_nodes(query.srcs);
std::unordered_set<Node> dsts = this->query_nodes(query.dsts);
>>>>>>> 804707dc140044e43c2436f6a40fe0e9889f5a0a
auto traced_srcs = this->joined_nodes.trace_nodes(srcs);
auto traced_dsts = this->joined_nodes.trace_nodes(dsts);
DirectedEdgeQuery left_query = {traced_srcs.first, traced_dsts.first};
Expand Down Expand Up @@ -275,15 +271,8 @@ std::unordered_set<Node>

std::unordered_set<MultiDiEdge>
JoinedMultiDigraphView::query_edges(MultiDiEdgeQuery const &query) const {
<<<<<<< HEAD
std::unordered_set<Node> srcs =
query.srcs.value_or(get_nodes(unsafe_create(*this)));
std::unordered_set<Node> dsts =
query.dsts.value_or(get_nodes(unsafe_create(*this)));
=======
std::unordered_set<Node> srcs = this->query_nodes(query.srcs);
std::unordered_set<Node> dsts = this->query_nodes(query.dsts);
>>>>>>> 804707dc140044e43c2436f6a40fe0e9889f5a0a

auto traced_srcs = this->joined_nodes.trace_nodes(srcs);
auto traced_dsts = this->joined_nodes.trace_nodes(dsts);
Expand Down
Loading

0 comments on commit 05bd453

Please sign in to comment.