From 3923b4e193dcaf7c4301d442bb1506bdb2d35dcb Mon Sep 17 00:00:00 2001 From: fintarin Date: Sat, 6 Apr 2024 11:26:15 +0400 Subject: [PATCH] Refactor ExpressionComparator --- .../expressions/ExpressionComparator.cpp | 247 +++++++----------- .../interfaces/IPolynomExpression.cpp | 2 +- tests/src/overall/simplify/SimplifyTests.cpp | 2 +- 3 files changed, 101 insertions(+), 150 deletions(-) diff --git a/src/fintamath/expressions/ExpressionComparator.cpp b/src/fintamath/expressions/ExpressionComparator.cpp index 8e57661e6..616a1d21e 100644 --- a/src/fintamath/expressions/ExpressionComparator.cpp +++ b/src/fintamath/expressions/ExpressionComparator.cpp @@ -16,52 +16,46 @@ #include "fintamath/functions/FunctionArguments.hpp" #include "fintamath/functions/FunctionUtils.hpp" #include "fintamath/functions/IFunction.hpp" +#include "fintamath/functions/arithmetic/Mul.hpp" +#include "fintamath/functions/logic/Not.hpp" +#include "fintamath/functions/powers/Pow.hpp" #include "fintamath/literals/ILiteral.hpp" #include "fintamath/literals/Variable.hpp" namespace fintamath::detail { -using ExpressionTreePathStack = std::stack, size_t>>; - using Ordering = std::strong_ordering; +using FunctionPtr = std::shared_ptr; + +using ExpressionPtr = std::shared_ptr; + +using PolynomPtr = std::shared_ptr; + +using ExpressionTreePathStack = std::stack>; + struct ChildrenComparatorResult final { Ordering postfix = Ordering::equal; - Ordering postfixUnary = Ordering::equal; Ordering prefixFirst = Ordering::equal; Ordering prefixLast = Ordering::equal; - Ordering prefixVariables = Ordering::equal; - Ordering prefixLiterals = Ordering::equal; + Ordering unary = Ordering::equal; + Ordering literals = Ordering::equal; Ordering size = Ordering::equal; }; -Ordering compareNonExpressions(const ArgumentPtr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options); +Ordering compareNonExpressions(const ArgumentPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options); -Ordering comparePolynoms(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options); +Ordering comparePolynoms(const PolynomPtr &lhs, const PolynomPtr &rhs, const ComparatorOptions &options); -Ordering compareExpressions(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options); +Ordering compareExpressions(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ComparatorOptions &options); -Ordering comparePolynomAndNonPolynom(const std::shared_ptr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options); +Ordering comparePolynomAndNonPolynom(const PolynomPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options); -Ordering compareExpressionAndNonExpression(const std::shared_ptr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options); +Ordering compareExpressionAndNonExpression(const ExpressionPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options); -Ordering compareFunctions(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options); +Ordering compareFunctions(const FunctionPtr &lhs, const FunctionPtr &rhs, const ComparatorOptions &options); -ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, - const ArgumentPtrVector &rhsChildren, - const ComparatorOptions &options); +ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, const ArgumentPtrVector &rhsChildren, const ComparatorOptions &options); bool unwrapUnaryExpression(ArgumentPtr &arg); @@ -69,10 +63,10 @@ bool unwrapEmptyExpression(ArgumentPtr &arg); Ordering reverse(Ordering ordering); -template +template size_t getPositionOfFirstChildWithTerm(const ArgumentPtrVector &children) { for (size_t i = 0; i < children.size(); i++) { - if (containsIf(children[i], [](const ArgumentPtr &child) { return is(child); })) { + if (containsIf(children[i], [](const ArgumentPtr &child) { return is(child); })) { return i; } } @@ -80,8 +74,8 @@ size_t getPositionOfFirstChildWithTerm(const ArgumentPtrVector &children) { return children.size(); } -template -std::shared_ptr popNextTerm(ExpressionTreePathStack &stack) { +template +std::shared_ptr popNextTerm(ExpressionTreePathStack &stack) { while (!stack.empty()) { const ArgumentPtrVector &children = stack.top().first->getChildren(); @@ -98,7 +92,7 @@ std::shared_ptr popNextTerm(ExpressionTreePathStack &stack) { break; } - if (const auto &varChild = cast(children[exprIndex])) { + if (const auto &varChild = cast(children[exprIndex])) { return varChild; } } @@ -113,30 +107,27 @@ std::shared_ptr popNextTerm(ExpressionTreePathStack &stack) { return {}; } -template -Ordering compareTerms(const ArgumentPtr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options) { - +template +Ordering compareTerms(const ArgumentPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options) { ExpressionTreePathStack lhsPath; ExpressionTreePathStack rhsPath; - std::shared_ptr lhsTerm; - std::shared_ptr rhsTerm; + std::shared_ptr lhsTerm; + std::shared_ptr rhsTerm; if (const auto &expr = cast(lhs)) { lhsPath.emplace(expr, -1); - lhsTerm = popNextTerm(lhsPath); + lhsTerm = popNextTerm(lhsPath); } - else if (const auto &term = cast(lhs)) { + else if (const auto &term = cast(lhs)) { lhsTerm = term; } if (const auto &expr = cast(rhs)) { rhsPath.emplace(expr, -1); - rhsTerm = popNextTerm(rhsPath); + rhsTerm = popNextTerm(rhsPath); } - else if (const auto &term = cast(rhs)) { + else if (const auto &term = cast(rhs)) { rhsTerm = term; } @@ -153,17 +144,14 @@ Ordering compareTerms(const ArgumentPtr &lhs, return res; } - lhsTerm = popNextTerm(lhsPath); - rhsTerm = popNextTerm(rhsPath); + lhsTerm = popNextTerm(lhsPath); + rhsTerm = popNextTerm(rhsPath); } return Ordering::equal; } -Ordering compare(ArgumentPtr lhs, - ArgumentPtr rhs, - const ComparatorOptions options) { - +Ordering compare(ArgumentPtr lhs, ArgumentPtr rhs, const ComparatorOptions options) { unwrapEmptyExpression(lhs); unwrapEmptyExpression(rhs); @@ -200,10 +188,7 @@ Ordering compare(ArgumentPtr lhs, return compareExpressions(lhsExpr, rhsExpr, options); } -Ordering compareNonExpressions(const ArgumentPtr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options) { - +Ordering compareNonExpressions(const ArgumentPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options) { if (is(lhs) && !is(rhs)) { return !options.termOrderInversed ? Ordering::greater : Ordering::less; } @@ -235,21 +220,18 @@ Ordering compareNonExpressions(const ArgumentPtr &lhs, return lhs->toString() < rhs->toString() ? Ordering::greater : Ordering::less; } -Ordering comparePolynoms(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options) { - +Ordering comparePolynoms(const PolynomPtr &lhs, const PolynomPtr &rhs, const ComparatorOptions &options) { const ChildrenComparatorResult childrenComp = compareChildren(lhs->getChildren(), rhs->getChildren(), options); if (childrenComp.postfix != Ordering::equal) { return childrenComp.postfix; } - if (childrenComp.postfixUnary != Ordering::equal) { - return childrenComp.postfixUnary; - } if (childrenComp.size != Ordering::equal) { return childrenComp.size; } + if (childrenComp.unary != Ordering::equal) { + return childrenComp.unary; + } if (childrenComp.prefixFirst != Ordering::equal) { return childrenComp.prefixFirst; } @@ -257,10 +239,7 @@ Ordering comparePolynoms(const std::shared_ptr &lhs, return compareFunctions(lhs->getFunction(), rhs->getFunction(), options); } -Ordering compareExpressions(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options) { - +Ordering compareExpressions(const ExpressionPtr &lhs, const ExpressionPtr &rhs, const ComparatorOptions &options) { const auto lhsOper = cast(lhs->getFunction()); const auto rhsOper = cast(rhs->getFunction()); @@ -272,11 +251,8 @@ Ordering compareExpressions(const std::shared_ptr &lhs, childCompOptions.termOrderInversed = false; const ChildrenComparatorResult childrenComp = compareChildren(lhs->getChildren(), rhs->getChildren(), childCompOptions); - if (childrenComp.prefixVariables != Ordering::equal) { - return childrenComp.prefixVariables; - } - if (childrenComp.prefixLiterals != Ordering::equal) { - return childrenComp.prefixLiterals; + if (childrenComp.literals != Ordering::equal) { + return childrenComp.literals; } if (childrenComp.size != Ordering::equal) { return childrenComp.size; @@ -299,10 +275,7 @@ Ordering compareExpressions(const std::shared_ptr &lhs, return compareFunctions(lhs->getFunction(), rhs->getFunction(), options); } -Ordering comparePolynomAndNonPolynom(const std::shared_ptr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options) { - +Ordering comparePolynomAndNonPolynom(const PolynomPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options) { const ChildrenComparatorResult childrenComp = compareChildren(lhs->getChildren(), {rhs}, options); if (childrenComp.postfix != Ordering::equal) { @@ -312,10 +285,7 @@ Ordering comparePolynomAndNonPolynom(const std::shared_ptr &lhs, - const ArgumentPtr &rhs, - const ComparatorOptions &options) { - +Ordering compareExpressionAndNonExpression(const ExpressionPtr &lhs, const ArgumentPtr &rhs, const ComparatorOptions &options) { if (!is(rhs)) { return !options.termOrderInversed ? Ordering::greater : Ordering::less; } @@ -332,35 +302,24 @@ Ordering compareExpressionAndNonExpression(const std::shared_ptr(lhs->getFunction())) { - switch (lhsOper->getPriority()) { - case IOperator::Priority::PostfixUnary: - case IOperator::Priority::PrefixUnary: { - if (const Ordering res = compare(lhs->getChildren().front(), rhs); res != Ordering::equal) { - return res; - } - - return Ordering::less; - } - case IOperator::Priority::Exponentiation: - case IOperator::Priority::Multiplication: { - const ArgumentPtr rhsExpr = makeExpr(*lhsOper, rhs, Integer(1).clone()); - const Ordering res = compare(lhs, rhsExpr); - return options.termOrderInversed ? reverse(res) : res; - } - default: { - break; - } + if (is(lhs->getFunction())) { + if (const Ordering res = compare(lhs->getChildren().front(), rhs); res != Ordering::equal) { + return res; } + + return Ordering::less; + } + + if (is(lhs->getFunction()) || is(lhs->getFunction())) { + const ArgumentPtr rhsExpr = makeExpr(*lhs->getFunction(), rhs, Integer(1).clone()); + const Ordering res = compare(lhs, rhsExpr); + return !options.termOrderInversed ? res : reverse(res); } return !options.termOrderInversed ? Ordering::greater : Ordering::less; } -Ordering compareFunctions(const std::shared_ptr &lhs, - const std::shared_ptr &rhs, - const ComparatorOptions &options) { - +Ordering compareFunctions(const FunctionPtr &lhs, const FunctionPtr &rhs, const ComparatorOptions &options) { if (is(lhs) && !is(rhs)) { return options.termOrderInversed ? Ordering::greater : Ordering::less; } @@ -379,16 +338,39 @@ Ordering compareFunctions(const std::shared_ptr &lhs, return lhs->toString() < rhs->toString() ? Ordering::greater : Ordering::less; } -ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, - const ArgumentPtrVector &rhsChildren, - const ComparatorOptions &options) { - +ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, const ArgumentPtrVector &rhsChildren, const ComparatorOptions &options) { ChildrenComparatorResult result = {}; - const size_t lhsStart = getPositionOfFirstChildWithTerm(lhsChildren); - const size_t rhsStart = getPositionOfFirstChildWithTerm(rhsChildren); + if (lhsChildren.size() != rhsChildren.size()) { + result.size = lhsChildren.size() > rhsChildren.size() ? Ordering::greater : Ordering::less; + } + + const size_t lhsPostfixStart = getPositionOfFirstChildWithTerm(lhsChildren); + const size_t rhsPostfixStart = getPositionOfFirstChildWithTerm(rhsChildren); + const size_t prefixSize = std::min(std::max(lhsPostfixStart, rhsPostfixStart), + std::min(lhsChildren.size(), rhsChildren.size())); + + for (const auto i : stdv::iota(0U, prefixSize)) { + const Ordering childrenComp = compare(lhsChildren[i], rhsChildren[i], options); + + if (childrenComp != Ordering::equal) { + result.prefixLast = childrenComp; + } + + if (result.prefixFirst == Ordering::equal) { + result.prefixFirst = childrenComp; + } + + if (result.literals == Ordering::equal) { + result.literals = compareTerms(lhsChildren[i], rhsChildren[i], {}); + } + + if (result.literals != Ordering::equal && result.prefixLast != Ordering::equal) { + break; + } + } - for (size_t i = lhsStart, j = rhsStart; i < lhsChildren.size() && j < rhsChildren.size(); i++, j++) { + for (size_t i = lhsPostfixStart, j = rhsPostfixStart; i < lhsChildren.size() && j < rhsChildren.size(); i++, j++) { ArgumentPtr compLhs = lhsChildren[i]; ArgumentPtr compRhs = rhsChildren[j]; @@ -400,8 +382,12 @@ ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, compRhs = rhsChildren[j]; } - if (result.postfixUnary == Ordering::equal && isLhsUnary != isRhsUnary) { - result.postfixUnary = !isLhsUnary ? Ordering::greater : Ordering::less; + if (result.literals == Ordering::equal) { + result.literals = compareTerms(compLhs, compRhs, options); + } + + if (result.unary == Ordering::equal && isLhsUnary != isRhsUnary) { + result.unary = !isLhsUnary ? Ordering::greater : Ordering::less; } if (result.postfix == Ordering::equal) { @@ -414,49 +400,14 @@ ChildrenComparatorResult compareChildren(const ArgumentPtrVector &lhsChildren, } if (result.postfix == Ordering::equal) { - const size_t lhsPostfixSize = lhsChildren.size() - lhsStart; - const size_t rhsPostfixSize = rhsChildren.size() - rhsStart; + const size_t lhsPostfixSize = lhsChildren.size() - lhsPostfixStart; + const size_t rhsPostfixSize = rhsChildren.size() - rhsPostfixStart; if (lhsPostfixSize != rhsPostfixSize) { result.postfix = lhsPostfixSize > rhsPostfixSize ? Ordering::greater : Ordering::less; } } - if (lhsChildren.size() != rhsChildren.size()) { - result.postfixUnary = Ordering::equal; - } - - auto size = std::min(std::max(lhsStart, rhsStart), - std::min(lhsChildren.size(), rhsChildren.size())); - - for (size_t i = 0; i < size; i++) { - const Ordering childrenComp = compare(lhsChildren[i], rhsChildren[i], options); - - if (childrenComp != Ordering::equal) { - result.prefixLast = childrenComp; - } - - if (result.prefixFirst == Ordering::equal) { - result.prefixFirst = childrenComp; - } - - if (result.prefixVariables == Ordering::equal) { - result.prefixVariables = compareTerms(lhsChildren[i], rhsChildren[i], {}); - } - - if (result.prefixLiterals == Ordering::equal) { - result.prefixLiterals = compareTerms(lhsChildren[i], rhsChildren[i], {}); - } - - if (result.prefixLiterals != Ordering::equal && result.prefixLast != Ordering::equal) { - break; - } - } - - if (lhsChildren.size() != rhsChildren.size()) { - result.size = lhsChildren.size() > rhsChildren.size() ? Ordering::greater : Ordering::less; - } - return result; } @@ -474,8 +425,8 @@ bool unwrapUnaryExpression(ArgumentPtr &arg) { bool unwrapEmptyExpression(ArgumentPtr &arg) { if (const auto expr = cast(arg); - expr && - !expr->getFunction()) { + expr && + !expr->getFunction()) { arg = expr->getChildren().front(); return true; @@ -488,4 +439,4 @@ Ordering reverse(const Ordering ordering) { return 0 <=> ordering; } -} +} \ No newline at end of file diff --git a/src/fintamath/expressions/interfaces/IPolynomExpression.cpp b/src/fintamath/expressions/interfaces/IPolynomExpression.cpp index bba48aab8..c84cbd762 100644 --- a/src/fintamath/expressions/interfaces/IPolynomExpression.cpp +++ b/src/fintamath/expressions/interfaces/IPolynomExpression.cpp @@ -230,4 +230,4 @@ void IPolynomExpression::sort() { }); } -} +} \ No newline at end of file diff --git a/tests/src/overall/simplify/SimplifyTests.cpp b/tests/src/overall/simplify/SimplifyTests.cpp index e8e4be571..d25724fa6 100644 --- a/tests/src/overall/simplify/SimplifyTests.cpp +++ b/tests/src/overall/simplify/SimplifyTests.cpp @@ -605,7 +605,7 @@ TEST(SimplifyTests, simplifyTest) { EXPECT_EQ(Expression("cos(b) log(b, a)").toString(), "log(b, a) cos(b)"); EXPECT_EQ(Expression("cos(a) log(b, c)").toString(), - "log(b, c) cos(a)"); + "cos(a) log(b, c)"); EXPECT_EQ(Expression("cos(b^2) log(b, c)").toString(), "log(b, c) cos(b^2)"); EXPECT_EQ(Expression("(x + y^3)^2 * sin(x)/ln(2)/x^2 - (2 sin(x) y^3)/(x ln(2))").toString(),