diff --git a/lib/substitutions/CMakeLists.txt b/lib/substitutions/CMakeLists.txt index 836cc91a36..b909f6183d 100644 --- a/lib/substitutions/CMakeLists.txt +++ b/lib/substitutions/CMakeLists.txt @@ -10,6 +10,7 @@ ff_add_library( DEPS utils op-attrs + pcg ) add_subdirectory(ffi) diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 452ed60af5..12d1bc07b9 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -23,7 +23,7 @@ enum class OperatorAttributeKey { STRIDE_W, PADDING_H, PADDING_W, - AGGR_MODE, + AGGR, NUM_ENTRIES, OUT_CHANNELS, ACTIVATION, @@ -42,14 +42,49 @@ enum class OperatorAttributeKey { PARALLEL_DIM, PARALLEL_DEGREE, PAD, + EMBED_DIM, + KDIM, + VDIM, + DROPOUT, + BIAS, + ADD_BIAS_KV, + ADD_ZERO_ATTN, + A_SEQ_LENGTH_DIM, + B_SEQ_LENGTH_DIM, + RELU, + TARGET_DIMS, + RATE, + SEED, + SHOULD_BROADCAST_LHS, + SHOULD_BROADCAST_RHS, + DIM, + ELEMENTWISE_AFFINE, + REGULARIZER, + SHAPE, + SPLITS, + K, + SORTED, }; -using OperatorAttributeValue = - variant, OperatorType, Activation>; +using OperatorAttributeValue = variant, + OperatorType, + Activation, + ff_dim_t, + unsigned long long, + AggregateOp, + stack_vector, + RegularizerAttrs, + PoolOp, + TensorShape>; -using OperatorAttributeConstraint = AttributeConstraint; +using OperatorAttributeConstraint = + AttributeConstraint; -using OperatorPattern = AttributePattern; +using OperatorPattern = + AttributePattern; optional evaluate_attribute_expr(Operator const &attrs, diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index 5cc0dd5ffe..b9cf1f53f3 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -15,8 +15,7 @@ struct AttrConstant { OperatorAttributeValue value; }; -using OperatorAttributeExpr = - variant; +using OperatorAttributeExpr = variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index 4cfb6a0a69..d07a1da23b 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -13,7 +13,8 @@ using TensorAttributeValue = variant>; using TensorAttributeConstraint = AttributeConstraint; -using ParallelTensorPattern = AttributePattern; +using ParallelTensorPattern = + AttributePattern; optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index eaf96d6516..a81d84c9e5 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -9,23 +9,45 @@ namespace FlexFlow { -template -optional evaluate_list_index_access(ListIndexAccess const &index_access, - optional const &v) { +optional evaluate_list_index_access(int index, + optional const &v) { + if (!v.has_value() || + !holds_alternative>(v.value()) || + !holds_alternative>(v.value())) { + return nullopt; + } + + if (index >= MAX_TENSOR_DIM) { + return nullopt; + } + + if (holds_alternative>(v.value())) { + return get>(v.value()).at(index); + } else { + return get>(v.value()).at(index); + } +} + +optional evaluate_list_index_access(int const &index, + optional const &v) { if (!v.has_value() || !holds_alternative>(v.value())) { return nullopt; } auto vec = get>(v.value()); - if (index_access.index >= vec.size()) { + + if (index >= vec.size()) { return nullopt; } - return vec.at(index_access.index); + return vec.at(index); } -template -optional evaluate_list_size(optional const &v) { +optional evaluate_list_size(optional const &v) { + return MAX_TENSOR_DIM; +} + +optional evaluate_list_size(optional const &v) { if (!v.has_value() || !holds_alternative>(v.value())) { return nullopt; } @@ -44,7 +66,7 @@ struct EvaluateOperatorAttributeExpr { operator()(ListIndexAccess const &index_access) { optional v = get_attribute(this->attrs, index_access.attribute_key); - return evaluate_list_index_access(index_access, v); + return evaluate_list_index_access(index_access.index, v); } optional @@ -94,8 +116,9 @@ struct EvaluateTensorAttributeExpr { optional operator()(ListIndexAccess const &index_access) { - auto v = this->evaluate(index_access.attribute_key); - return evaluate_list_index_access(index_access, v); + optional v = + this->evaluate(index_access.attribute_key); + return evaluate_list_index_access(index_access.index, v); } optional @@ -184,6 +207,13 @@ optional satisfies(ParallelTensor const ¶ms, [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); } +struct AlwaysTrueCriterion { + template + bool operator()(T const &t) const { + return true; + } +}; + bool assignment_satisfies(ParallelComputationGraph const &pcg, GraphPattern const &pattern, MultiDiGraphPatternMatch const &patternMatch) { @@ -192,7 +222,7 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg, auto patternNode = kv.first; auto pcgNode = kv.second; optional constraintResult = - satisfies(pcg->at(pcgNode), pattern->at(patternNode)); + satisfies(pcg.value().at(pcgNode), pattern.value().at(patternNode)); result &= constraintResult.value_or(false); } @@ -200,11 +230,11 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg, auto patternEdge = kv.first; auto pcgEdge = kv.second; optional constraintResult = - satisfies(pcg->at(pcgEdge), pattern->at(patternEdge)); + satisfies(pcg.value().at(pcgEdge), pattern.value().at(patternEdge)); result &= constraintResult.value_or(false); } - result &= pattern_matches(pattern, pcg, patternMatch); + result &= pattern_matches(pattern, pcg, patternMatch, AlwaysTrueCriterion{}); return result; } diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index afc3bb4e6d..3724f630d4 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -2,9 +2,37 @@ namespace FlexFlow { +struct DeriveValidOperatorAttributeExpr { + template + std::unordered_set> + operator()(T const &t) { + return derive_valid_operator_attribute_expr(t); + } + + std::unordered_set> + derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { + return {key}; + } + + std::unordered_set> + derive_valid_operator_attribute_expr( + ListIndexAccess const &access) { + return {access, access.attribute_key}; + } + + std::unordered_set> + derive_valid_operator_attribute_expr( + ListSize const &ls) { + return {ls, ls.attribute_key}; + } +}; + std::unordered_set> get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { - NOT_IMPLEMENTED(); + return set_union(transform( + pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) { + return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); + })); } bool is_valid_operator_attribute_expr( @@ -22,7 +50,7 @@ struct IsValidOperatorAttributeExprFunctor { } bool is_valid(OperatorAttrAccess const &t) const { - return is_valid_operator_attribute_expr(graph_pattern->at(t.node), + return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), t.attr_expr); } @@ -32,7 +60,7 @@ struct IsValidOperatorAttributeExprFunctor { }; bool is_valid_operator_attribute_expr(GraphPattern const &pattern, - OperatorAttributeExpr const &expr) { + OperatorAttributeExpr const &expr) { return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); } @@ -60,7 +88,8 @@ struct EvaluateOperatorAttributeExpr { OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { Node node_in_pattern = t.node; Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); - return evaluate_attribute_expr(graph->at(node_in_pcg), t.attr_expr).value(); + return evaluate_attribute_expr(graph.value().at(node_in_pcg), t.attr_expr) + .value(); } OperatorAttributeValue evaluate(AttrConstant const &t) { @@ -78,7 +107,189 @@ OperatorAttributeExpr Operator get_operator_attrs(ParallelComputationGraph const &graph, MultiDiGraphPatternMatch const &match, OperatorAttrAssignment const &assignment) { - NOT_IMPLEMENTED(); + std::unordered_map assignments; + for (auto const &[key, expr] : assignment.assignments) { + assignments.emplace(key, evaluate_graph_attribute_expr(graph, match, expr)); + } + assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); + assert(holds_alternative( + assignments.at(OperatorAttributeKey::OP_TYPE))); + OperatorType op_type = + get(assignments.at(OperatorAttributeKey::OP_TYPE)); + switch (op_type) { + case Op::BATCHMATMUL: + return Operator( + BatchMatmulAttrs{ + get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + nullopt); + case Op::BATCHNORM: + return Operator( + BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, + nullopt); + case Op::CAST: + return Operator(CastAttrs{get( + assignments.at(OperatorAttributeKey::DATA_TYPE))}, + nullopt); + case Op::CONCAT: + return Operator(ConcatAttrs{get( + assignments.at(OperatorAttributeKey::AXIS))}, + nullopt); + case Op::CONV2D: + return Operator( + Conv2DAttrs{ + get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + get(assignments.at(OperatorAttributeKey::KERNEL_H)), + get(assignments.at(OperatorAttributeKey::KERNEL_W)), + get(assignments.at(OperatorAttributeKey::STRIDE_H)), + get(assignments.at(OperatorAttributeKey::STRIDE_W)), + get(assignments.at(OperatorAttributeKey::PADDING_H)), + get(assignments.at(OperatorAttributeKey::PADDING_W)), + get(assignments.at(OperatorAttributeKey::GROUPS)), + get(assignments.at(OperatorAttributeKey::ACTIVATION)), + get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + nullopt); + case Op::DROPOUT: + return Operator( + DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), + get( + assignments.at(OperatorAttributeKey::SEED))}, + nullopt); + case Op::EW_ADD: + case Op::EW_DIV: + case Op::EW_EQUAL: + case Op::EW_GREATER: + case Op::EW_LESS: + case Op::EW_MAX: + case Op::EW_MIN: + case Op::EW_MUL: + case Op::EW_SUB: + return Operator( + ElementBinaryAttrs{ + op_type, + get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + get( + assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + get( + assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + nullopt); + case Op::SCALAR_ADD: + case Op::SCALAR_FLOOR_DIV: + case Op::SCALAR_MULTIPLY: + case Op::SCALAR_SUB: + case Op::SCALAR_TRUE_DIV: + return Operator( + ElementScalarUnaryAttrs{ + op_type, + get(assignments.at(OperatorAttributeKey::SCALAR))}, + nullopt); + case Op::EMBEDDING: + return Operator( + EmbeddingAttrs{ + get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + get(assignments.at(OperatorAttributeKey::AGGR)), + get(assignments.at(OperatorAttributeKey::OP_TYPE))}, + nullopt); + case Op::FLAT: + return Operator(FlatAttrs{}, nullopt); + case Op::GATHER: + return Operator( + GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, + nullopt); + case Op::INPUT: + return Operator(InputAttrs{}, nullopt); + case Op::LAYERNORM: + return Operator( + LayerNormAttrs{ + get>( + assignments.at(OperatorAttributeKey::AXES)), + get( + assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), + get(assignments.at(OperatorAttributeKey::EPSILON))}, + nullopt); + case Op::LINEAR: + return Operator( + LinearAttrs{ + get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + get(assignments.at(OperatorAttributeKey::USE_BIAS)), + get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + get(assignments.at(OperatorAttributeKey::DATA_TYPE)), + get( + assignments.at(OperatorAttributeKey::REGULARIZER))}, + nullopt); + case Op::MULTIHEAD_ATTENTION: + return Operator( + MultiHeadAttentionAttrs{ + get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + get(assignments.at(OperatorAttributeKey::VDIM)), + get(assignments.at(OperatorAttributeKey::DROPOUT)), + get(assignments.at(OperatorAttributeKey::BIAS)), + get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + nullopt); + case Op::NOOP: + return Operator(NoopAttrs{}, nullopt); + case Op::POOL2D: + return Operator( + Pool2DAttrs{ + get(assignments.at(OperatorAttributeKey::KERNEL_H)), + get(assignments.at(OperatorAttributeKey::KERNEL_W)), + get(assignments.at(OperatorAttributeKey::STRIDE_H)), + get(assignments.at(OperatorAttributeKey::STRIDE_W)), + get(assignments.at(OperatorAttributeKey::PADDING_H)), + get(assignments.at(OperatorAttributeKey::PADDING_W)), + get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + get( + assignments.at(OperatorAttributeKey::ACTIVATION))}, + nullopt); + case Op::REDUCE_ARGMAX: + case Op::REDUCE_ARGMIN: + case Op::REDUCE_MAX: + case Op::REDUCE_MEAN: + case Op::REDUCE_MIN: + case Op::REDUCE_PROD: + case Op::REDUCE_SUM: + return Operator( + ReduceAttrs{ + get>( + assignments.at(OperatorAttributeKey::AXES)), + op_type, + get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + nullopt); + case Op::REVERSE: + return Operator(ReverseAttrs{get( + assignments.at(OperatorAttributeKey::AXIS))}, + nullopt); + case Op::RESHAPE: + return Operator(ReshapeAttrs{get( + assignments.at(OperatorAttributeKey::SHAPE))}, + nullopt); + case Op::SPLIT: + return Operator( + SplitAttrs{get>( + assignments.at(OperatorAttributeKey::SPLITS)), + get(assignments.at(OperatorAttributeKey::AXIS))}, + nullopt); + case Op::SOFTMAX: + return Operator(SoftmaxAttrs{get( + assignments.at(OperatorAttributeKey::DIM))}, + nullopt); + case Op::TOPK: + return Operator( + TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), + get(assignments.at(OperatorAttributeKey::SORTED))}, + nullopt); + case Op::TRANSPOSE: + return Operator( + TransposeAttrs{get>( + assignments.at(OperatorAttributeKey::PERMUTATION))}, + nullopt); + default: + break; + } } ParallelComputationGraph @@ -91,19 +302,19 @@ ParallelComputationGraph bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.node_assignment, node)) { - node_mapping.equate(node, new_pcg->add_node(pcg.value().at(node))); + node_mapping.equate(node, new_pcg.value().add_node(pcg.value().at(node))); } } for (MultiDiEdge const &edge : get_edges(pcg)) { if (!contains_r(match.edge_assignment, edge)) { - new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(edge.src), - node_mapping.at_r(edge.dst), - new_pcg->add_node_port(), - new_pcg->add_node_port()}); + new_pcg.value().add_edge(MultiDiEdge{node_mapping.at_l(edge.src), + node_mapping.at_r(edge.dst), + new_pcg.value().add_node_port(), + new_pcg.value().add_node_port()}); } } for (Node const &output_node : get_nodes(substitution.output_graph_expr)) { - Node new_node = new_pcg->add_node(get_operator_attrs( + Node new_node = new_pcg.value().add_node(get_operator_attrs( pcg, match, substitution.output_graph_expr->at(output_node))); node_mapping.equate(output_node, new_node); } @@ -111,27 +322,28 @@ ParallelComputationGraph get_edges(substitution.output_graph_expr)) { if (holds_alternative(output_edge)) { InputMultiDiEdge e = get(output_edge); - MultiDiEdge original_edge = match.edge_assignment.at_l( - substitution.input_mapping.at_r(e)); - new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(original_edge.src), - node_mapping.at_l(e.dst), - new_pcg->add_node_port(), - new_pcg->add_node_port()}); + MultiDiEdge original_edge = + match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); + new_pcg.value().add_edge( + MultiDiOutput{node_mapping.at_l(original_edge.src), + new_pcg.value().add_node_port()}, + MultiDiInput{node_mapping.at_l(e.dst), + new_pcg.value().add_node_port()}); } else if (holds_alternative(output_edge)) { OutputMultiDiEdge e = get(output_edge); - MultiDiEdge original_edge = match.edge_assignment.at_l( - substitution.output_mapping.at_r(e)); - new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(e.src), - node_mapping.at_l(original_edge.dst), - new_pcg->add_node_port(), - new_pcg->add_node_port()}); + MultiDiEdge original_edge = + match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); + new_pcg.value().add_edge(MultiDiEdge{node_mapping.at_l(e.src), + node_mapping.at_l(original_edge.dst), + new_pcg.value().add_node_port(), + new_pcg.value().add_node_port()}); } else { assert(holds_alternative(output_edge)); MultiDiEdge e = get(output_edge); - new_pcg->add_edge(MultiDiEdge{node_mapping.at_l(e.src), - node_mapping.at_l(e.dst), - new_pcg->add_node_port(), - new_pcg->add_node_port()}); + new_pcg.value().add_edge(MultiDiEdge{node_mapping.at_l(e.src), + node_mapping.at_l(e.dst), + new_pcg.value().add_node_port(), + new_pcg.value().add_node_port()}); } } diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 29e88fd62b..223a9174c3 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -7,11 +7,12 @@ template struct NodeLabelledOpenMultiDiGraph { NodeLabelledOpenMultiDiGraph() = delete; NodeLabelledOpenMultiDiGraph(NodeLabelledOpenMultiDiGraph const &) = default; - NodeLabelledOpenMultiDiGraph &operator=(NodeLabelledOpenMultiDiGraph const &) = default; + NodeLabelledOpenMultiDiGraph & + operator=(NodeLabelledOpenMultiDiGraph const &) = default; operator OpenMultiDiGraphView(); }; -} +} // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index ffb69b717d..fcbb2436d0 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -62,6 +62,10 @@ struct OutputLabelledMultiDiGraph { return this->ptr->add_node(l); } + NodePort add_node_port() { + return this->ptr->add_node_port(); + } + NodeLabel &at(Node const &n) { return this->ptr->at(n); } @@ -83,6 +87,9 @@ struct OutputLabelledMultiDiGraph { OutputLabel const &at(MultiDiOutput const &o) const { return this->ptr->at(o); } + OutputLabel const &at(MultiDiEdge const &e) const { + return at(MultiDiOutput{e.src, e.srcIdx}); + } std::unordered_set query_nodes(NodeQuery const &q) const { return this->ptr->query_nodes(q); diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h index 5e4d59a829..15c554b97d 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_interfaces.h @@ -23,6 +23,7 @@ struct IOutputLabelledMultiDiGraph OutputLabel const &label) = 0; virtual void add_edge(MultiDiOutput const &output, MultiDiInput const &input) = 0; + virtual NodePort add_node_ports() = 0; virtual NodeLabel &at(Node const &) = 0; virtual NodeLabel const &at(Node const &) const = 0; diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index 75e0608837..8c6597d464 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -13,7 +13,7 @@ struct OutputLabelledOpenMultiDiGraph { OutputLabelledOpenMultiDiGraph & operator=(OutputLabelledOpenMultiDiGraph const &) = default; - operator OpenMultiDiGraphView() { + operator OpenMultiDiGraphView() const { NOT_IMPLEMENTED(); } @@ -60,6 +60,10 @@ struct OutputLabelledOpenMultiDiGraph { OutputLabel &at(MultiDiOutput const &) { NOT_IMPLEMENTED(); } + + OutputLabel const &at(OpenMultiDiEdge const &) const { + NOT_IMPLEMENTED(); + } }; } // namespace FlexFlow