Skip to content

Commit

Permalink
Refactoring & fix of DivExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 12, 2023
1 parent 0234b87 commit e1d21dd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
51 changes: 27 additions & 24 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,24 +209,22 @@ bool DivExpression::isNeg(const ArgumentPtr &expr) {
}

ArgumentPtr DivExpression::sumSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (auto res = sumMulSimplify(lhs, rhs)) {
return res;
if (auto [result, remainder] = sumMulSimplify(lhs, rhs); result) {
return makeExpr(Add(), result, remainder);
}

if (auto [result, remainder] = mulSumSimplify(lhs, rhs); result) {
simplifyChild(remainder);
ArgumentPtr res = makeExpr(Add(), result, remainder);
return res;
return makeExpr(Add(), result, remainder);
}

if (auto res = sumSumSimplify(lhs, rhs)) {
return res;
if (auto [result, remainder] = sumSumSimplify(lhs, rhs); result) {
return makeExpr(Add(), result, remainder);
}

return {};
}

ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
std::pair<ArgumentPtr, ArgumentPtr> DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentsPtrVector lhsChildren;
ArgumentsPtrVector rhsChildren;

Expand Down Expand Up @@ -265,15 +263,16 @@ ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const Argument
return {};
}

ArgumentPtr remainder = makeExpr(Add(), remainderVect);
simplifyChild(remainder);
resultVect.emplace_back(makeExpr(Div(), remainder, rhs));

ArgumentPtr result = makeExpr(Add(), resultVect);
return result;

ArgumentPtr remainderAdd = makeExpr(Add(), remainderVect);
simplifyChild(remainderAdd);
ArgumentPtr remainder = makeExpr(Div(), remainderAdd, rhs);

return {result, remainder};
}

ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
std::pair<ArgumentPtr, ArgumentPtr> DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentsPtrVector lhsChildren;

if (const auto lhsExpr = cast<IExpression>(lhs); lhsExpr && is<Add>(lhsExpr->getFunction())) {
Expand Down Expand Up @@ -305,12 +304,16 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument
return {};
}

ArgumentPtr result = makeExpr(Add(), resultChildren);

ArgumentPtr remainder;
if (!remainderChildren.empty()) {
ArgumentPtr remainder = makeExpr(Div(), makeExpr(Add(), remainderChildren), rhs);
resultChildren.emplace_back(remainder);
ArgumentPtr remainderAdd = makeExpr(Add(), remainderChildren);
simplifyChild(remainder);
remainder = makeExpr(Div(), remainderAdd, rhs);
}

return makeExpr(Add(), resultChildren);
return {result, remainder};
}

std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand Down Expand Up @@ -338,8 +341,10 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], result));
}

ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicator));
ArgumentPtr remainder = makeExpr(Div(), negSum, rhs);
ArgumentPtr remainderNegAdd = makeExpr(Neg(), makeExpr(Add(), multiplicator));
simplifyChild(remainderNegAdd);
ArgumentPtr remainder = makeExpr(Div(), remainderNegAdd, rhs);

return {result, remainder};
}

Expand Down Expand Up @@ -441,8 +446,7 @@ ArgumentPtr DivExpression::nestedRationalsInNumeratorSimplify(const ArgumentsPtr
denominatorChildren.emplace_back(rhs);

ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);
ArgumentPtr denominator =
denominatorChildren.size() > 1 ? makeExpr(Mul(), denominatorChildren) : denominatorChildren.front();
ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);
return makeExpr(Div(), numerator, denominator);
}

Expand All @@ -464,11 +468,10 @@ ArgumentPtr DivExpression::nestedRationalsInDenominatorSimplify(const ArgumentPt
denominatorChildren.emplace_back(child);
}

if (!denominatorChildren.empty()) {
if (!numeratorChildren.empty()) {
numeratorChildren.emplace_back(lhs);

ArgumentPtr numerator =
numeratorChildren.size() > 1 ? makeExpr(Mul(), numeratorChildren) : numeratorChildren.front();
ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);
ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);
return makeExpr(Div(), numerator, denominator);
}
Expand Down
4 changes: 2 additions & 2 deletions src/fintamath/expressions/binary/DivExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class DivExpression : public IBinaryExpressionCRTP<DivExpression> {
static ArgumentPtr nestedRationalsInDenominatorSimplify(const ArgumentPtr &lhs,
const ArgumentsPtrVector &rhsChildren);

static ArgumentPtr sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);
static std::pair<ArgumentPtr, ArgumentPtr> sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);
static std::pair<ArgumentPtr, ArgumentPtr> sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static std::pair<ArgumentPtr, ArgumentPtr> mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

Expand Down

0 comments on commit e1d21dd

Please sign in to comment.