From b41a871ecb5c9260417b90a65853f08a06ec223f Mon Sep 17 00:00:00 2001 From: fintarin Date: Wed, 12 Jul 2023 13:57:06 +0300 Subject: [PATCH] Refactoring & fix of DivExpression --- .../expressions/binary/DivExpression.cpp | 45 ++++++++++--------- .../expressions/binary/DivExpression.hpp | 4 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/fintamath/expressions/binary/DivExpression.cpp b/src/fintamath/expressions/binary/DivExpression.cpp index 5bbb173c8..6707e1e12 100644 --- a/src/fintamath/expressions/binary/DivExpression.cpp +++ b/src/fintamath/expressions/binary/DivExpression.cpp @@ -209,24 +209,22 @@ bool DivExpression::isNeg(const ArgumentPtr &expr) { } ArgumentPtr DivExpression::sumSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) { - if (auto res = sumMulSimplify(lhs, rhs)) { - return res; + if (auto [result, remainder] = sumMulSimplify(lhs, rhs); result) { + return makeExpr(Add(), result, remainder); } if (auto [result, remainder] = mulSumSimplify(lhs, rhs); result) { - simplifyChild(remainder); - ArgumentPtr res = makeExpr(Add(), result, remainder); - return res; + return makeExpr(Add(), result, remainder); } - if (auto res = sumSumSimplify(lhs, rhs)) { - return res; + if (auto [result, remainder] = sumSumSimplify(lhs, rhs); result) { + return makeExpr(Add(), result, remainder); } return {}; } -ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { +std::pair DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { ArgumentsPtrVector lhsChildren; ArgumentsPtrVector rhsChildren; @@ -265,15 +263,16 @@ ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const Argument return {}; } - ArgumentPtr remainder = makeExpr(Add(), remainderVect); - simplifyChild(remainder); - resultVect.emplace_back(makeExpr(Div(), remainder, rhs)); - ArgumentPtr result = makeExpr(Add(), resultVect); - return result; + + ArgumentPtr remainderAdd = makeExpr(Add(), remainderVect); + simplifyChild(remainderAdd); + ArgumentPtr remainder = makeExpr(Div(), remainderAdd, rhs); + + return {result, remainder}; } -ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { +std::pair DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { ArgumentsPtrVector lhsChildren; if (const auto lhsExpr = cast(lhs); lhsExpr && is(lhsExpr->getFunction())) { @@ -305,12 +304,16 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument return {}; } + ArgumentPtr result = makeExpr(Add(), resultChildren); + + ArgumentPtr remainder; if (!remainderChildren.empty()) { - ArgumentPtr remainder = makeExpr(Div(), makeExpr(Add(), remainderChildren), rhs); - resultChildren.emplace_back(remainder); + ArgumentPtr remainderAdd = makeExpr(Add(), remainderChildren); + simplifyChild(remainder); + remainder = makeExpr(Div(), remainderAdd, rhs); } - return makeExpr(Add(), resultChildren); + return {result, remainder}; } std::pair DivExpression::mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { @@ -338,8 +341,10 @@ std::pair DivExpression::mulSumSimplify(const Argument multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], result)); } - ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicator)); - ArgumentPtr remainder = makeExpr(Div(), negSum, rhs); + ArgumentPtr remainderNegAdd = makeExpr(Neg(), makeExpr(Add(), multiplicator)); + simplifyChild(remainderNegAdd); + ArgumentPtr remainder = makeExpr(Div(), remainderNegAdd, rhs); + return {result, remainder}; } @@ -464,7 +469,7 @@ ArgumentPtr DivExpression::nestedRationalsInDenominatorSimplify(const ArgumentPt denominatorChildren.emplace_back(child); } - if (!denominatorChildren.empty()) { + if (!numeratorChildren.empty()) { numeratorChildren.emplace_back(lhs); ArgumentPtr numerator = diff --git a/src/fintamath/expressions/binary/DivExpression.hpp b/src/fintamath/expressions/binary/DivExpression.hpp index 1579dca09..bbb6d640b 100644 --- a/src/fintamath/expressions/binary/DivExpression.hpp +++ b/src/fintamath/expressions/binary/DivExpression.hpp @@ -38,9 +38,9 @@ class DivExpression : public IBinaryExpressionCRTP { static ArgumentPtr nestedRationalsInDenominatorSimplify(const ArgumentPtr &lhs, const ArgumentsPtrVector &rhsChildren); - static ArgumentPtr sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs); + static std::pair sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs); - static ArgumentPtr sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs); + static std::pair sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs); static std::pair mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);