From 82e2c2c061ed8115ab7e167efa7bd584b9854a5f Mon Sep 17 00:00:00 2001 From: wmdi Date: Wed, 30 Aug 2023 16:59:06 -0400 Subject: [PATCH] check substitution validity --- lib/substitutions/src/substitution.cc | 51 ++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index 1b56a0443f..e83d522cd7 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -1,8 +1,57 @@ #include "substitutions/substitution.h" -#include namespace FlexFlow { +std::unordered_set> + get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { + NOT_IMPLEMENTED(); +} + +bool is_valid_operator_attribute_expr( + OperatorPattern const &pattern, + AttributeExpr const &expr) { + return contains(get_valid_operator_attribute_exprs(pattern), expr); +} + +struct IsValidGraphAttributeExprFunctor { + GraphPattern const &graph_pattern; + + template + bool operator()(T const &t) const { + return is_valid(t); + } + + bool is_valid(NodeAttrAccess const &t) const { + return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), + t.attr_expr); + } + + bool is_valid(EdgeAttrAccess const &t) const { + NOT_IMPLEMENTED(); + } + + bool is_valid(AttrConstant const &t) const { + return true; + } +}; + +bool is_valid_graph_attribute_expr(GraphPattern const &pattern, + GraphAttributeExpr const &expr) { + return visit(IsValidGraphAttributeExprFunctor{pattern}, expr); +} + +bool is_valid_substitution(Substitution const &s) { + for (Node const &node : get_nodes(s.output_graph_expr)) { + for (GraphAttributeExpr expr : + values(s.output_graph_expr.value().at(node).assignment)) { + if (!is_valid_graph_attribute_expr(s.input_graph, expr)) { + return false; + } + } + } + return true; +} + struct EvaluateGraphAttributeExpr { ParallelComputationGraph const &graph; MultiDiGraphPatternMatch const &match;