Skip to content

Commit

Permalink
check substitution validity
Browse files Browse the repository at this point in the history
  • Loading branch information
wmdi committed Aug 30, 2023
1 parent 08dd3fe commit 82e2c2c
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion lib/substitutions/src/substitution.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,57 @@
#include "substitutions/substitution.h"
#include

namespace FlexFlow {

std::unordered_set<AttributeExpr<OperatorAttributeKey>>
get_valid_operator_attribute_exprs(OperatorPattern const &pattern) {
NOT_IMPLEMENTED();
}

bool is_valid_operator_attribute_expr(
OperatorPattern const &pattern,
AttributeExpr<OperatorAttributeKey> const &expr) {
return contains(get_valid_operator_attribute_exprs(pattern), expr);
}

struct IsValidGraphAttributeExprFunctor {
GraphPattern const &graph_pattern;

template <typename T>
bool operator()(T const &t) const {
return is_valid(t);
}

bool is_valid(NodeAttrAccess const &t) const {
return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node),
t.attr_expr);
}

bool is_valid(EdgeAttrAccess const &t) const {
NOT_IMPLEMENTED();
}

bool is_valid(AttrConstant const &t) const {
return true;
}
};

bool is_valid_graph_attribute_expr(GraphPattern const &pattern,
GraphAttributeExpr const &expr) {
return visit(IsValidGraphAttributeExprFunctor{pattern}, expr);
}

bool is_valid_substitution(Substitution const &s) {
for (Node const &node : get_nodes(s.output_graph_expr)) {
for (GraphAttributeExpr expr :
values(s.output_graph_expr.value().at(node).assignment)) {
if (!is_valid_graph_attribute_expr(s.input_graph, expr)) {
return false;
}
}
}
return true;
}

struct EvaluateGraphAttributeExpr {
ParallelComputationGraph const &graph;
MultiDiGraphPatternMatch const &match;
Expand Down

0 comments on commit 82e2c2c

Please sign in to comment.