From 2fb2c7dfefa10f6c09dee74c625ae02804cc27c8 Mon Sep 17 00:00:00 2001 From: wmdi Date: Sun, 27 Aug 2023 21:18:42 -0400 Subject: [PATCH] minor fix --- lib/compiler/src/unity_algorithm.cc | 6 ++- .../include/substitutions/graph_pattern.h | 2 +- .../substitutions/graph_pattern_match.h | 23 ++++++++--- .../include/substitutions/output_graph.h | 8 ++-- .../substitutions/parallel_tensor_pattern.h | 9 ----- .../include/substitutions/substitution.h | 13 +++++-- lib/substitutions/src/graph_pattern.cc | 2 +- lib/substitutions/src/graph_pattern_match.cc | 32 +++++++-------- lib/substitutions/src/substitution.cc | 39 ++++++++++--------- 9 files changed, 73 insertions(+), 61 deletions(-) diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ef093fc11e..86fdd88d92 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,6 +1,6 @@ #include "compiler/unity_algorithm.h" #include "graph_utils.h" -#include "substitutions_implementation.h" +#include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" namespace FlexFlow { @@ -14,7 +14,9 @@ std::unordered_set std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &) {} + Substitution const &) { + NOT_IMPLEMENTED(); +} Strategy graph_optimize(ComputationGraph &cg, diff --git a/lib/substitutions/include/substitutions/graph_pattern.h b/lib/substitutions/include/substitutions/graph_pattern.h index e2054f1a4f..7697ddf55d 100644 --- a/lib/substitutions/include/substitutions/graph_pattern.h +++ b/lib/substitutions/include/substitutions/graph_pattern.h @@ -22,7 +22,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &); bool assignment_satisfies(ParallelComputationGraph const &, GraphPattern const &, - DiGraphPatternMatch const &); + MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/graph_pattern_match.h b/lib/substitutions/include/substitutions/graph_pattern_match.h index 449c26c846..498ec6cfd0 100644 --- a/lib/substitutions/include/substitutions/graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/graph_pattern_match.h @@ -6,22 +6,33 @@ namespace FlexFlow { -struct DiGraphPatternMatch { - bidict nodeAssignment; - bidict edgeAssignment; +struct MultiDiGraphPatternMatch { + using PatternNode = Node; + using PCGNode = Node; + using PatternEdge = OpenMultiDiEdge; + using PCGEdge = MultiDiEdge; + + bidict nodeAssignment; + bidict edgeAssignment; }; struct MatchSplit { - DiGraphPatternMatch prefix_submatch; - DiGraphPatternMatch postfix_submatch; + MultiDiGraphPatternMatch prefix_submatch; + MultiDiGraphPatternMatch postfix_submatch; }; template bool pattern_matches(OpenMultiDiGraphView const &, MultiDiGraphView const &, - DiGraphPatternMatch const &, + MultiDiGraphPatternMatch const &, F const &additional_criterion); +template +std::unordered_set + find_pattern_matches(OpenMultiDiGraphView const &pattern, + MultiDiGraphView const &graph, + F const &additional_criterion); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 7b32ca9900..f5b6328d7d 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -46,16 +46,16 @@ using GraphAttributeExpr = // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. struct OperatorAttrAssignment { - std::vector> assignment; + std::unordered_map assignment; }; struct ParallelTensorAttrAssignment { - std::vector> assignment; + std::unordered_map assignment; }; -struct OutputGraph +struct OutputGraphExpr : public strong_typedef< - OutputGraph, + OutputGraphExpr, OutputLabelledOpenMultiDiGraph> { using strong_typedef::strong_typedef; diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index c62237d0fd..2b5f4d0f58 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -8,15 +8,6 @@ namespace FlexFlow { enum class TensorDimensionAttribute { SIZE, DEGREE }; -struct TensorNumDimensionsConstraint { - int value; -}; - -struct TensorDimensionAttributeConstraint { - TensorDimensionAttribute attribute; - int index; -}; - enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; using TensorAttributeValue = variant>; diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 55820da33f..a805d0dae1 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -7,15 +7,20 @@ namespace FlexFlow { struct Substitution { + using InputPatternInput = InputMultiDiEdge; + using InputPatternOutput = OutputMultiDiEdge; + using OutputPatternInput = InputMultiDiEdge; + using OutputPatternOutput = OutputMultiDiEdge; + GraphPattern input_graph; - OutputGraph output_graph; - bidict input_mapping; - bidict output_mapping; + OutputGraphExpr output_graph_expr; + bidict input_mapping; + bidict output_mapping; }; ParallelComputationGraph apply_substitution(ParallelComputationGraph const &, Substitution const &, - DiGraphPatternMatch const &); + MultiDiGraphPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 9b7529ad8a..dfaf47910b 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -186,7 +186,7 @@ optional satisfies(ParallelTensor const ¶ms, bool assignment_satisfies(ParallelComputationGraph const &pcg, GraphPattern const &pattern, - DiGraphPatternMatch const &patternMatch) { + MultiDiGraphPatternMatch const &patternMatch) { bool result = true; for (auto const &kv : patternMatch.nodeAssignment) { auto patternNode = kv.first; diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index a5c185aba0..2e0150c808 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -4,9 +4,9 @@ namespace FlexFlow { -// DiGraphPatternMatch narrow_match(DiGraphPatternMatch const &match, +// MultiDiGraphPatternMatch narrow_match(MultiDiGraphPatternMatch const &match, // OpenMultiDiGraphView const &pattern) { -// DiGraphPatternMatch result; +// MultiDiGraphPatternMatch result; // std::unordered_set nodes = get_nodes(pattern); // for (auto const &kv : match.nodeAssignment) { // Node pattern_node = kv.first; @@ -47,7 +47,7 @@ std::pair Given a match and a pattern split, gets the submatches in subpatterns. */ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphSplit const &split) { auto prefix = split.first; auto postfix = split.second; @@ -99,7 +99,7 @@ bool is_singleton_pattern(OpenMultiDiGraphView const &pattern) { template bool pattern_matches(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, F const &additional_criterion) { if (is_singleton_pattern(pattern)) { Node pattern_node = get_only(get_nodes(pattern)); @@ -149,7 +149,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, additional_criterion); } -optional +optional get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, Node const &graph_node) { @@ -157,7 +157,7 @@ optional Node pattern_node = get_only(get_nodes(pattern)); - DiGraphPatternMatch match; + MultiDiGraphPatternMatch match; match.nodeAssignment.equate(pattern_node, graph_node); auto incoming = get_incoming_edges_by_idx(graph, graph_node); @@ -185,12 +185,12 @@ optional return match; } -optional unsplit_matches( - DiGraphPatternMatch const &prefix, - DiGraphPatternMatch const &postfix, +optional unsplit_matches( + MultiDiGraphPatternMatch const &prefix, + MultiDiGraphPatternMatch const &postfix, bidict> const &edge_splits) { - DiGraphPatternMatch result; + MultiDiGraphPatternMatch result; std::unordered_set handled; for (auto const &kv : edge_splits) { MultiDiEdge standard_edge = kv.first; @@ -222,14 +222,14 @@ optional unsplit_matches( } template -std::unordered_set +std::unordered_set find_pattern_matches(OpenMultiDiGraphView const &pattern, MultiDiGraphView const &graph, F const &additional_criterion) { - std::unordered_set matches; + std::unordered_set matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - optional candidate = + optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() || pattern_matches(pattern, graph, candidate.value())) { @@ -244,9 +244,9 @@ std::unordered_set auto postfix_matches = find_pattern_matches(subpatterns.first, graph, additional_criterion); auto edge_splits = get_edge_splits(pattern, split); - for (DiGraphPatternMatch const &prefix_match : prefix_matches) { - for (DiGraphPatternMatch const &postfix_match : postfix_matches) { - optional unsplit = + for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { + for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { + optional unsplit = unsplit_matches(prefix_match, postfix_match, edge_splits); if (unsplit.has_value()) { matches.insert(unsplit.value()); diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index c5cc870a7a..ff8d6ef541 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -3,7 +3,9 @@ namespace FlexFlow { -struct GraphAttributeValueOp { +struct GraphAttributeValueOpFunctor { + AttrOpType op; + template GraphAttributeValue operator()(T const &lhs, T const &rhs) { switch (op) { @@ -23,13 +25,12 @@ struct GraphAttributeValueOp { mk_runtime_error("Unknown attribute operator type"); } } - AttrOpType op; }; GraphAttributeValue graph_attribute_value_op(AttrOpType op, GraphAttributeValue const &lhs, GraphAttributeValue const &rhs) { - visit(GraphAttributeValueOp{op}, lhs, rhs); + visit(GraphAttributeValueOpFunctor{op}, lhs, rhs); } struct EvaluateGraphAttributeExprLeaf { @@ -53,12 +54,12 @@ struct EvaluateGraphAttributeExprLeaf { } ParallelComputationGraph const &graph; - DiGraphPatternMatch const &match; + MultiDiGraphPatternMatch const &match; }; GraphAttributeValue evaluate_graph_attribute_expr_leaf(ParallelComputationGraph const &g, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphAttributeExprLeaf const &expr) { return visit(EvaluateGraphAttributeExprLeaf{g, match}, expr); } @@ -82,36 +83,37 @@ struct EvaluateGraphAttributeExpr { } EvaluateGraphAttributeExpr(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match) + MultiDiGraphPatternMatch const &match) : graph(graph), match(match) {} ParallelComputationGraph const &graph; - DiGraphPatternMatch const &match; + MultiDiGraphPatternMatch const &match; }; GraphAttributeValue evaluate_graph_attribute_expr(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, GraphAttributeExpr const &expr) { return visit(EvaluateGraphAttributeExpr(graph, match), expr); } Operator get_operator_attrs(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, OperatorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } ParallelTensor get_parallel_tensor_attrs(ParallelComputationGraph const &graph, - DiGraphPatternMatch const &match, + MultiDiGraphPatternMatch const &match, ParallelTensorAttrAssignment const &assignment) { NOT_IMPLEMENTED(); } -ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, - Substitution const &substitution, - DiGraphPatternMatch const &match) { +ParallelComputationGraph + apply_substitution(ParallelComputationGraph const &pcg, + Substitution const &substitution, + MultiDiGraphPatternMatch const &match) { ParallelComputationGraph new_pcg = ParallelComputationGraph::create(); bidict node_mapping; // Refactor it with global nodes @@ -128,13 +130,13 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, new_pcg.add_node_port()}); } } - for (Node const &output_node : get_nodes(substitution.output_graph)) { + for (Node const &output_node : get_nodes(substitution.output_graph_expr)) { Node new_node = new_pcg.add_node(get_operator_attrs( - pcg, match, substitution.output_graph.at(output_node))); + pcg, match, substitution.output_graph_expr.at(output_node))); node_mapping.equate(output_node, new_node); } for (OpenMultiDiEdge const &output_edge : - get_edges(substitution.output_graph)) { + get_edges(substitution.output_graph_expr)) { if (holds_alternative(output_edge)) { MultiDiEdge origin_edge = match.edgeAssignment.at_r( substitution.input_mapping.at_r(output_edge)); @@ -157,11 +159,12 @@ ParallelComputationGraph apply_substitution(ParallelComputationGraph const &pcg, new_pcg.add_node_port()}); } } - for (MultiDiOutput const &output : get_outputs(substitution.output_graph)) { + for (MultiDiOutput const &output : + get_outputs(substitution.output_graph_expr)) { new_pcg.add_output( MultiDiOutput{node_mapping.at_l(output.src), new_pcg.add_node_port()}, get_parallel_tensor_attrs( - pcg, match, substitution.output_graph.at(output))); + pcg, match, substitution.output_graph_expr.at(output))); } return new_pcg;