Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Aug 28, 2023
1 parent a5e111e commit 2fb2c7d
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 61 deletions.
6 changes: 4 additions & 2 deletions lib/compiler/src/unity_algorithm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "compiler/unity_algorithm.h"
#include "graph_utils.h"
#include "substitutions_implementation.h"
#include "substitutions/substitution.h"
#include "utils/deduplicated_priority_queue.h"

namespace FlexFlow {
Expand All @@ -14,7 +14,9 @@ std::unordered_set<Substitution>

std::unordered_set<ParallelComputationGraph>
apply_substitution(ParallelComputationGraph const &pcg,
Substitution const &) {}
Substitution const &) {
NOT_IMPLEMENTED();
}

Strategy
graph_optimize(ComputationGraph &cg,
Expand Down
2 changes: 1 addition & 1 deletion lib/substitutions/include/substitutions/graph_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &);

bool assignment_satisfies(ParallelComputationGraph const &,
GraphPattern const &,
DiGraphPatternMatch const &);
MultiDiGraphPatternMatch const &);

} // namespace FlexFlow

Expand Down
23 changes: 17 additions & 6 deletions lib/substitutions/include/substitutions/graph_pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,33 @@

namespace FlexFlow {

struct DiGraphPatternMatch {
bidict<Node, Node> nodeAssignment;
bidict<OpenMultiDiEdge, MultiDiEdge> edgeAssignment;
struct MultiDiGraphPatternMatch {
using PatternNode = Node;
using PCGNode = Node;
using PatternEdge = OpenMultiDiEdge;
using PCGEdge = MultiDiEdge;

bidict<PatternNode, PCGNode> nodeAssignment;
bidict<PatternEdge, PCGEdge> edgeAssignment;
};

struct MatchSplit {
DiGraphPatternMatch prefix_submatch;
DiGraphPatternMatch postfix_submatch;
MultiDiGraphPatternMatch prefix_submatch;
MultiDiGraphPatternMatch postfix_submatch;
};

template <typename F>
bool pattern_matches(OpenMultiDiGraphView const &,
MultiDiGraphView const &,
DiGraphPatternMatch const &,
MultiDiGraphPatternMatch const &,
F const &additional_criterion);

template <typename F>
std::unordered_set<MultiDiGraphPatternMatch>
find_pattern_matches(OpenMultiDiGraphView const &pattern,
MultiDiGraphView const &graph,
F const &additional_criterion);

} // namespace FlexFlow

#endif
8 changes: 4 additions & 4 deletions lib/substitutions/include/substitutions/output_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ using GraphAttributeExpr =
// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can
// define the assignment for each operator type.
struct OperatorAttrAssignment {
std::vector<std::pair<OperatorAttributeKey, GraphAttributeExpr>> assignment;
std::unordered_map<OperatorAttributeKey, GraphAttributeExpr> assignment;
};

struct ParallelTensorAttrAssignment {
std::vector<std::pair<TensorAttributeKey, GraphAttributeExpr>> assignment;
std::unordered_map<TensorAttributeKey, GraphAttributeExpr> assignment;
};

struct OutputGraph
struct OutputGraphExpr
: public strong_typedef<
OutputGraph,
OutputGraphExpr,
OutputLabelledOpenMultiDiGraph<OperatorAttrAssignment,
ParallelTensorAttrAssignment>> {
using strong_typedef::strong_typedef;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ namespace FlexFlow {

enum class TensorDimensionAttribute { SIZE, DEGREE };

struct TensorNumDimensionsConstraint {
int value;
};

struct TensorDimensionAttributeConstraint {
TensorDimensionAttribute attribute;
int index;
};

enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES };

using TensorAttributeValue = variant<int, std::vector<int>>;
Expand Down
13 changes: 9 additions & 4 deletions lib/substitutions/include/substitutions/substitution.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
namespace FlexFlow {

struct Substitution {
using InputPatternInput = InputMultiDiEdge;
using InputPatternOutput = OutputMultiDiEdge;
using OutputPatternInput = InputMultiDiEdge;
using OutputPatternOutput = OutputMultiDiEdge;

GraphPattern input_graph;
OutputGraph output_graph;
bidict<InputMultiDiEdge, InputMultiDiEdge> input_mapping;
bidict<OutputMultiDiEdge, OutputMultiDiEdge> output_mapping;
OutputGraphExpr output_graph_expr;
bidict<InputPatternInput, OutputPatternInput> input_mapping;
bidict<InputPatternOutput, OutputPatternOutput> output_mapping;
};

ParallelComputationGraph apply_substitution(ParallelComputationGraph const &,
Substitution const &,
DiGraphPatternMatch const &);
MultiDiGraphPatternMatch const &);

} // namespace FlexFlow

Expand Down
2 changes: 1 addition & 1 deletion lib/substitutions/src/graph_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ optional<bool> satisfies(ParallelTensor const &params,

bool assignment_satisfies(ParallelComputationGraph const &pcg,
GraphPattern const &pattern,
DiGraphPatternMatch const &patternMatch) {
MultiDiGraphPatternMatch const &patternMatch) {
bool result = true;
for (auto const &kv : patternMatch.nodeAssignment) {
auto patternNode = kv.first;
Expand Down
32 changes: 16 additions & 16 deletions lib/substitutions/src/graph_pattern_match.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

namespace FlexFlow {

// DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match,
// MultiDiGraphPatternMatch narrow_match(MultiDiGraphPatternMatch const &match,
// OpenMultiDiGraphView const &pattern) {
// DiGraphPatternMatch result;
// MultiDiGraphPatternMatch result;
// std::unordered_set<Node> nodes = get_nodes(pattern);
// for (auto const &kv : match.nodeAssignment) {
// Node pattern_node = kv.first;
Expand Down Expand Up @@ -47,7 +47,7 @@ std::pair<OpenMultiDiGraphView, OpenMultiDiGraphView>
Given a match and a pattern split, gets the submatches in subpatterns.
*/
MatchSplit apply_split(OpenMultiDiGraphView const &pattern,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
GraphSplit const &split) {
auto prefix = split.first;
auto postfix = split.second;
Expand Down Expand Up @@ -99,7 +99,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) {
template <typename F>
bool pattern_matches(OpenMultiDiGraphView const &pattern,
MultiDiGraphView const &graph,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
F const &additional_criterion) {
if (is_singleton_pattern(pattern)) {
Node pattern_node = get_only(get_nodes(pattern));
Expand Down Expand Up @@ -149,15 +149,15 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern,
additional_criterion);
}

optional<DiGraphPatternMatch>
optional<MultiDiGraphPatternMatch>
get_candidate_singleton_match(OpenMultiDiGraphView const &pattern,
MultiDiGraphView const &graph,
Node const &graph_node) {
assert(is_singleton_pattern(pattern));

Node pattern_node = get_only(get_nodes(pattern));

DiGraphPatternMatch match;
MultiDiGraphPatternMatch match;
match.nodeAssignment.equate(pattern_node, graph_node);

auto incoming = get_incoming_edges_by_idx(graph, graph_node);
Expand Down Expand Up @@ -185,12 +185,12 @@ optional<DiGraphPatternMatch>
return match;
}

optional<DiGraphPatternMatch> unsplit_matches(
DiGraphPatternMatch const &prefix,
DiGraphPatternMatch const &postfix,
optional<MultiDiGraphPatternMatch> unsplit_matches(
MultiDiGraphPatternMatch const &prefix,
MultiDiGraphPatternMatch const &postfix,
bidict<MultiDiEdge, std::pair<OutputMultiDiEdge, InputMultiDiEdge>> const
&edge_splits) {
DiGraphPatternMatch result;
MultiDiGraphPatternMatch result;
std::unordered_set<OpenMultiDiEdge> handled;
for (auto const &kv : edge_splits) {
MultiDiEdge standard_edge = kv.first;
Expand Down Expand Up @@ -222,14 +222,14 @@ optional<DiGraphPatternMatch> unsplit_matches(
}

template <typename F>
std::unordered_set<DiGraphPatternMatch>
std::unordered_set<MultiDiGraphPatternMatch>
find_pattern_matches(OpenMultiDiGraphView const &pattern,
MultiDiGraphView const &graph,
F const &additional_criterion) {
std::unordered_set<DiGraphPatternMatch> matches;
std::unordered_set<MultiDiGraphPatternMatch> matches;
if (is_singleton_pattern(pattern)) {
for (Node const &graph_node : get_nodes(graph)) {
optional<DiGraphPatternMatch> candidate =
optional<MultiDiGraphPatternMatch> candidate =
get_candidate_singleton_match(pattern, graph, graph_node);
if (candidate.has_value() ||
pattern_matches<F>(pattern, graph, candidate.value())) {
Expand All @@ -244,9 +244,9 @@ std::unordered_set<DiGraphPatternMatch>
auto postfix_matches =
find_pattern_matches(subpatterns.first, graph, additional_criterion);
auto edge_splits = get_edge_splits(pattern, split);
for (DiGraphPatternMatch const &prefix_match : prefix_matches) {
for (DiGraphPatternMatch const &postfix_match : postfix_matches) {
optional<DiGraphPatternMatch> unsplit =
for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) {
for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) {
optional<MultiDiGraphPatternMatch> unsplit =
unsplit_matches(prefix_match, postfix_match, edge_splits);
if (unsplit.has_value()) {
matches.insert(unsplit.value());
Expand Down
39 changes: 21 additions & 18 deletions lib/substitutions/src/substitution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

namespace FlexFlow {

struct GraphAttributeValueOp {
struct GraphAttributeValueOpFunctor {
AttrOpType op;

template <typename T>
GraphAttributeValue operator()(T const &lhs, T const &rhs) {
switch (op) {
Expand All @@ -23,13 +25,12 @@ struct GraphAttributeValueOp {
mk_runtime_error("Unknown attribute operator type");
}
}
AttrOpType op;
};

GraphAttributeValue graph_attribute_value_op(AttrOpType op,
GraphAttributeValue const &lhs,
GraphAttributeValue const &rhs) {
visit(GraphAttributeValueOp{op}, lhs, rhs);
visit(GraphAttributeValueOpFunctor{op}, lhs, rhs);
}

struct EvaluateGraphAttributeExprLeaf {
Expand All @@ -53,12 +54,12 @@ struct EvaluateGraphAttributeExprLeaf {
}

ParallelComputationGraph const &graph;
DiGraphPatternMatch const &match;
MultiDiGraphPatternMatch const &match;
};

GraphAttributeValue
evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
GraphAttributeExprLeaf const &expr) {
return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr);
}
Expand All @@ -82,36 +83,37 @@ struct EvaluateGraphAttributeExpr {
}

EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph,
DiGraphPatternMatch const &match)
MultiDiGraphPatternMatch const &match)
: graph(graph), match(match) {}

ParallelComputationGraph const &graph;
DiGraphPatternMatch const &match;
MultiDiGraphPatternMatch const &match;
};

GraphAttributeValue
evaluate_graph_attribute_expr(ParallelComputationGraph const &graph,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
GraphAttributeExpr const &expr) {
return visit(EvaluateGraphAttributeExpr(graph, match), expr);
}

Operator get_operator_attrs(ParallelComputationGraph const &graph,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
OperatorAttrAssignment const &assignment) {
NOT_IMPLEMENTED();
}

ParallelTensor
get_parallel_tensor_attrs(ParallelComputationGraph const &graph,
DiGraphPatternMatch const &match,
MultiDiGraphPatternMatch const &match,
ParallelTensorAttrAssignment const &assignment) {
NOT_IMPLEMENTED();
}

ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg,
Substitution const &substitution,
DiGraphPatternMatch const &match) {
ParallelComputationGraph
apply_substitution(ParallelComputationGraph const &pcg,
Substitution const &substitution,
MultiDiGraphPatternMatch const &match) {
ParallelComputationGraph new_pcg =
ParallelComputationGraph::create<UnorderedOutputLabelledMultiDiGraph>();
bidict<Node, Node> node_mapping; // Refactor it with global nodes
Expand All @@ -128,13 +130,13 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg,
new_pcg.add_node_port()});
}
}
for (Node const &output_node : get_nodes(substitution.output_graph)) {
for (Node const &output_node : get_nodes(substitution.output_graph_expr)) {
Node new_node = new_pcg.add_node(get_operator_attrs(
pcg, match, substitution.output_graph.at(output_node)));
pcg, match, substitution.output_graph_expr.at(output_node)));
node_mapping.equate(output_node, new_node);
}
for (OpenMultiDiEdge const &output_edge :
get_edges(substitution.output_graph)) {
get_edges(substitution.output_graph_expr)) {
if (holds_alternative<InputMultiDiEdge>(output_edge)) {
MultiDiEdge origin_edge = match.edgeAssignment.at_r(
substitution.input_mapping.at_r(output_edge));
Expand All @@ -157,11 +159,12 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg,
new_pcg.add_node_port()});
}
}
for (MultiDiOutput const &output : get_outputs(substitution.output_graph)) {
for (MultiDiOutput const &output :
get_outputs(substitution.output_graph_expr)) {
new_pcg.add_output(
MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()},
get_parallel_tensor_attrs(
pcg, match, substitution.output_graph.at(output)));
pcg, match, substitution.output_graph_expr.at(output)));
}

return new_pcg;
Expand Down

0 comments on commit 2fb2c7d

Please sign in to comment.