Skip to content

Commit

Permalink
remove output tensor computation
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Sep 4, 2023
1 parent c2b6b04 commit 21b8549
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 123 deletions.
12 changes: 6 additions & 6 deletions lib/compiler/src/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ SerialParallelDecomposition
std::vector<MultiDiEdge>
get_sorted_node_input_edges(ParallelComputationGraph const &pcg,
Node const &n) {
std::unordered_map<size_t, std::unordered_set<MultiDiEdge>> incoming_edges =
std::unordered_map<NodePort, std::unordered_set<MultiDiEdge>> incoming_edges =
get_incoming_edges_by_idx(pcg, n);

std::vector<MultiDiEdge> result;
Expand All @@ -36,11 +36,11 @@ std::unordered_map<MultiDiEdge, ParallelTensorShape>

auto outgoing_edges = get_outgoing_edges_by_idx(pcg, n);

for (std::size_t i = 0; i < output_tensor_shapes.size(); i++) {
if (contains_key(outgoing_edges, i)) {
for (MultiDiEdge const &e : outgoing_edges.at(i)) {
result.insert({e, output_tensor_shapes[i]});
}
int i = 0;

for (auto const &[node_port, edges] : outgoing_edges) {
for (MultiDiEdge const &e : edges) {
result.insert({e, output_tensor_shapes[i++]});
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions lib/substitutions/include/substitutions/attribute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ struct AttributeConstraint {
V attribute_value;
};

template <typename K, typename V>
struct AttributePattern {
std::unordered_set<AttributeConstraint<K, V>> attribute_constraints;
};

} // namespace FlexFlow

#endif
7 changes: 2 additions & 5 deletions lib/substitutions/include/substitutions/operator_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,9 @@ enum class OperatorAttributeKey {
using OperatorAttributeValue =
variant<int, float, bool, std::vector<int>, OperatorType, Activation>;

using OperatorAttributeConstraint =
AttributeConstraint<OperatorAttributeKey, OperatorAttributeValue>;
using OperatorAttributeConstraint = AttributeConstraint<OperatorAttributeKey, OperatorAttributeValue>;

struct OperatorPattern {
std::unordered_set<OperatorAttributeConstraint> attribute_constraints;
};
using OperatorPattern = AttributePattern<OperatorAttributeKey, OperatorAttributeValue>;

optional<OperatorAttributeValue>
evaluate_attribute_expr(Operator const &attrs,
Expand Down
25 changes: 6 additions & 19 deletions lib/substitutions/include/substitutions/output_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,29 @@

namespace FlexFlow {

using GraphAttributeValue =
variant<int, float, bool, std::vector<int>, OperatorType, Activation>;

// NOTE(@wmdi) I am not sure whether these should be part of attribute expr.
struct NodeAttrAccess {
struct OperatorAttrAccess {
Node node;
AttributeExpr<OperatorAttributeKey> attr_expr;
};

struct EdgeAttrAccess {
OpenMultiDiEdge edge;
AttributeExpr<TensorAttributeKey> attr_expr;
};

struct AttrConstant {
GraphAttributeValue value;
OperatorAttributeValue value;
};

using GraphAttributeExpr =
variant<NodeAttrAccess, EdgeAttrAccess, AttrConstant>;
using OperatorAttributeExpr =
variant<OperatorAttrAccess, AttrConstant>;

// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can
// define the assignment for each operator type.
struct OperatorAttrAssignment {
std::unordered_map<OperatorAttributeKey, GraphAttributeExpr> assignment;
};

struct ParallelTensorAttrAssignment {
std::unordered_map<TensorAttributeKey, GraphAttributeExpr> assignment;
std::unordered_map<OperatorAttributeKey, OperatorAttributeExpr> assignments;
};

struct OutputGraphExpr
: public strong_typedef<
OutputGraphExpr,
OutputLabelledOpenMultiDiGraph<OperatorAttrAssignment,
ParallelTensorAttrAssignment>> {
NodeLabelledOpenMultiDiGraph<OperatorAttrAssignment>> {
using strong_typedef::strong_typedef;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@

namespace FlexFlow {

enum class TensorDimensionAttribute { SIZE, DEGREE };

enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES };

using TensorAttributeValue = variant<int, std::vector<int>>;

using TensorAttributeConstraint =
AttributeConstraint<TensorAttributeKey, TensorAttributeValue>;

struct ParallelTensorPattern {
std::unordered_set<TensorAttributeConstraint> attribute_constraints;
};
using ParallelTensorPattern = AttributePattern<TensorAttributeKey, TensorAttributeValue>;

optional<TensorAttributeValue>
evaluate_attribute_expr(ParallelTensor const &tensor_shape,
Expand Down
8 changes: 4 additions & 4 deletions lib/substitutions/src/graph_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,19 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg,
GraphPattern const &pattern,
MultiDiGraphPatternMatch const &patternMatch) {
bool result = true;
for (auto const &kv : patternMatch.nodeAssignment) {
for (auto const &kv : patternMatch.node_assignment) {
auto patternNode = kv.first;
auto pcgNode = kv.second;
optional<bool> constraintResult =
satisfies(pcg.at(pcgNode), pattern.at(patternNode));
satisfies(pcg->at(pcgNode), pattern->at(patternNode));
result &= constraintResult.value_or(false);
}

for (auto const &kv : patternMatch.edgeAssignment) {
for (auto const &kv : patternMatch.edge_assignment) {
auto patternEdge = kv.first;
auto pcgEdge = kv.second;
optional<bool> constraintResult =
satisfies(pcg.at(pcgEdge), pattern.at(patternEdge));
satisfies(pcg->at(pcgEdge), pattern->at(patternEdge));
result &= constraintResult.value_or(false);
}

Expand Down
22 changes: 0 additions & 22 deletions lib/substitutions/src/graph_pattern_match.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,6 @@

namespace FlexFlow {

// MultiDiGraphPatternMatch narrow_match(MultiDiGraphPatternMatch const &match,
// OpenMultiDiGraphView const &pattern) {
// MultiDiGraphPatternMatch result;
// std::unordered_set<Node> nodes = get_nodes(pattern);
// for (auto const &kv : match.node_assignment) {
// Node pattern_node = kv.first;
// if (contains(nodes, pattern_node)) {
// result.node_assignment.equate(kv.first, kv.second);
// }
// }

// std::unordered_set<OpenMultiDiEdge> edges = get_edges(pattern);
// for (auto const &kv : match.edge_assignment) {
// OpenMultiDiEdge pattern_edge = kv.first;
// if (contains(edges, pattern_edge)) {
// result.edge_assignment.equate(kv.first, kv.second);
// }
// }

// return result;
// }

GraphSplit split_pattern(OpenMultiDiGraphView const &pattern) {
std::vector<Node> topological_ordering = get_topological_ordering(pattern);
assert(topological_ordering.size() >= 2);
Expand Down
103 changes: 42 additions & 61 deletions lib/substitutions/src/substitution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,74 +13,66 @@ bool is_valid_operator_attribute_expr(
return contains(get_valid_operator_attribute_exprs(pattern), expr);
}

struct IsValidGraphAttributeExprFunctor {
struct IsValidOperatorAttributeExprFunctor {
GraphPattern const &graph_pattern;

template <typename T>
bool operator()(T const &t) const {
return is_valid(t);
}

bool is_valid(NodeAttrAccess const &t) const {
bool is_valid(OperatorAttrAccess const &t) const {
return is_valid_operator_attribute_expr(graph_pattern->at(t.node),
t.attr_expr);
}

bool is_valid(EdgeAttrAccess const &t) const {
NOT_IMPLEMENTED();
}

bool is_valid(AttrConstant const &t) const {
return true;
}
};

bool is_valid_graph_attribute_expr(GraphPattern const &pattern,
GraphAttributeExpr const &expr) {
return visit(IsValidGraphAttributeExprFunctor{pattern}, expr);
bool is_valid_operator_attribute_expr(GraphPattern const &pattern,
OperatorAttributeExpr const &expr) {
return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr);
}

bool is_valid_substitution(Substitution const &s) {
for (Node const &node : get_nodes(s.output_graph_expr)) {
for (GraphAttributeExpr expr :
values(s.output_graph_expr.value().at(node).assignment)) {
if (!is_valid_graph_attribute_expr(s.input_graph, expr)) {
for (OperatorAttributeExpr expr :
values(s.output_graph_expr->at(node).assignment)) {
if (!is_valid_operator_attribute_expr(s.input_graph, expr)) {
return false;
}
}
}
return true;
}

struct EvaluateGraphAttributeExpr {
struct EvaluateOperatorAttributeExpr {
ParallelComputationGraph const &graph;
MultiDiGraphPatternMatch const &match;

template <typename T>
GraphAttributeValue operator()(T const &t) {
OperatorAttributeExpr operator()(T const &t) {
return evaluate(t);
}

GraphAttributeValue evaluate(NodeAttrAccess const &t) {
OperatorAttributeValue evaluate(OperatorAttrAccess const &t) {
Node node_in_pattern = t.node;
Node node_in_pcg = match.node_assignment.at_l(node_in_pattern);
return widen<GraphAttributeValue>(
evaluate_attribute_expr(graph->at(node_in_pcg), t.attr_expr).value());
return evaluate_attribute_expr(graph->at(node_in_pcg), t.attr_expr).value();
}

GraphAttributeValue evaluate(EdgeAttrAccess const &t) {
OpenMultiDiEdge output_in_pattern = t.edge;
MultiDiEdge output_in_pcg = match.edge_assignment.at_l(output_in_pattern);
return widen<GraphAttributeValue>(
evaluate_attribute_expr(graph->at(output_in_pcg), t.attr_expr).value());
OperatorAttributeValue evaluate(AttrConstant const &t) {
return t.value;
}
};

GraphAttributeValue
OperatorAttributeExpr
evaluate_graph_attribute_expr(ParallelComputationGraph const &g,
MultiDiGraphPatternMatch const &match,
GraphAttributeExpr const &expr) {
return visit(EvaluateGraphAttributeExpr{g, match}, expr);
OperatorAttributeExpr const &expr) {
return visit(EvaluateOperatorAttributeExpr{g, match}, expr);
}

Operator get_operator_attrs(ParallelComputationGraph const &graph,
Expand All @@ -89,13 +81,6 @@ Operator get_operator_attrs(ParallelComputationGraph const &graph,
NOT_IMPLEMENTED();
}

ParallelTensor
get_parallel_tensor_attrs(ParallelComputationGraph const &graph,
MultiDiGraphPatternMatch const &match,
ParallelTensorAttrAssignment const &assignment) {
NOT_IMPLEMENTED();
}

ParallelComputationGraph
apply_substitution(ParallelComputationGraph const &pcg,
Substitution const &substitution,
Expand All @@ -106,53 +91,49 @@ ParallelComputationGraph
bidict<Node, Node> node_mapping; // Refactor it with global nodes
for (Node const &node : get_nodes(pcg)) {
if (!contains_r(match.node_assignment, node)) {
node_mapping.equate(node, new_pcg.add_node(pcg.value().at(node)));
node_mapping.equate(node, new_pcg->add_node(pcg.value().at(node)));
}
}
for (MultiDiEdge const &edge : get_edges(pcg)) {
if (!contains_r(match.edge_assignment, edge)) {
new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(edge.src),
new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(edge.src),
node_mapping.at_r(edge.dst),
new_pcg.add_node_port(),
new_pcg.add_node_port()});
new_pcg->add_node_port(),
new_pcg->add_node_port()});
}
}
for (Node const &output_node : get_nodes(substitution.output_graph_expr)) {
Node new_node = new_pcg.add_node(get_operator_attrs(
pcg, match, substitution.output_graph_expr.at(output_node)));
Node new_node = new_pcg->add_node(get_operator_attrs(
pcg, match, substitution.output_graph_expr->at(output_node)));
node_mapping.equate(output_node, new_node);
}
for (OpenMultiDiEdge const &output_edge :
get_edges(substitution.output_graph_expr)) {
if (holds_alternative<InputMultiDiEdge>(output_edge)) {
MultiDiEdge origin_edge = match.edge_assignment.at_r(
substitution.input_mapping.at_r(output_edge));
new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(origin_edge.src),
node_mapping.at_l(output_edge.dst),
new_pcg.add_node_port(),
new_pcg.add_node_port()});
InputMultiDiEdge e = get<InputMultiDiEdge>(output_edge);
MultiDiEdge original_edge = match.edge_assignment.at_l(
substitution.input_mapping.at_r(e));
new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(original_edge.src),
node_mapping.at_l(e.dst),
new_pcg->add_node_port(),
new_pcg->add_node_port()});
} else if (holds_alternative<OutputMultiDiEdge>(output_edge)) {
MultiDiEdge origin_edge = match.edge_assignment.at_r(
substitution.output_mapping.at_r(output_edge));
new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src),
node_mapping.at_l(origin_edge.dst),
new_pcg.add_node_port(),
new_pcg.add_node_port()});
OutputMultiDiEdge e = get<OutputMultiDiEdge>(output_edge);
MultiDiEdge original_edge = match.edge_assignment.at_l(
substitution.output_mapping.at_r(e));
new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(e.src),
node_mapping.at_l(original_edge.dst),
new_pcg->add_node_port(),
new_pcg->add_node_port()});
} else {
assert(holds_alternative<MultiDiEdge>(output_edge));
new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(output_edge.src),
node_mapping.at_l(output_edge.dst),
new_pcg.add_node_port(),
new_pcg.add_node_port()});
MultiDiEdge e = get<MultiDiEdge>(output_edge);
new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(e.src),
node_mapping.at_l(e.dst),
new_pcg->add_node_port(),
new_pcg->add_node_port()});
}
}
for (MultiDiOutput const &output :
get_outputs(substitution.output_graph_expr)) {
new_pcg.add_output(
MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()},
get_parallel_tensor_attrs(
pcg, match, substitution.output_graph_expr.at(output)));
}

return new_pcg;
}
Expand Down
17 changes: 17 additions & 0 deletions lib/utils/include/utils/graph/labelled/node_labelled_open.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_NODE_LABELLED_OPEN
#define _FLEXFLOW_UTILS_GRAPH_LABELLED_NODE_LABELLED_OPEN

namespace FlexFlow {

template <typename NodeLabel>
struct NodeLabelledOpenMultiDiGraph {
NodeLabelledOpenMultiDiGraph() = delete;
NodeLabelledOpenMultiDiGraph(NodeLabelledOpenMultiDiGraph const &) = default;
NodeLabelledOpenMultiDiGraph &operator=(NodeLabelledOpenMultiDiGraph const &) = default;

operator OpenMultiDiGraphView();
};

}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct OutputLabelledOpenMultiDiGraph {
NodeLabel const &at(Node const &) const {
NOT_IMPLEMENTED();
}
NodeLabel &at(Node const &) const {
NodeLabel &at(Node const &) {
NOT_IMPLEMENTED();
}

Expand Down
1 change: 1 addition & 0 deletions lib/utils/include/utils/graph/labelled_graphs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "labelled/labelled_open.h"
#include "labelled/labelled_upward_open.h"
#include "labelled/node_labelled.h"
#include "labelled/node_labelled_open.h"
#include "labelled/open_algorithms.h"
#include "labelled/output_labelled.h"
#include "labelled/output_labelled_open.h"
Expand Down

0 comments on commit 21b8549

Please sign in to comment.