Skip to content

Commit

Permalink
refactor the pattern graph to be OutputLabelledOpenMultiDiGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Aug 25, 2023
1 parent cc5837b commit d1aa92f
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 55 deletions.
3 changes: 2 additions & 1 deletion lib/substitutions/include/substitutions/graph_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace FlexFlow {
struct GraphPattern
: public strong_typedef<
GraphPattern,
LabelledOpenMultiDiGraph<OperatorPattern, ParallelTensorPattern>> {
OutputLabelledOpenMultiDiGraph<OperatorPattern,
ParallelTensorPattern>> {
using strong_typedef::strong_typedef;
};

Expand Down
11 changes: 5 additions & 6 deletions lib/substitutions/include/substitutions/output_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@

namespace FlexFlow {

using GraphAttributeKey = variant<OperatorAttributeKey, TensorAttributeKey>;
using GraphAttributeValue =
variant<int, float, bool, std::vector<int>, OperatorType, Activation>;

// NOTE(@wmdi) I am not sure whether these should be part of attribute expr.
struct NodeAttrAccess {
Node node;
GraphAttributeKey attr_expr;
AttributeExpr<OperatorAttributeKey> attr_expr;
};

struct EdgeAttrAccess {
OpenMultiDiEdge edge;
GraphAttributeKey attr_expr;
AttributeExpr<TensorAttributeKey> attr_expr;
};

struct AttrConstant {
Expand All @@ -32,7 +31,7 @@ enum class AttrOpType { ADD, SUB, MUL, DIV };
struct AttrUnary {
AttrOpType op_type;
GraphAttributeExprLeaf lhs;
GraphAttributeExprLeaf rhs;
GraphAttributeValue rhs;
};

struct AttrBinary {
Expand All @@ -57,8 +56,8 @@ struct ParallelTensorAttrAssignment {
struct OutputGraph
: public strong_typedef<
OutputGraph,
OutputLabelledMultiDiGraph<OperatorAttrAssignment,
ParallelTensorAttrAssignment>> {
OutputLabelledOpenMultiDiGraph<OperatorAttrAssignment,
ParallelTensorAttrAssignment>> {
using strong_typedef::strong_typedef;
};

Expand Down
2 changes: 1 addition & 1 deletion lib/substitutions/include/substitutions/substitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace FlexFlow {
struct Substitution {
GraphPattern input_graph;
OutputGraph output_graph;
bidict<InputMultiDiEdge, IutputMultiDiEdge> input_mapping;
bidict<InputMultiDiEdge, InputMultiDiEdge> input_mapping;
bidict<OutputMultiDiEdge, OutputMultiDiEdge> output_mapping;
};

Expand Down
10 changes: 6 additions & 4 deletions lib/substitutions/src/graph_pattern.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "substitutions/graph_pattern.h"
#include "op-attrs/operator_attrs.h"
#include "op-attrs/parallel_tensor_shape.h"
#include "pcg/parallel_computation_graph.h"
#include "substitutions/get_attribute.h"
#include "substitutions/graph_pattern_match.h"
#include "substitutions/operator_pattern.h"
#include "substitutions/parallel_tensor_pattern.h"

Expand Down Expand Up @@ -72,14 +75,14 @@ struct EvaluateTensorAttributeExpr {
switch (key) {
case TensorAttributeKey::DIM_SIZES: {
std::vector<int> result;
for (ParallelDim const &dim : this->tensor_shape) {
for (ParallelDim const &dim : this->tensor_shape.dims) {
result.push_back(dim.size);
}
return result;
}
case TensorAttributeKey::DIM_DEGREES: {
std::vector<int> result;
for (ParallelDim const &dim : this->tensor_shape) {
for (ParallelDim const &dim : this->tensor_shape.dims) {
result.push_back(dim.degree);
}
return result;
Expand Down Expand Up @@ -201,8 +204,7 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg,
result &= constraintResult.value_or(false);
}

result &= pattern_matches(
OpenMultiDiGraphView(pattern), MultiDiGraphView(pcg), patternMatch);
result &= pattern_matches(pattern, pcg, patternMatch);

return result;
}
Expand Down
109 changes: 66 additions & 43 deletions lib/substitutions/src/substitution.cc
Original file line number Diff line number Diff line change
@@ -1,60 +1,84 @@
#include "substitutions/substitution.h"
#include

namespace FlexFlow {

template <typename T>
GraphAttributeValue
graph_attribute_value_op(AttrOpType op, T const &lhs, T const &rhs) {
switch (op) {
case AttrOpType::ADD:
return lhs + rhs;
break;
case AttrOpType::SUB:
return lhs - rhs;
break;
case AttrOpType::MUL:
return lhs * rhs;
break;
case AttrOpType::DIV:
return lhs / rhs;
break;
default:
mk_runtime_error("Unknown attribute operator type");
struct GraphAttributeValueOp {
template <typename T>
GraphAttributeValue operator()(T const &lhs, T const &rhs) {
switch (op) {
case AttrOpType::ADD:
return lhs + rhs;
break;
case AttrOpType::SUB:
return lhs - rhs;
break;
case AttrOpType::MUL:
return lhs * rhs;
break;
case AttrOpType::DIV:
return lhs / rhs;
break;
default:
mk_runtime_error("Unknown attribute operator type");
}
}
AttrOpType op;
};

GraphAttributeValue graph_attribute_value_op(AttrOpType op,
GraphAttributeValue const &lhs,
GraphAttributeValue const &rhs) {
visit(GraphAttributeValueOp{op}, lhs, rhs);
}

struct EvaluateGraphAttributeExpr {
template <typename... Ts>
GraphAttributeValue operator()(Ts... const &ts) {
return evaluate(ts);
struct EvaluateGraphAttributeExprLeaf {
template <typename T>
GraphAttributeValue operator()(T const &t) {
return evaluate(t);
}

template <typename T>
GraphAttributeValue evaluate(NodeAttrAccess<T> const &t) {
GraphAttributeValue evaluate(NodeAttrAccess const &t) {
Node node_in_pattern = t.node;
Node node_in_pcg = match.nodeAssignment.at_l(node_in_pattern);
return evaluate_attribute_expr(node_in_pcg, t.attr_expr);
return widen<GraphAttributeValue>(
evaluate_attribute_expr(graph.at(node_in_pcg), t.attr_expr).value());
}

GraphAttributeValue evaluate(EdgeAttrAccess const &t) {
OpenMultiDiEdge output_in_pattern = t.edge;
MultiDiEdge output_in_pcg = match.edgeAssignment.at_l(output_in_pattern);
return widen<GraphAttributeValue>(
evaluate_attribute_expr(graph.at(output_in_pcg), t.attr_expr).value());
}

ParallelComputationGraph const &graph;
DiGraphPatternMatch const &match;
};

GraphAttributeValue
evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g,
DiGraphPatternMatch const &match,
GraphAttributeExprLeaf const &expr) {
return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr);
}

struct EvaluateGraphAttributeExpr {
template <typename T>
GraphAttributeValue evaluate(EdgeAttrAccess<T> const &t) {
OpenMultiDiEdge edge_in_pattern = t.edge;
MultiDiEdge edge_in_pcg = match.edgeAssignment.at_l(edge_in_pattern);
return evaluate_attribute_expr(edge_in_pcg, t.attr_expr);
GraphAttributeValue operator()(T const &t) {
return evaluate(t);
}

template <typename L, typename R>
GraphAttributeValue evaluate(AttrUnary<L, R> const &t) {
auto lhs = (*this)(t.lhs).value();
auto rhs = t.rhs;
return graph_attribute_value_op(lhs, rhs);
GraphAttributeValue evaluate(AttrUnary const &expr) {
auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs);
auto rhs = expr.rhs;
return graph_attribute_value_op(expr.op_type, lhs, rhs);
}

template <typename L, typename R>
GraphAttributeValue evaluate(AttrBinary<L, R> const &t) {
auto lhs = (*this)(t.lhs).value();
auto rhs = (*this)(t.rhs).value();
return graph_attribute_value_op(lhs, rhs);
GraphAttributeValue evaluate(AttrBinary const &expr) {
auto lhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.lhs);
auto rhs = evaluate_graph_attribute_expr_leaf(graph, match, expr.rhs);
return graph_attribute_value_op(expr.op_type, lhs, rhs);
}

EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph,
Expand All @@ -65,11 +89,10 @@ struct EvaluateGraphAttributeExpr {
DiGraphPatternMatch const &match;
};

template <typename... Ts>
GraphAttributeValue
evaluate_graph_attribute_expr(ParallelComputationGraph const &graph,
DiGraphPatternMatch const &match,
GraphAttributeExpr<Ts...> const &expr) {
GraphAttributeExpr const &expr) {
return visit(EvaluateGraphAttributeExpr(graph, match), expr);
}

Expand All @@ -93,12 +116,12 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg,
ParallelComputationGraph::create<UnorderedOutputLabelledMultiDiGraph>();
bidict<Node, Node> node_mapping; // Refactor it with global nodes
for (Node const &node : get_nodes(pcg)) {
if (!contains_r(match.nodeAssignment)) {
if (!contains_r(match.nodeAssignment, node)) {
node_mapping.equate(node, new_pcg.add_node(pcg.at(node)));
}
}
for (MultiDiEdge const &edge : get_edges(pcg)) {
if (!contains_r(match.edgeAssignment)) {
if (!contains_r(match.edgeAssignment, edge)) {
new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src),
node_mapping.at_r(edge.dst),
new_pcg.add_node_port(),
Expand Down
17 changes: 17 additions & 0 deletions lib/utils/include/utils/graph/labelled/output_labelled_open.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN
#define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN

namespace FlexFlow {

template <typename NodeLabel, typename InputLabel, typename OutputLabel=InputLabel>
struct OutputLabelledOpenMultiDiGraph {
OutputLabelledOpenMultiDiGraph() = delete;
OutputLabelledOpenMultiDiGraph(OutputLabelledOpenMultiDiGraph const &) = default;
OutputLabelledOpenMultiDiGraph& operator=(OutputLabelledOpenMultiDiGraph const &) = default;

operator OpenMultiDiGraphView();
};

}

#endif
1 change: 1 addition & 0 deletions lib/utils/include/utils/graph/labelled_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "labelled/node_labelled.h"
#include "labelled/open_algorithms.h"
#include "labelled/output_labelled.h"
#include "labelled/output_labelled_open.h"
#include "labelled/standard_labelled.h"
#include "labelled/unordered_labelled_graphs.h"

Expand Down

0 comments on commit d1aa92f

Please sign in to comment.