Skip to content

Commit

Permalink
Use simplifyChild instead of toMinimalObject & refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
fintarin committed Jul 6, 2023
1 parent 560b115 commit 4984f1e
Show file tree
Hide file tree
Showing 14 changed files with 223 additions and 131 deletions.
2 changes: 1 addition & 1 deletion src/fintamath/expressions/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ void Expression::setChildren(const ArgumentsPtrVector &childVect) {
throw InvalidInputFunctionException("", argumentVectorToStringVector(childVect));
}

child = childVect.front()->toMinimalObject();
*this = Expression(childVect.front());
}

void Expression::setVariables(const std::vector<std::pair<Variable, ArgumentPtr>> &varsToVals) {
Expand Down
20 changes: 13 additions & 7 deletions src/fintamath/expressions/binary/CompExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ std::string CompExpression::toString() const {
if (is<Variable>(solLhs)) {
sumChildren.erase(sumChildren.begin());

ArgumentPtr solRhs = makeExpr(Neg(), sumChildren)->toMinimalObject();
ArgumentPtr solRhs = makeExpr(Neg(), sumChildren);
simplifyChild(solRhs);

if (!is<IExpression>(solRhs)) {
return CompExpression(cast<IOperator>(*func), solLhs, solRhs).toString();
Expand All @@ -48,7 +49,8 @@ ArgumentPtr CompExpression::preSimplify() const {
auto simplExpr = cast<CompExpression>(simpl);

if (!simplExpr->isSolution && (!is<Integer>(rhsChild) || *rhsChild != Integer(0))) {
ArgumentPtr resLhs = makeExpr(Sub(), simplExpr->lhsChild, simplExpr->rhsChild)->toMinimalObject();
ArgumentPtr resLhs = makeExpr(Sub(), simplExpr->lhsChild, simplExpr->rhsChild);
simplifyChild(resLhs);
return std::make_shared<CompExpression>(cast<IOperator>(*func), resLhs, std::make_shared<Integer>(0));
}

Expand Down Expand Up @@ -81,7 +83,9 @@ std::shared_ptr<IFunction> CompExpression::getOppositeFunction(const IFunction &
ArgumentPtr CompExpression::coeffSimplify(const IFunction &func, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (auto lhsExpr = cast<IExpression>(lhs)) {
if (is<Neg>(lhsExpr->getFunction())) {
return makeExpr(*getOppositeFunction(func), lhsExpr->getChildren().front(), rhs)->toMinimalObject();
ArgumentPtr res = makeExpr(*getOppositeFunction(func), lhsExpr->getChildren().front(), rhs);
simplifyChild(res);
return res;
}

ArgumentsPtrVector dividendPolynom;
Expand Down Expand Up @@ -112,12 +116,14 @@ ArgumentPtr CompExpression::coeffSimplify(const IFunction &func, const ArgumentP
child = makeExpr(Div(), child, dividerNum);
}

ArgumentPtr newLhs = makeExpr(Add(), dividendPolynom)->toMinimalObject();
ArgumentPtr newRhs = rhs;
ArgumentPtr newLhs = makeExpr(Add(), dividendPolynom);
simplifyChild(newLhs);

if (*dividerNum < Integer(0)) {
return makeExpr(*cast<IFunction>(getOppositeFunction(func)), newLhs, newRhs);
return makeExpr(*cast<IFunction>(getOppositeFunction(func)), newLhs, rhs);
}
return makeExpr(func, newLhs, newRhs);

return makeExpr(func, newLhs, rhs);
}
}
return {};
Expand Down
112 changes: 68 additions & 44 deletions src/fintamath/expressions/binary/DivExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ ArgumentPtr DivExpression::numSimplify(const IFunction & /*func*/, const Argumen
}

if (Div().doArgsMatch({one, *rhs})) {
return makeExpr(Mul(), lhs, Div()(one, *rhs))->toMinimalObject();
ArgumentPtr res = makeExpr(Mul(), lhs, Div()(one, *rhs));
simplifyChild(res);
return res;
}

return {};
Expand Down Expand Up @@ -102,7 +104,9 @@ ArgumentPtr DivExpression::divSimplify(const IFunction & /*func*/, const Argumen
denominator = makeExpr(Mul(), denominatorChildren);
}

return makeExpr(Div(), numerator, denominator)->toMinimalObject();
ArgumentPtr res = makeExpr(Div(), numerator, denominator);
simplifyChild(res);
return res;
}

ArgumentPtr DivExpression::mulSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand Down Expand Up @@ -153,7 +157,8 @@ ArgumentPtr DivExpression::mulSimplify(const IFunction & /*func*/, const Argumen

ArgumentPtr numerator;
if (lhsChildren.size() > 1) {
numerator = makeExpr(Mul(), lhsChildren)->toMinimalObject();
numerator = makeExpr(Mul(), lhsChildren);
simplifyChild(numerator);
}
else {
numerator = lhsChildren.front();
Expand All @@ -164,24 +169,28 @@ ArgumentPtr DivExpression::mulSimplify(const IFunction & /*func*/, const Argumen
}

ArgumentPtr denominator;

if (rhsChildren.size() > 1) {
denominator = makeExpr(Mul(), rhsChildren)->toMinimalObject();
denominator = makeExpr(Mul(), rhsChildren);
simplifyChild(denominator);
}
else {
denominator = rhsChildren.front();
}

if (lhsChildren.size() != lhsChildrenSizeInitial || rhsChildren.size() != rhsChildrenSizeInitial) {
return makeExpr(Div(), numerator, denominator)->toMinimalObject();
ArgumentPtr res = makeExpr(Div(), numerator, denominator);
simplifyChild(res);
return res;
}

return {};
}

ArgumentPtr DivExpression::negSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
if (isNeg(rhs)) {
return makeExpr(Div(), makeExpr(Neg(), lhs), makeExpr(Neg(), rhs))->toMinimalObject();
ArgumentPtr res = makeExpr(Div(), makeExpr(Neg(), lhs), makeExpr(Neg(), rhs));
simplifyChild(res);
return res;
}

return {};
Expand Down Expand Up @@ -210,7 +219,9 @@ ArgumentPtr DivExpression::sumSimplify(const IFunction & /*func*/, const Argumen
}

if (auto [lhsRes, rhsRes] = mulSumSimplify(lhs, rhs); lhsRes) {
return makeExpr(Add(), lhsRes, rhsRes)->toMinimalObject();
ArgumentPtr res = makeExpr(Add(), lhsRes, rhsRes);
simplifyChild(res);
return res;
}

if (auto res = sumSumSimplify(lhs, rhs)) {
Expand All @@ -236,29 +247,33 @@ ArgumentPtr DivExpression::sumSumSimplify(const ArgumentPtr &lhs, const Argument
return {};
}

ArgumentsPtrVector answerVect;
ArgumentsPtrVector restVect;
ArgumentsPtrVector resultVect;
ArgumentsPtrVector remainderVect;

for (const auto &child : lhsChildren) {
auto [result, rest] = mulSumSimplify(child, rhs);
auto [result, remainder] = mulSumSimplify(child, rhs);

if (result) {
answerVect.emplace_back(result);
if (rest) {
auto restDiv = cast<DivExpression>(rest);
restVect.emplace_back(restDiv->getChildren().front());
resultVect.emplace_back(result);

if (remainder) {
auto remainderDiv = cast<DivExpression>(remainder);
remainderVect.emplace_back(remainderDiv->getChildren().front());
}
}
else {
restVect.emplace_back(child);
remainderVect.emplace_back(child);
}
}
if (answerVect.empty()) {
if (resultVect.empty()) {
return {};
}

ArgumentPtr restSimplResult = makeExpr(Add(), restVect);
answerVect.emplace_back(makeExpr(Div(), restSimplResult, rhs));
return makeExpr(Add(), answerVect)->toMinimalObject();
resultVect.emplace_back(makeExpr(Div(), makeExpr(Add(), remainderVect), rhs));

ArgumentPtr result = makeExpr(Add(), resultVect);
simplifyChild(result);
return result;
}

ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
Expand All @@ -277,7 +292,8 @@ ArgumentPtr DivExpression::sumMulSimplify(const ArgumentPtr &lhs, const Argument
ArgumentsPtrVector divFailure;

for (const auto &child : lhsChildren) {
ArgumentPtr divResult = makeExpr(Div(), child, rhs)->toMinimalObject();
ArgumentPtr divResult = makeExpr(Div(), child, rhs);
simplifyChild(divResult);

if (const auto divResultExpr = cast<IExpression>(divResult);
divResultExpr && is<Div>(divResultExpr->getFunction()) && *divResultExpr->getChildren().back() == *rhs) {
Expand Down Expand Up @@ -312,19 +328,20 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::mulSumSimplify(const Argument
return {};
}

ArgumentPtr divResult = makeExpr(Div(), lhs, rhsChildren.front())->toMinimalObject();
ArgumentPtr divResult = makeExpr(Div(), lhs, rhsChildren.front());
simplifyChild(divResult);

if (const auto divExpr = cast<IExpression>(divResult); divExpr && is<Div>(divExpr->getFunction())) {
return {};
}

ArgumentsPtrVector multiplicates;
ArgumentsPtrVector multiplicator;

for (size_t i = 1; i < rhsChildren.size(); i++) {
multiplicates.emplace_back(makeExpr(Mul(), rhsChildren[i], divResult));
multiplicator.emplace_back(makeExpr(Mul(), rhsChildren[i], divResult));
}

ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicates));
ArgumentPtr negSum = makeExpr(Neg(), makeExpr(Add(), multiplicator));
ArgumentPtr div = makeExpr(Div(), negSum, rhs);
return {divResult, div};
}
Expand All @@ -334,21 +351,21 @@ ArgumentPtr DivExpression::divPowSimplify(const ArgumentPtr &lhs, const Argument
return std::make_shared<Integer>(1);
}

bool negation = false;
bool isResultNegated = false;

ArgumentPtr lhsChild;
ArgumentPtr rhsChild;

if (const auto lhsExpr = cast<IExpression>(lhs); lhsExpr && is<Neg>(lhsExpr->getFunction())) {
negation = !negation;
isResultNegated = !isResultNegated;
lhsChild = lhsExpr->getChildren().front();
}
else {
lhsChild = lhs;
}

if (const auto rhsExpr = cast<IExpression>(rhs); rhsExpr && is<Neg>(rhsExpr->getFunction())) {
negation = !negation;
isResultNegated = !isResultNegated;
rhsChild = rhsExpr->getChildren().front();
}
else {
Expand All @@ -359,19 +376,17 @@ ArgumentPtr DivExpression::divPowSimplify(const ArgumentPtr &lhs, const Argument
auto [rhsRate, rhsValue] = getRateValuePair(rhsChild);

ArgumentPtr result;

if (*lhsValue == *rhsValue) {
result = addRatesToValue({lhsRate, makeExpr(Neg(), rhsRate)}, lhsValue);
}

if (result) {
if (negation) {
return makeExpr(Neg(), result)->toMinimalObject();
if (isResultNegated) {
result = makeExpr(Neg(), result);
simplifyChild(result);
}

return result;
}

return {};
return result;
}

std::pair<ArgumentPtr, ArgumentPtr> DivExpression::getRateValuePair(const ArgumentPtr &rhs) {
Expand All @@ -384,7 +399,8 @@ std::pair<ArgumentPtr, ArgumentPtr> DivExpression::getRateValuePair(const Argume
}

ArgumentPtr DivExpression::addRatesToValue(const ArgumentsPtrVector &rates, const ArgumentPtr &value) {
ArgumentPtr ratesSum = makeExpr(Add(), rates)->toMinimalObject();
ArgumentPtr ratesSum = makeExpr(Add(), rates);
simplifyChild(ratesSum);
return makeExpr(Pow(), value, ratesSum);
}

Expand All @@ -402,7 +418,8 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg
}

if (result) {
return result->toMinimalObject();
simplifyChild(result);
return result;
}

if (const auto &rhsExpr = cast<IExpression>(rhs)) {
Expand All @@ -411,7 +428,8 @@ ArgumentPtr DivExpression::polynomSimplify(const IFunction & /*func*/, const Arg
}
}

return result != nullptr ? result->toMinimalObject() : result;
simplifyChild(result);
return result;
}

ArgumentPtr DivExpression::numeratorSumSimplify(const ArgumentsPtrVector &lhsChildren, const ArgumentPtr &rhs) {
Expand Down Expand Up @@ -517,13 +535,19 @@ ArgumentPtr DivExpression::denominatorSumSimplify(const ArgumentPtr &lhs, const
return {};
}

ArgumentsPtrVector newNumerator = multiplicator;
newNumerator.emplace_back(lhs);
ArgumentsPtrVector newDenominator = multiplicator;
newDenominator.emplace_back(rhs);
ArgumentsPtrVector numeratorChildren = multiplicator;
numeratorChildren.emplace_back(lhs);

ArgumentPtr numerator = makeExpr(Mul(), numeratorChildren);
simplifyChild(numerator);

ArgumentsPtrVector denominatorChildren = multiplicator;
denominatorChildren.emplace_back(rhs);

ArgumentPtr denominator = makeExpr(Mul(), denominatorChildren);
simplifyChild(denominator);

return makeExpr(Div(), makeExpr(Mul(), newNumerator)->toMinimalObject(),
makeExpr(Mul(), newDenominator)->toMinimalObject());
return makeExpr(Div(), numerator, denominator);
}

ArgumentPtr DivExpression::denominatorMulSimplify(const ArgumentsPtrVector &rhsChildren) {
Expand Down
6 changes: 3 additions & 3 deletions src/fintamath/expressions/binary/IntegralExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ ArgumentPtr IntegralExpression::integralSimplify(const IFunction & /*func*/, con
ArgumentPtr res;

if (is<INumber>(lhs) || is<IConstant>(lhs)) {
res = makeExpr(Mul(), lhs, rhs)->toMinimalObject();
res = makeExpr(Mul(), lhs, rhs);
}
else if (is<Variable>(lhs) && is<Variable>(rhs) && *lhs == *rhs) {
res = makeExpr(Div(), makeExpr(Pow(), lhs, std::make_shared<Integer>(2)), std::make_shared<Integer>(2))
->toMinimalObject();
res = makeExpr(Div(), makeExpr(Pow(), lhs, std::make_shared<Integer>(2)), std::make_shared<Integer>(2));
}

// TODO: res + integral constant

simplifyChild(res);
return res;
}

Expand Down
19 changes: 11 additions & 8 deletions src/fintamath/expressions/binary/LogExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,21 @@ ArgumentPtr LogExpression::equalSimplify(const IFunction & /*func*/, const Argum
}

ArgumentPtr LogExpression::powSimplify(const IFunction & /*func*/, const ArgumentPtr &lhs, const ArgumentPtr &rhs) {
ArgumentPtr res;

if (auto rhsExpr = cast<IExpression>(rhs); rhsExpr && is<Pow>(rhsExpr->getFunction())) {
return makeExpr(Mul(), rhsExpr->getChildren().back(), makeExpr(Log(), lhs, rhsExpr->getChildren().front()))
->toMinimalObject();
ArgumentPtr multiplier = rhsExpr->getChildren().back();
ArgumentPtr logExpr = makeExpr(Log(), lhs, rhsExpr->getChildren().front());
res = makeExpr(Mul(), multiplier, logExpr);
}

if (auto lhsExpr = cast<IExpression>(lhs); lhsExpr && is<Pow>(lhsExpr->getFunction())) {
return makeExpr(Mul(), makeExpr(Div(), std::make_shared<Integer>(1), lhsExpr->getChildren().back()),
makeExpr(Log(), lhsExpr->getChildren().front(), rhs))
->toMinimalObject();
else if (auto lhsExpr = cast<IExpression>(lhs); lhsExpr && is<Pow>(lhsExpr->getFunction())) {
ArgumentPtr multiplier = makeExpr(Div(), std::make_shared<Integer>(1), lhsExpr->getChildren().back());
ArgumentPtr logExpr = makeExpr(Log(), lhsExpr->getChildren().front(), rhs);
res = makeExpr(Mul(), multiplier, logExpr);
}

return {};
simplifyChild(res);
return res;
}

}
Loading

0 comments on commit 4984f1e

Please sign in to comment.