Skip to content

Commit

Permalink
Simplify Div + Div with different denominators
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 11, 2023
1 parent 7f8c4ff commit 1623d29
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 191 deletions.
171 changes: 7 additions & 164 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForSimplify()

DivExpression::SimplifyFunctionsVector DivExpression::getFunctionsForPostSimplify() const {
static const DivExpression::SimplifyFunctionsVector simplifyFunctions = {
&DivExpression::zeroSimplify, //
&DivExpression::negSimplify, //
&DivExpression::sumSimplify, //
&DivExpression::nestedDivSimplify, //
&DivExpression::zeroSimplify, //
&DivExpression::negSimplify, //
&DivExpression::sumSimplify, //
};
return simplifyFunctions;
}
Expand Down Expand Up @@ -131,7 +130,7 @@ ArgumentPtr DivExpression::mulSimplify(const IFunction & /*func*/, const Argumen
for (size_t j = 0; j < rhsChildren.size(); j++) {
bool isResFound = false;

if (auto divPowRes = divPowSimplify(lhsChild, rhsChildren[j])) {
if (auto divPowRes = powSimplify(lhsChild, rhsChildren[j])) {
lhsChild = divPowRes;
rhsChildren.erase(rhsChildren.begin() + ArgumentsPtrVector::difference_type(j));
isResFound = true;
Expand Down Expand Up @@ -257,6 +256,7 @@ ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const Argument
remainderVect.emplace_back(child);
}
}

if (resultVect.empty()) {
return {};
}
Expand Down Expand Up @@ -319,6 +319,7 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
}

ArgumentPtr result = makeExpr(Div(), lhs, rhsChildren.front());
postSimplifyChild(result);

if (const auto divExpr = cast<IExpression>(result); divExpr && is<Div>(divExpr->getFunction())) {
return {};
Expand All @@ -335,7 +336,7 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
return {result, remainder};
}

ArgumentPtr DivExpression::divPowSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentPtr DivExpression::powSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (*lhs == *rhs) {
return std::make_shared<Integer>(1);
}
Expand Down Expand Up @@ -391,162 +392,4 @@ ArgumentPtr DivExpression::addRatesToValue(const ArgumentsPtrVector &rates, cons
return makeExpr(Pow(), value, ratesSum);
}

ArgumentPtr DivExpression::nestedDivSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs,
const ArgumentPtr &rhs) {
ArgumentPtr result;

if (const auto &lhsExpr = cast<IExpression>(lhs)) {
if (is<Mul>(lhsExpr->getFunction())) {
result = nestedDivInNumeratorMulSimplify(lhsExpr->getChildren(), rhs);
}
else if (is<Add>(lhsExpr->getFunction())) {
result = nestedDivInNumeratorSumSimplify(lhsExpr->getChildren(), rhs);
}
}

if (result) {
return result;
}

if (const auto &rhsExpr = cast<IExpression>(rhs)) {
if (is<Add>(rhsExpr->getFunction())) {
result = nestedDivInDenominatorSumSimplify(lhs, rhs, rhsExpr->getChildren());
}
}

return result;
}

ArgumentPtr DivExpression::nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren,
const ArgumentPtr &rhs) {
ArgumentsPtrVector newNumerator;
ArgumentsPtrVector resultPolynom;

for (const auto &child : lhsChildren) {
ArgumentPtr childForCheck = child;
bool isNeg = unwrapNeg(childForCheck);

if (const auto &exprChild = cast<IExpression>(childForCheck); exprChild && is<Mul>(exprChild->getFunction())) {
if (auto result = nestedDivInNumeratorMulSimplify(exprChild->getChildren(), rhs)) {
resultPolynom.emplace_back(result);
continue;
}
}
else if (const auto &divChild = cast<DivExpression>(childForCheck)) {
ArgumentPtr childForAdd = makeExpr(Div(), divChild->lhsChild, makeExpr(Mul(), divChild->rhsChild, rhs));
resultPolynom.emplace_back(isNeg ? makeExpr(Neg(), childForAdd) : childForAdd);
continue;
}
else if (const auto &rationalChild = cast<Rational>(childForCheck)) {
ArgumentPtr childForAdd =
makeExpr(Div(), std::make_shared<const Integer>(rationalChild->numerator()),
makeExpr(Mul(), std::make_shared<const Integer>(rationalChild->denominator()), rhs));
resultPolynom.emplace_back(isNeg ? makeExpr(Neg(), childForAdd) : childForAdd);
continue;
}

newNumerator.emplace_back(child);
}

if (resultPolynom.empty()) {
return {};
}

if (!newNumerator.empty()) {
resultPolynom.emplace_back(
makeExpr(Div(), newNumerator.size() > 1 ? makeExpr(Add(), newNumerator) : newNumerator.front(), rhs));
}

return makeExpr(Add(), resultPolynom);
}

ArgumentPtr DivExpression::nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren,
const ArgumentPtr &rhs) {
ArgumentsPtrVector numeratorChildren;
ArgumentsPtrVector denominatorChildren;

for (const auto &child : lhsChildren) {
if (const auto &rationalChild = cast<Rational>(child)) {
numeratorChildren.emplace_back(std::make_shared<const Integer>(rationalChild->numerator()));
denominatorChildren.emplace_back(std::make_shared<const Integer>(rationalChild->denominator()));
continue;
}

numeratorChildren.emplace_back(child);
}

if (!denominatorChildren.empty()) {
denominatorChildren.emplace_back(rhs);
ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);
ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);
return makeExpr(Div(), numerator, denominator);
}

return {};
}

bool DivExpression::unwrapNeg(ArgumentPtr &lhs) {
if (const auto &exprLhs = cast<IExpression>(lhs); exprLhs && is<Neg>(exprLhs->getFunction())) {
lhs = exprLhs->getChildren().front();
return true;
}
return false;
}

ArgumentPtr DivExpression::nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren) {
ArgumentsPtrVector multiplicator;

for (const auto &child : rhsChildren) {
ArgumentPtr childForCheck = child;
unwrapNeg(childForCheck);

if (const auto &divChild = cast<DivExpression>(childForCheck)) {
multiplicator.emplace_back(divChild->rhsChild);
continue;
}

if (const auto &rationalChild = cast<Rational>(childForCheck)) {
multiplicator.emplace_back(std::make_shared<Integer>(rationalChild->denominator()));
continue;
}

if (const auto &exprChild = cast<IExpression>(childForCheck); exprChild && is<Mul>(exprChild->getFunction())) {
if (const auto &childForAdd = nestedDivInDenominatorMulSimplify(exprChild->getChildren())) {
multiplicator.emplace_back(childForAdd);
}
}
}

if (multiplicator.empty()) {
return {};
}

ArgumentsPtrVector numeratorChildren = multiplicator;
numeratorChildren.emplace_back(lhs);
ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);

ArgumentsPtrVector denominatorChildren = multiplicator;
denominatorChildren.emplace_back(rhs);
ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);

return makeExpr(Div(), numerator, denominator);
}

ArgumentPtr DivExpression::nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) {
ArgumentsPtrVector multiplicator;

for (const auto &child : rhsChildren) {
if (const auto &rationalChild = cast<Rational>(child)) {
multiplicator.emplace_back(std::make_shared<Integer>(rationalChild->denominator()));
}
}

if (multiplicator.empty()) {
return {};
}

return multiplicator.size() == 1 ? multiplicator.front() : makeExpr(Mul(), multiplicator);
}

}
15 changes: 1 addition & 14 deletions src/fintamath/expressions/binary/DivExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,19 @@ class DivExpression : public IBinaryExpressionCRTP<DivExpression> {

static ArgumentPtr sumSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr nestedDivSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr nestedDivInNumeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);

static ArgumentPtr nestedDivInNumeratorMulSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs);

static ArgumentPtr nestedDivInDenominatorSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs,
const ArgumentsPtrVector &rhsChildren);

static ArgumentPtr nestedDivInDenominatorMulSimplify(const ArgumentsPtrVector &rhsChildren);

static ArgumentPtr sumSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static std::pair<ArgumentPtr, ArgumentPtr> mulSumSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static ArgumentPtr divPowSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);
static ArgumentPtr powSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs);

static std::pair<ArgumentPtr, ArgumentPtr> getRateValuePair(const ArgumentPtr &rhs);

static ArgumentPtr addRatesToValue(const ArgumentsPtrVector &rates, const ArgumentPtr &value);

static bool isNeg(const ArgumentPtr &expr);

static bool unwrapNeg(ArgumentPtr &lhs);
};

}
45 changes: 34 additions & 11 deletions src/fintamath/expressions/polynomial/AddExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,7 @@ AddExpression::SimplifyFunctionsVector AddExpression::getFunctionsForPreSimplify
static const AddExpression::SimplifyFunctionsVector simplifyFunctions = {
&AddExpression::simplifyNegations, //
&AddExpression::simplifyCallFunction, //
};
return simplifyFunctions;
}

AddExpression::SimplifyFunctionsVector AddExpression::getFunctionsForPostSimplify() const {
static const AddExpression::SimplifyFunctionsVector simplifyFunctions = {
&AddExpression::sumDivisions, //
&AddExpression::sumDivisions, //
};
return simplifyFunctions;
}
Expand Down Expand Up @@ -296,15 +290,44 @@ ArgumentPtr AddExpression::sumDivisions(const IFunction & /*func*/, const Argume
std::shared_ptr<const IExpression> lhsExpr = cast<IExpression>(lhsChild);
std::shared_ptr<const IExpression> rhsExpr = cast<IExpression>(rhsChild);

ArgumentPtr res;

if (lhsExpr && is<Div>(lhsExpr->getFunction()) && rhsExpr && is<Div>(rhsExpr->getFunction())) {
if (*lhsExpr->getChildren().back() == *rhsExpr->getChildren().back()) {
ArgumentPtr divLhs = makeExpr(Add(), lhsExpr->getChildren().front(), rhsExpr->getChildren().front());
ArgumentPtr divRhs = lhsExpr->getChildren().back();
return makeExpr(Div(), divLhs, divRhs);
ArgumentPtr lhsNumerator = lhsExpr->getChildren().front();
ArgumentPtr rhsNumerator = rhsExpr->getChildren().front();
ArgumentPtr rhsDenominator = rhsExpr->getChildren().back();

ArgumentPtr numerator = makeExpr(Add(), lhsNumerator, rhsNumerator);
ArgumentPtr denominator = rhsDenominator;
res = makeExpr(Div(), numerator, denominator);
}
else {
ArgumentPtr lhsNumerator = lhsExpr->getChildren().front();
ArgumentPtr rhsNumerator = rhsExpr->getChildren().front();
ArgumentPtr lhsDenominator = lhsExpr->getChildren().back();
ArgumentPtr rhsDenominator = rhsExpr->getChildren().back();

ArgumentPtr lhsNumeratorMulRhsDenominator = makeExpr(Mul(), lhsNumerator, rhsDenominator);
ArgumentPtr rhsNumeratorMulLhsDenominator = makeExpr(Mul(), rhsNumerator, lhsDenominator);

ArgumentPtr numerator = makeExpr(Add(), lhsNumeratorMulRhsDenominator, rhsNumeratorMulLhsDenominator);
ArgumentPtr denominator = makeExpr(Mul(), lhsDenominator, rhsDenominator);
res = makeExpr(Div(), numerator, denominator);
}
}
else if (rhsExpr && is<Div>(rhsExpr->getFunction())) {
ArgumentPtr rhsNumerator = rhsExpr->getChildren().front();
ArgumentPtr rhsDenominator = rhsExpr->getChildren().back();

return {};
ArgumentPtr lhsMulRhsDenominator = makeExpr(Mul(), lhsChild, rhsDenominator);

ArgumentPtr numerator = makeExpr(Add(), lhsMulRhsDenominator, rhsNumerator);
ArgumentPtr denominator = rhsDenominator;
res = makeExpr(Div(), numerator, denominator);
}

return res;
}

}
2 changes: 0 additions & 2 deletions src/fintamath/expressions/polynomial/AddExpression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class AddExpression : public IPolynomExpressionCRTP<AddExpression> {

SimplifyFunctionsVector getFunctionsForPreSimplify() const override;

SimplifyFunctionsVector getFunctionsForPostSimplify() const override;

std::string operatorChildToString(const ArgumentPtr &inChild, const ArgumentPtr &prevChild) const override;

/**
Expand Down

0 comments on commit 1623d29

Please sign in to comment.