Skip to content

Commit

Permalink
implement get_operator_attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Sep 8, 2023
1 parent 21b8549 commit e3b633f
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 51 deletions.
1 change: 1 addition & 0 deletions lib/substitutions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ff_add_library(
DEPS
utils
op-attrs
pcg
)

add_subdirectory(ffi)
Expand Down
45 changes: 40 additions & 5 deletions lib/substitutions/include/substitutions/operator_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ enum class OperatorAttributeKey {
STRIDE_W,
PADDING_H,
PADDING_W,
AGGR_MODE,
AGGR,
NUM_ENTRIES,
OUT_CHANNELS,
ACTIVATION,
Expand All @@ -42,14 +42,49 @@ enum class OperatorAttributeKey {
PARALLEL_DIM,
PARALLEL_DEGREE,
PAD,
EMBED_DIM,
KDIM,
VDIM,
DROPOUT,
BIAS,
ADD_BIAS_KV,
ADD_ZERO_ATTN,
A_SEQ_LENGTH_DIM,
B_SEQ_LENGTH_DIM,
RELU,
TARGET_DIMS,
RATE,
SEED,
SHOULD_BROADCAST_LHS,
SHOULD_BROADCAST_RHS,
DIM,
ELEMENTWISE_AFFINE,
REGULARIZER,
SHAPE,
SPLITS,
K,
SORTED,
};

using OperatorAttributeValue =
variant<int, float, bool, std::vector<int>, OperatorType, Activation>;
using OperatorAttributeValue = variant<int,
float,
bool,
stack_vector<int, MAX_TENSOR_DIM>,
OperatorType,
Activation,
ff_dim_t,
unsigned long long,
AggregateOp,
stack_vector<ff_dim_t, MAX_TENSOR_DIM>,
RegularizerAttrs,
PoolOp,
TensorShape>;

using OperatorAttributeConstraint = AttributeConstraint<OperatorAttributeKey, OperatorAttributeValue>;
using OperatorAttributeConstraint =
AttributeConstraint<OperatorAttributeKey, OperatorAttributeValue>;

using OperatorPattern = AttributePattern<OperatorAttributeKey, OperatorAttributeValue>;
using OperatorPattern =
AttributePattern<OperatorAttributeKey, OperatorAttributeValue>;

optional<OperatorAttributeValue>
evaluate_attribute_expr(Operator const &attrs,
Expand Down
3 changes: 1 addition & 2 deletions lib/substitutions/include/substitutions/output_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ struct AttrConstant {
OperatorAttributeValue value;
};

using OperatorAttributeExpr =
variant<OperatorAttrAccess, AttrConstant>;
using OperatorAttributeExpr = variant<OperatorAttrAccess, AttrConstant>;

// NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can
// define the assignment for each operator type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ using TensorAttributeValue = variant<int, std::vector<int>>;
using TensorAttributeConstraint =
AttributeConstraint<TensorAttributeKey, TensorAttributeValue>;

using ParallelTensorPattern = AttributePattern<TensorAttributeKey, TensorAttributeValue>;
using ParallelTensorPattern =
AttributePattern<TensorAttributeKey, TensorAttributeValue>;

optional<TensorAttributeValue>
evaluate_attribute_expr(ParallelTensor const &tensor_shape,
Expand Down
56 changes: 43 additions & 13 deletions lib/substitutions/src/graph_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,45 @@

namespace FlexFlow {

template <typename T, typename V>
optional<V> evaluate_list_index_access(ListIndexAccess<T> const &index_access,
optional<V> const &v) {
optional<OperatorAttributeValue> evaluate_list_index_access(int index,
optional<OperatorAttributeValue> const &v) {
if (!v.has_value() ||
!holds_alternative<stack_vector<int, MAX_TENSOR_DIM>>(v.value()) ||
!holds_alternative<stack_vector<ff_dim_t, MAX_TENSOR_DIM>>(v.value())) {
return nullopt;
}

if (index >= MAX_TENSOR_DIM) {
return nullopt;
}

if (holds_alternative<stack_vector<int, MAX_TENSOR_DIM>>(v.value())) {
return get<stack_vector<int, MAX_TENSOR_DIM>>(v.value()).at(index);
} else {
return get<stack_vector<ff_dim_t, MAX_TENSOR_DIM>>(v.value()).at(index);
}
}

optional<TensorAttributeValue> evaluate_list_index_access(int const &index,
optional<TensorAttributeValue> const &v) {
if (!v.has_value() || !holds_alternative<std::vector<int>>(v.value())) {
return nullopt;
}

auto vec = get<std::vector<int>>(v.value());
if (index_access.index >= vec.size()) {

if (index >= vec.size()) {
return nullopt;
}

return vec.at(index_access.index);
return vec.at(index);
}

template <typename V>
optional<V> evaluate_list_size(optional<V> const &v) {
optional<OperatorAttributeValue> evaluate_list_size(optional<OperatorAttributeValue> const &v) {
return MAX_TENSOR_DIM;
}

optional<TensorAttributeValue> evaluate_list_size(optional<TensorAttributeValue> const &v) {
if (!v.has_value() || !holds_alternative<std::vector<int>>(v.value())) {
return nullopt;
}
Expand All @@ -44,7 +66,7 @@ struct EvaluateOperatorAttributeExpr {
operator()(ListIndexAccess<OperatorAttributeKey> const &index_access) {
optional<OperatorAttributeValue> v =
get_attribute(this->attrs, index_access.attribute_key);
return evaluate_list_index_access(index_access, v);
return evaluate_list_index_access(index_access.index, v);
}

optional<OperatorAttributeValue>
Expand Down Expand Up @@ -94,8 +116,9 @@ struct EvaluateTensorAttributeExpr {

optional<TensorAttributeValue>
operator()(ListIndexAccess<TensorAttributeKey> const &index_access) {
auto v = this->evaluate(index_access.attribute_key);
return evaluate_list_index_access(index_access, v);
optional<TensorAttributeValue> v =
this->evaluate(index_access.attribute_key);
return evaluate_list_index_access(index_access.index, v);
}

optional<TensorAttributeValue>
Expand Down Expand Up @@ -184,6 +207,13 @@ optional<bool> satisfies(ParallelTensor const &params,
[&](TensorAttributeConstraint const &c) { return satisfies(params, c); });
}

struct AlwaysTrueCriterion {
template <typename T>
bool operator()(T const &t) const {
return true;
}
};

bool assignment_satisfies(ParallelComputationGraph const &pcg,
GraphPattern const &pattern,
MultiDiGraphPatternMatch const &patternMatch) {
Expand All @@ -192,19 +222,19 @@ bool assignment_satisfies(ParallelComputationGraph const &pcg,
auto patternNode = kv.first;
auto pcgNode = kv.second;
optional<bool> constraintResult =
satisfies(pcg->at(pcgNode), pattern->at(patternNode));
satisfies(pcg.value().at(pcgNode), pattern.value().at(patternNode));
result &= constraintResult.value_or(false);
}

for (auto const &kv : patternMatch.edge_assignment) {
auto patternEdge = kv.first;
auto pcgEdge = kv.second;
optional<bool> constraintResult =
satisfies(pcg->at(pcgEdge), pattern->at(patternEdge));
satisfies(pcg.value().at(pcgEdge), pattern.value().at(patternEdge));
result &= constraintResult.value_or(false);
}

result &= pattern_matches(pattern, pcg, patternMatch);
result &= pattern_matches(pattern, pcg, patternMatch, AlwaysTrueCriterion{});

return result;
}
Expand Down
Loading

0 comments on commit e3b633f

Please sign in to comment.