From 4a68ceecf7d700fd35aea05b1dda9dcb9ac1a03f Mon Sep 17 00:00:00 2001 From: fintarin Date: Thu, 6 Jul 2023 15:05:01 +0300 Subject: [PATCH] Refactoring of DivExpression --- .../expressions/binary/DivExpression.cpp | 68 ++++++++++--------- .../expressions/binary/DivExpression.hpp | 12 ++-- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/src/fintamath/expressions/binary/DivExpression.cpp b/src/fintamath/expressions/binary/DivExpression.cpp index 46a1899b9..cef1d89df 100644 --- a/src/fintamath/expressions/binary/DivExpression.cpp +++ b/src/fintamath/expressions/binary/DivExpression.cpp @@ -27,10 +27,10 @@ DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForSimplify() DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForPostSimplify() const { static const DivExpression::SimplifyFunctionsVector simplifyFunctions = { - &DivExpression::zeroSimplify, // - &DivExpression::negSimplify, // - &DivExpression::sumSimplify, // - &DivExpression::polynomSimplify, // + &DivExpression::zeroSimplify, // + &DivExpression::negSimplify, // + &DivExpression::sumSimplify, // + &DivExpression::nestedDivSimplify, // }; return simplifyFunctions; } @@ -288,8 +288,8 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument return {}; } - ArgumentsPtrVector divSuccess; - ArgumentsPtrVector divFailure; + ArgumentsPtrVector result; + ArgumentsPtrVector remainder; for (const auto &child : lhsChildren) { ArgumentPtr divResult = makeExpr(Div(), child, rhs); @@ -297,23 +297,23 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument if (const auto divResultExpr = cast(divResult); divResultExpr && is
(divResultExpr->getFunction()) && *divResultExpr->getChildren().back() == *rhs) { - divFailure.emplace_back(child); + remainder.emplace_back(child); } else { - divSuccess.emplace_back(divResult); + result.emplace_back(divResult); } } - if (divFailure.size() == lhsChildren.size()) { + if (remainder.size() == lhsChildren.size()) { return {}; } - if (!divFailure.empty()) { - ArgumentPtr divExpr = makeExpr(Div(), makeExpr(Add(), divFailure), rhs); - divSuccess.emplace_back(divExpr); + if (!remainder.empty()) { + ArgumentPtr divExpr = makeExpr(Div(), makeExpr(Add(), remainder), rhs); + result.emplace_back(divExpr); } - return makeExpr(Add(), divSuccess); + return makeExpr(Add(), result); } std::pair DivExpression::mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { @@ -328,22 +328,22 @@ std::pair DivExpression::mulSumSimplify(const Argument return {}; } - ArgumentPtr divResult = makeExpr(Div(), lhs, rhsChildren.front()); - simplifyChild(divResult); + ArgumentPtr result = makeExpr(Div(), lhs, rhsChildren.front()); + simplifyChild(result); - if (const auto divExpr = cast(divResult); divExpr && is
(divExpr->getFunction())) { + if (const auto divExpr = cast(result); divExpr && is
(divExpr->getFunction())) { return {}; } ArgumentsPtrVector multiplicator; for (size_t i = 1; i < rhsChildren.size(); i++) { - multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], divResult)); + multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], result)); } ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicator)); - ArgumentPtr div = makeExpr(Div(), negSum, rhs); - return {divResult, div}; + ArgumentPtr remainder = makeExpr(Div(), negSum, rhs); + return {result, remainder}; } ArgumentPtr DivExpression::divPowSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) { @@ -404,16 +404,16 @@ ArgumentPtr DivExpression::addRatesToValue(const ArgumentsPtrVector &rates, cons return makeExpr(Pow(), value, ratesSum); } -ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) { +ArgumentPtr DivExpression::nestedDivSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, + const ArgumentPtr &rhs) { ArgumentPtr result; if (const auto &lhsExpr = cast(lhs)) { - if (is(lhsExpr->getFunction())) { - result = numeratorSumSimplify(lhsExpr->getChildren(), rhs); - } - if (is(lhsExpr->getFunction())) { - result = numeratorMulSimplify(lhsExpr->getChildren(), rhs); + result = nestedDivInNumeratorMulSimplify(lhsExpr->getChildren(), rhs); + } + else if (is(lhsExpr->getFunction())) { + result = nestedDivInNumeratorSumSimplify(lhsExpr->getChildren(), rhs); } } @@ -424,7 +424,7 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg if (const auto &rhsExpr = cast(rhs)) { if (is(rhsExpr->getFunction())) { - result = denominatorSumSimplify(lhs, rhs, rhsExpr->getChildren()); + result = nestedDivInDenominatorSumSimplify(lhs, rhs, rhsExpr->getChildren()); } } @@ -432,7 +432,8 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg return result; } -ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs) { +ArgumentPtr DivExpression::nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, + const ArgumentPtr &rhs) { ArgumentsPtrVector newNumerator; ArgumentsPtrVector resultPolynom; @@ -441,7 +442,7 @@ ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChi bool isNeg = unwrapNeg(childForCheck); if (const auto &exprChild = cast(childForCheck); exprChild && is(exprChild->getFunction())) { - if (auto result = numeratorMulSimplify(exprChild->getChildren(), rhs)) { + if (auto result = nestedDivInNumeratorMulSimplify(exprChild->getChildren(), rhs)) { resultPolynom.emplace_back(result); continue; } @@ -474,7 +475,8 @@ ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChi return makeExpr(Add(), resultPolynom); } -ArgumentPtr DivExpression::numeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs) { +ArgumentPtr DivExpression::nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, + const ArgumentPtr &rhs) { ArgumentsPtrVector numeratorChildren; ArgumentsPtrVector denominatorChildren; @@ -506,8 +508,8 @@ bool DivExpression::unwrapNeg(ArgumentPtr &lhs) { return false; } -ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs, - const ArgumentsPtrVector &rhsChildren) { +ArgumentPtr DivExpression::nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs, + const ArgumentsPtrVector &rhsChildren) { ArgumentsPtrVector multiplicator; for (const auto &child : rhsChildren) { @@ -525,7 +527,7 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const } if (const auto &exprChild = cast(childForCheck); exprChild && is(exprChild->getFunction())) { - if (const auto &childForAdd = denominatorMulSimplify(exprChild->getChildren())) { + if (const auto &childForAdd = nestedDivInDenominatorMulSimplify(exprChild->getChildren())) { multiplicator.emplace_back(childForAdd); } } @@ -550,7 +552,7 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const return makeExpr(Div(), numerator, denominator); } -ArgumentPtr DivExpression::denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) { +ArgumentPtr DivExpression::nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) { ArgumentsPtrVector multiplicator; for (const auto &child : rhsChildren) { diff --git a/src/fintamath/expressions/binary/DivExpression.hpp b/src/fintamath/expressions/binary/DivExpression.hpp index 3a290d1d7..d0cbf83a5 100644 --- a/src/fintamath/expressions/binary/DivExpression.hpp +++ b/src/fintamath/expressions/binary/DivExpression.hpp @@ -31,16 +31,16 @@ class DivExpression : public IBinaryExpressionCRTP { static ArgumentPtr sumSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs); - static ArgumentPtr polynomSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs); + static ArgumentPtr nestedDivSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs); - static ArgumentPtr numeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs); + static ArgumentPtr nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs); - static ArgumentPtr numeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs); + static ArgumentPtr nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs); - static ArgumentPtr denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren); + static ArgumentPtr nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs, + const ArgumentsPtrVector &rhsChildren); - static ArgumentPtr denominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs, - const ArgumentsPtrVector &rhsChildren); + static ArgumentPtr nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren); static ArgumentPtr sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);