Skip to content

Commit

Permalink
Refactoring of DivExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 7, 2023
1 parent 4984f1e commit 4a68cee
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 39 deletions.
68 changes: 35 additions & 33 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForSimplify()

DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForPostSimplify() const {
static const DivExpression::SimplifyFunctionsVector simplifyFunctions = {
&DivExpression::zeroSimplify, //
&DivExpression::negSimplify, //
&DivExpression::sumSimplify, //
&DivExpression::polynomSimplify, //
&DivExpression::zeroSimplify, //
&DivExpression::negSimplify, //
&DivExpression::sumSimplify, //
&DivExpression::nestedDivSimplify, //
};
return simplifyFunctions;
}
Expand Down Expand Up @@ -288,32 +288,32 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument
return {};
}

ArgumentsPtrVector divSuccess;
ArgumentsPtrVector divFailure;
ArgumentsPtrVector result;
ArgumentsPtrVector remainder;

for (const auto &child : lhsChildren) {
ArgumentPtr divResult = makeExpr(Div(), child, rhs);
simplifyChild(divResult);

if (const auto divResultExpr = cast<IExpression>(divResult);
divResultExpr && is<Div>(divResultExpr->getFunction()) && *divResultExpr->getChildren().back() == *rhs) {
divFailure.emplace_back(child);
remainder.emplace_back(child);
}
else {
divSuccess.emplace_back(divResult);
result.emplace_back(divResult);
}
}

if (divFailure.size() == lhsChildren.size()) {
if (remainder.size() == lhsChildren.size()) {
return {};
}

if (!divFailure.empty()) {
ArgumentPtr divExpr = makeExpr(Div(), makeExpr(Add(), divFailure), rhs);
divSuccess.emplace_back(divExpr);
if (!remainder.empty()) {
ArgumentPtr divExpr = makeExpr(Div(), makeExpr(Add(), remainder), rhs);
result.emplace_back(divExpr);
}

return makeExpr(Add(), divSuccess);
return makeExpr(Add(), result);
}

std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand All @@ -328,22 +328,22 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
return {};
}

ArgumentPtr divResult = makeExpr(Div(), lhs, rhsChildren.front());
simplifyChild(divResult);
ArgumentPtr result = makeExpr(Div(), lhs, rhsChildren.front());
simplifyChild(result);

if (const auto divExpr = cast<IExpression>(divResult); divExpr && is<Div>(divExpr->getFunction())) {
if (const auto divExpr = cast<IExpression>(result); divExpr && is<Div>(divExpr->getFunction())) {
return {};
}

ArgumentsPtrVector multiplicator;

for (size_t i = 1; i < rhsChildren.size(); i++) {
multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], divResult));
multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], result));
}

ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicator));
ArgumentPtr div = makeExpr(Div(), negSum, rhs);
return {divResult, div};
ArgumentPtr remainder = makeExpr(Div(), negSum, rhs);
return {result, remainder};
}

ArgumentPtr DivExpression::divPowSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand Down Expand Up @@ -404,16 +404,16 @@ ArgumentPtr DivExpression::addRatesToValue(const ArgumentsPtrVector &rates, cons
return makeExpr(Pow(), value, ratesSum);
}

ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentPtr DivExpression::nestedDivSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs,
const ArgumentPtr &rhs) {
ArgumentPtr result;

if (const auto &lhsExpr = cast<IExpression>(lhs)) {
if (is<Add>(lhsExpr->getFunction())) {
result = numeratorSumSimplify(lhsExpr->getChildren(), rhs);
}

if (is<Mul>(lhsExpr->getFunction())) {
result = numeratorMulSimplify(lhsExpr->getChildren(), rhs);
result = nestedDivInNumeratorMulSimplify(lhsExpr->getChildren(), rhs);
}
else if (is<Add>(lhsExpr->getFunction())) {
result = nestedDivInNumeratorSumSimplify(lhsExpr->getChildren(), rhs);
}
}

Expand All @@ -424,15 +424,16 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg

if (const auto &rhsExpr = cast<IExpression>(rhs)) {
if (is<Add>(rhsExpr->getFunction())) {
result = denominatorSumSimplify(lhs, rhs, rhsExpr->getChildren());
result = nestedDivInDenominatorSumSimplify(lhs, rhs, rhsExpr->getChildren());
}
}

simplifyChild(result);
return result;
}

ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs) {
ArgumentPtr DivExpression::nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren,
const ArgumentPtr &rhs) {
ArgumentsPtrVector newNumerator;
ArgumentsPtrVector resultPolynom;

Expand All @@ -441,7 +442,7 @@ ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChi
bool isNeg = unwrapNeg(childForCheck);

if (const auto &exprChild = cast<IExpression>(childForCheck); exprChild && is<Mul>(exprChild->getFunction())) {
if (auto result = numeratorMulSimplify(exprChild->getChildren(), rhs)) {
if (auto result = nestedDivInNumeratorMulSimplify(exprChild->getChildren(), rhs)) {
resultPolynom.emplace_back(result);
continue;
}
Expand Down Expand Up @@ -474,7 +475,8 @@ ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChi
return makeExpr(Add(), resultPolynom);
}

ArgumentPtr DivExpression::numeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs) {
ArgumentPtr DivExpression::nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren,
const ArgumentPtr &rhs) {
ArgumentsPtrVector numeratorChildren;
ArgumentsPtrVector denominatorChildren;

Expand Down Expand Up @@ -506,8 +508,8 @@ bool DivExpression::unwrapNeg(ArgumentPtr &lhs) {
return false;
}

ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren) {
ArgumentPtr DivExpression::nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren) {
ArgumentsPtrVector multiplicator;

for (const auto &child : rhsChildren) {
Expand All @@ -525,7 +527,7 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const
}

if (const auto &exprChild = cast<IExpression>(childForCheck); exprChild && is<Mul>(exprChild->getFunction())) {
if (const auto &childForAdd = denominatorMulSimplify(exprChild->getChildren())) {
if (const auto &childForAdd = nestedDivInDenominatorMulSimplify(exprChild->getChildren())) {
multiplicator.emplace_back(childForAdd);
}
}
Expand All @@ -550,7 +552,7 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const
return makeExpr(Div(), numerator, denominator);
}

ArgumentPtr DivExpression::denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) {
ArgumentPtr DivExpression::nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) {
ArgumentsPtrVector multiplicator;

for (const auto &child : rhsChildren) {
Expand Down
12 changes: 6 additions & 6 deletions src/fintamath/expressions/binary/DivExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@ class DivExpression : public IBinaryExpressionCRTP<DivExpression> {

static ArgumentPtr sumSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr polynomSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);
static ArgumentPtr nestedDivSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr numeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);
static ArgumentPtr nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);

static ArgumentPtr numeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);
static ArgumentPtr nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);

static ArgumentPtr denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren);
static ArgumentPtr nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren);

static ArgumentPtr denominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren);
static ArgumentPtr nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren);

static ArgumentPtr sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

Expand Down

0 comments on commit 4a68cee

Please sign in to comment.